RTK  2.6.0
Reconstruction Toolkit
rtkSchlomka2008NegativeLogLikelihood.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright RTK Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * https://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #ifndef rtkSchlomka2008NegativeLogLikelihood_h
20 #define rtkSchlomka2008NegativeLogLikelihood_h
21 
23 #include "rtkMacro.h"
24 
25 #include <itkVectorImage.h>
27 #include <itkVariableSizeMatrix.h>
28 
29 namespace rtk
30 {
44 // We have to define the cost function first
46 {
47 public:
48  ITK_DISALLOW_COPY_AND_MOVE(Schlomka2008NegativeLogLikelihood);
49 
54 
56  itkNewMacro(Self);
57 
59 #ifdef itkOverrideGetNameOfClassMacro
60  itkOverrideGetNameOfClassMacro(Schlomka2008NegativeLogLikelihood);
61 #else
63 #endif
64 
65 
69 
74 
75  // Constructor
77 
78  // Destructor
79  ~Schlomka2008NegativeLogLikelihood() override = default;
80 
81  void
82  Initialize() override
83  {
84  // This method computes the combined m_IncidentSpectrumAndDetectorResponseProduct
85  // from m_DetectorResponse and m_IncidentSpectrum
86 
87  // In spectral CT, m_DetectorResponse has as many rows as the number of bins,
88  // and m_IncidentSpectrum has only one row (there is only one spectrum illuminating
89  // the object)
91  for (unsigned int i = 0; i < m_DetectorResponse.rows(); i++)
92  for (unsigned int j = 0; j < m_DetectorResponse.cols(); j++)
94  }
95 
96  // Not used with a simplex optimizer, but may be useful later
97  // for gradient based methods
98  void
99  GetDerivative(const ParametersType & lineIntegrals, DerivativeType & derivatives) const override
100  {
101  // Set the size of the derivatives vector
102  derivatives.set_size(m_NumberOfMaterials);
103 
104  // Get some required data
105  vnl_vector<double> attenuationFactors;
106  attenuationFactors.set_size(this->m_NumberOfEnergies);
107  GetAttenuationFactors(lineIntegrals, attenuationFactors);
108  vnl_vector<double> lambdas = ForwardModel(lineIntegrals);
109 
110  // Compute the vector of 1 - m_b / lambda_b
111  vnl_vector<double> weights;
112  weights.set_size(m_NumberOfSpectralBins);
113  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
114  weights[i] = 1 - (m_MeasuredData[i] / lambdas[i]);
115 
116  // Prepare intermediate variables
117  vnl_vector<double> intermediate_a;
118  vnl_vector<double> partial_derivative_a;
119 
120  for (unsigned int a = 0; a < m_NumberOfMaterials; a++)
121  {
122  // Compute the partial derivatives of lambda_b with respect to the material line integrals
123  intermediate_a = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a));
124  partial_derivative_a = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a;
125 
126  // Multiply them together element-wise, then dot product with the weights
127  derivatives[a] = dot_product(partial_derivative_a, weights);
128  }
129  }
130 
131  // Main method
133  GetValue(const ParametersType & parameters) const override
134  {
135  // Forward model: compute the expected number of counts in each bin
136  vnl_vector<double> forward = ForwardModel(parameters);
137 
138  long double measure = 0;
139  // Compute the negative log likelihood from the lambdas
140  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
141  measure += forward[i] - std::log((long double)forward[i]) * m_MeasuredData[i];
142  return measure;
143  }
144 
145  void
146  ComputeFischerMatrix(const ParametersType & lineIntegrals) override
147  {
148  // Get some required data
149  vnl_vector<double> attenuationFactors;
150  attenuationFactors.set_size(this->m_NumberOfEnergies);
151  GetAttenuationFactors(lineIntegrals, attenuationFactors);
152  vnl_vector<double> lambdas = ForwardModel(lineIntegrals);
153 
154  // Compute the vector of m_b / lambda_b^2
155  vnl_vector<double> weights;
156  weights.set_size(m_NumberOfSpectralBins);
157  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
158  weights[i] = m_MeasuredData[i] / (lambdas[i] * lambdas[i]);
159 
160  // Prepare intermediate variables
161  vnl_vector<double> intermediate_a;
162  vnl_vector<double> intermediate_a_prime;
163  vnl_vector<double> partial_derivative_a;
164  vnl_vector<double> partial_derivative_a_prime;
165 
166  // Compute the Fischer information matrix
168  for (unsigned int a = 0; a < m_NumberOfMaterials; a++)
169  {
170  for (unsigned int a_prime = 0; a_prime < m_NumberOfMaterials; a_prime++)
171  {
172  // Compute the partial derivatives of lambda_b with respect to the material line integrals
173  intermediate_a = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a));
174  intermediate_a_prime = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a_prime));
175 
176  partial_derivative_a = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a;
177  partial_derivative_a_prime = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a_prime;
178 
179  // Multiply them together element-wise, then dot product with the weights
180  partial_derivative_a_prime = element_product(partial_derivative_a, partial_derivative_a_prime);
181  m_Fischer[a][a_prime] = dot_product(partial_derivative_a_prime, weights);
182  }
183  }
184  }
185 };
186 
187 } // namespace rtk
188 
189 #endif
virtual vnl_vector< double > ForwardModel(const ParametersType &lineIntegrals) const
void ComputeFischerMatrix(const ParametersType &lineIntegrals) override
void GetAttenuationFactors(const ParametersType &lineIntegrals, vnl_vector< double > &attenuationFactors) const
void GetDerivative(const ParametersType &lineIntegrals, DerivativeType &derivatives) const override
MeasureType GetValue(const ParametersType &parameters) const override
~Schlomka2008NegativeLogLikelihood() override=default