go home Home | Main Page | Topics | Namespace List | Class Hierarchy | Alphabetical List | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages
Loading...
Searching...
No Matches
itkImpactModelConfiguration.h
Go to the documentation of this file.
1/*=========================================================================
2 *
3 * Copyright UMC Utrecht and contributors
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19#ifndef itkImpactModelConfiguration_h
20#define itkImpactModelConfiguration_h
21
22#include <torch/script.h>
23#include <torch/torch.h>
24
25// Standard C++ header files:
26#include <algorithm> // For transform.
27#include <memory> // For shared_ptr.
28#include <sstream>
29#include <string>
30#include <vector>
31#include "itkStatisticsImageFilter.h"
32
36template <typename T>
37std::string
38GetStringFromVector(const std::vector<T> & vec)
39{
40 std::stringstream ss;
41 ss << "(";
42 for (int i = 0; i < vec.size(); ++i)
43 {
44 ss << vec[i];
45 if (i != vec.size() - 1)
46 {
47 ss << " ";
48 }
49 }
50 ss << ")";
51 return ss.str();
52} // end GetStringFromVector
53
54
55namespace itk
56{
65{
66public:
67 ImpactModelConfiguration(std::string modelPath,
68 unsigned int dimension,
69 unsigned int numberOfChannels,
70 std::vector<unsigned int> patchSize,
71 std::vector<float> voxelSize,
72 std::vector<bool> layersMask,
73 bool isStatic,
74 bool useMixedPrecision)
75 : m_ModelPath(modelPath)
76 , m_Dimension(dimension)
77 , m_NumberOfChannels(numberOfChannels)
78 , m_PatchSize(patchSize.begin(), patchSize.end())
79 , m_VoxelSize(voxelSize)
80 , m_LayersMask(layersMask)
81 , m_DataType(useMixedPrecision ? torch::kFloat16 : torch::kFloat32)
82 {
83 m_Model = std::make_shared<torch::jit::script::Module>(torch::jit::load(m_ModelPath, torch::Device(torch::kCPU)));
84 m_Model->eval();
86 m_nArgs = m_Model->get_method("forward").function().getSchema().arguments().size() - 1;
87 m_nLayers = torch::tensor(static_cast<int64_t>(m_LayersMask.size()), torch::kInt16);
88
89 if (!isStatic)
90 {
92 m_PatchIndex.clear();
93 if (m_PatchSize.size() == 2)
94 {
95 for (int y = 0; y < m_PatchSize[1]; ++y)
96 {
97 for (int x = 0; x < m_PatchSize[0]; ++x)
98 {
99 m_PatchIndex.push_back(
100 { (x - m_PatchSize[0] / 2) * m_VoxelSize[0], (y - m_PatchSize[1] / 2) * m_VoxelSize[1] });
101 }
102 }
103 }
104 else
105 {
106 for (int z = 0; z < m_PatchSize[2]; ++z)
107 {
108 for (int y = 0; y < m_PatchSize[1]; ++y)
109 {
110 for (int x = 0; x < m_PatchSize[0]; ++x)
111 {
112 m_PatchIndex.push_back({ (x - m_PatchSize[0] / 2) * m_VoxelSize[0],
113 (y - m_PatchSize[1] / 2) * m_VoxelSize[1],
114 (z - m_PatchSize[2] / 2) * m_VoxelSize[2] });
115 }
116 }
117 }
118 }
119 }
120 }
121
122 // Disable (delete) copying, to avoid having multiple copies of the same model:
126
127 // Enable (default) move semantics:
131
132 // Destructor.
134
135 bool
142
143 friend std::ostream &
144 operator<<(std::ostream & os, const ImpactModelConfiguration & config)
145 {
146 os << "\t\tPath : " << config.m_ModelPath << "\n\t\tDimension : " << config.m_Dimension
147 << "\n\t\tNumberOfChannels : " << config.m_NumberOfChannels
148 << "\n\t\tPatchSize : " << GetStringFromVector<int64_t>(config.m_PatchSize)
149 << "\n\t\tVoxelSize : " << GetStringFromVector<float>(config.m_VoxelSize)
150 << "\n\t\tLayersMask : " << GetStringFromVector<bool>(config.m_LayersMask);
151 return os;
152 }
153
154 const std::string &
156 {
157 return m_ModelPath;
158 }
159
160 const torch::ScalarType &
162 {
163 return m_DataType;
164 }
165
166 unsigned int
168 {
169 return m_Dimension;
170 }
171 unsigned int
173 {
174 return m_NumberOfChannels;
175 }
176 const std::vector<int64_t> &
178 {
179 return m_PatchSize;
180 }
181 const std::vector<float> &
183 {
184 return m_VoxelSize;
185 }
186 const std::vector<bool> &
188 {
189 return m_LayersMask;
190 }
191
192 void
193 to(torch::Device device) const
194 {
195 m_Model->to(device);
196 }
197
198 template <class TImage>
199 void
200 setup(typename TImage::ConstPointer image)
201 {
202 auto imageStats = itk::StatisticsImageFilter<TImage>::New();
203 imageStats->SetInput(image);
204 imageStats->Update();
205
206 torch::Tensor imageStatsTensor = torch::tensor({ static_cast<float>(imageStats->GetMinimum()),
207 static_cast<float>(imageStats->GetMaximum()),
208 static_cast<float>(imageStats->GetMean()),
209 static_cast<float>(imageStats->GetSigma()) },
210 torch::kFloat32);
211
212 const auto & imageDirection = image->GetDirection(); // itk::Matrix<double,TImage::Dimension,TImage::Dimension>
213
214 constexpr unsigned int imageDimension = TImage::ImageDimension;
215 torch::Tensor imageDirectionTensor = torch::empty({ imageDimension, imageDimension }, torch::kInt16);
216
217 for (unsigned int r = 0; r < imageDimension; ++r)
218 {
219 for (unsigned int c = 0; c < imageDimension; ++c)
220 {
221 imageDirectionTensor[r][c] = static_cast<int16_t>(std::llround(imageDirection(r, c)));
222 }
223 }
224 m_imageStatsTensor = imageStatsTensor;
225 m_imageDirectionTensor = imageDirectionTensor;
226 }
227
228 std::vector<torch::jit::IValue>
229 forward(torch::Tensor inputPatch) const
230 {
231
232 std::vector<torch::jit::IValue> args;
233 args.reserve(m_nArgs);
234 args.emplace_back(inputPatch);
235
236 if (m_nArgs >= 2)
237 { // number of requested layers (retrocompatible models may not have it)
238 args.emplace_back(m_nLayers);
239 }
240
241 if (m_nArgs >= 4)
242 { // Arg 2-3: optional image metadata (image stats + direction)
243 args.emplace_back(m_imageStatsTensor);
244 args.emplace_back(m_imageDirectionTensor);
245 }
246
247 return m_Model->forward(std::move(args)).toList().vec();
248 }
249
250 const std::vector<std::vector<float>> &
252 {
253 return m_PatchIndex;
254 }
255 const std::vector<std::vector<torch::indexing::TensorIndex>> &
257 {
259 }
260 void
261 SetCentersIndexLayers(std::vector<std::vector<torch::indexing::TensorIndex>> & centersIndexLayers)
262 {
263 m_CentersIndexLayers = centersIndexLayers;
264 }
265
266
267private:
268 std::string m_ModelPath;
269 unsigned int m_Dimension;
270 unsigned int m_NumberOfChannels;
271 std::vector<int64_t> m_PatchSize;
272 std::vector<float> m_VoxelSize;
273 std::vector<bool> m_LayersMask;
274 std::shared_ptr<torch::jit::script::Module> m_Model;
275 std::vector<std::vector<float>> m_PatchIndex;
276 std::vector<std::vector<torch::indexing::TensorIndex>> m_CentersIndexLayers;
277 torch::ScalarType m_DataType;
278 torch::Tensor m_imageStatsTensor;
280 std::size_t m_nArgs;
281 torch::Tensor m_nLayers;
282};
283
284
285} // end namespace itk
286
287#endif // end #ifndef itkImpactModelConfiguration_h
std::string GetStringFromVector(const std::vector< T > &vec)
const std::vector< int64_t > & GetPatchSize() const
std::vector< torch::jit::IValue > forward(torch::Tensor inputPatch) const
ImpactModelConfiguration(std::string modelPath, unsigned int dimension, unsigned int numberOfChannels, std::vector< unsigned int > patchSize, std::vector< float > voxelSize, std::vector< bool > layersMask, bool isStatic, bool useMixedPrecision)
friend std::ostream & operator<<(std::ostream &os, const ImpactModelConfiguration &config)
const torch::ScalarType & GetDataType() const
ImpactModelConfiguration & operator=(ImpactModelConfiguration &&)=default
ImpactModelConfiguration & operator=(const ImpactModelConfiguration &)=delete
const std::vector< bool > & GetLayersMask() const
std::vector< std::vector< torch::indexing::TensorIndex > > m_CentersIndexLayers
void to(torch::Device device) const
std::shared_ptr< torch::jit::script::Module > m_Model
void setup(typename TImage::ConstPointer image)
const std::string & GetModelPath() const
void SetCentersIndexLayers(std::vector< std::vector< torch::indexing::TensorIndex > > &centersIndexLayers)
ImpactModelConfiguration(ImpactModelConfiguration &&)=default
ImpactModelConfiguration(const ImpactModelConfiguration &)=delete
std::vector< std::vector< float > > m_PatchIndex
const std::vector< std::vector< torch::indexing::TensorIndex > > & GetCentersIndexLayers() const
const std::vector< float > & GetVoxelSize() const
const std::vector< std::vector< float > > & GetPatchIndex() const
bool operator==(const ImpactModelConfiguration &rhs) const


Generated on 1774142652 for elastix by doxygen 1.15.0 elastix logo