RTK  2.6.0
Reconstruction Toolkit
rtkSchlomka2008NegativeLogLikelihood.h
Go to the documentation of this file.
1 /*=========================================================================
2  *
3  * Copyright RTK Consortium
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  * https://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #ifndef rtkSchlomka2008NegativeLogLikelihood_h
20 #define rtkSchlomka2008NegativeLogLikelihood_h
21 
23 #include "rtkMacro.h"
24 
25 #include <itkVectorImage.h>
27 #include <itkVariableSizeMatrix.h>
28 
29 namespace rtk
30 {
44 // We have to define the cost function first
46 {
47 public:
48  ITK_DISALLOW_COPY_AND_MOVE(Schlomka2008NegativeLogLikelihood);
49 
54 
56  itkNewMacro(Self);
57 
59  itkOverrideGetNameOfClassMacro(Schlomka2008NegativeLogLikelihood);
60 
64 
69 
70  // Constructor
72 
73  // Destructor
74  ~Schlomka2008NegativeLogLikelihood() override = default;
75 
76  void
77  Initialize() override
78  {
79  // This method computes the combined m_IncidentSpectrumAndDetectorResponseProduct
80  // from m_DetectorResponse and m_IncidentSpectrum
81 
82  // In spectral CT, m_DetectorResponse has as many rows as the number of bins,
83  // and m_IncidentSpectrum has only one row (there is only one spectrum illuminating
84  // the object)
86  for (unsigned int i = 0; i < m_DetectorResponse.rows(); i++)
87  for (unsigned int j = 0; j < m_DetectorResponse.cols(); j++)
89  }
90 
91  // Not used with a simplex optimizer, but may be useful later
92  // for gradient based methods
93  void
94  GetDerivative(const ParametersType & lineIntegrals, DerivativeType & derivatives) const override
95  {
96  // Set the size of the derivatives vector
97  derivatives.set_size(m_NumberOfMaterials);
98 
99  // Get some required data
100  vnl_vector<double> attenuationFactors;
101  attenuationFactors.set_size(this->m_NumberOfEnergies);
102  GetAttenuationFactors(lineIntegrals, attenuationFactors);
103  vnl_vector<double> lambdas = ForwardModel(lineIntegrals);
104 
105  // Compute the vector of 1 - m_b / lambda_b
106  vnl_vector<double> weights;
107  weights.set_size(m_NumberOfSpectralBins);
108  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
109  weights[i] = 1 - (m_MeasuredData[i] / lambdas[i]);
110 
111  // Prepare intermediate variables
112  vnl_vector<double> intermediate_a;
113  vnl_vector<double> partial_derivative_a;
114 
115  for (unsigned int a = 0; a < m_NumberOfMaterials; a++)
116  {
117  // Compute the partial derivatives of lambda_b with respect to the material line integrals
118  intermediate_a = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a));
119  partial_derivative_a = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a;
120 
121  // Multiply them together element-wise, then dot product with the weights
122  derivatives[a] = dot_product(partial_derivative_a, weights);
123  }
124  }
125 
126  // Main method
128  GetValue(const ParametersType & parameters) const override
129  {
130  // Forward model: compute the expected number of counts in each bin
131  vnl_vector<double> forward = ForwardModel(parameters);
132 
133  long double measure = 0;
134  // Compute the negative log likelihood from the lambdas
135  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
136  measure += forward[i] - std::log((long double)forward[i]) * m_MeasuredData[i];
137  return measure;
138  }
139 
140  void
141  ComputeFischerMatrix(const ParametersType & lineIntegrals) override
142  {
143  // Get some required data
144  vnl_vector<double> attenuationFactors;
145  attenuationFactors.set_size(this->m_NumberOfEnergies);
146  GetAttenuationFactors(lineIntegrals, attenuationFactors);
147  vnl_vector<double> lambdas = ForwardModel(lineIntegrals);
148 
149  // Compute the vector of m_b / lambda_b^2
150  vnl_vector<double> weights;
151  weights.set_size(m_NumberOfSpectralBins);
152  for (unsigned int i = 0; i < m_NumberOfSpectralBins; i++)
153  weights[i] = m_MeasuredData[i] / (lambdas[i] * lambdas[i]);
154 
155  // Prepare intermediate variables
156  vnl_vector<double> intermediate_a;
157  vnl_vector<double> intermediate_a_prime;
158  vnl_vector<double> partial_derivative_a;
159  vnl_vector<double> partial_derivative_a_prime;
160 
161  // Compute the Fischer information matrix
163  for (unsigned int a = 0; a < m_NumberOfMaterials; a++)
164  {
165  for (unsigned int a_prime = 0; a_prime < m_NumberOfMaterials; a_prime++)
166  {
167  // Compute the partial derivatives of lambda_b with respect to the material line integrals
168  intermediate_a = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a));
169  intermediate_a_prime = element_product(-attenuationFactors, m_MaterialAttenuations.get_column(a_prime));
170 
171  partial_derivative_a = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a;
172  partial_derivative_a_prime = m_IncidentSpectrumAndDetectorResponseProduct * intermediate_a_prime;
173 
174  // Multiply them together element-wise, then dot product with the weights
175  partial_derivative_a_prime = element_product(partial_derivative_a, partial_derivative_a_prime);
176  m_Fischer[a][a_prime] = dot_product(partial_derivative_a_prime, weights);
177  }
178  }
179  }
180 };
181 
182 } // namespace rtk
183 
184 #endif
virtual vnl_vector< double > ForwardModel(const ParametersType &lineIntegrals) const
void ComputeFischerMatrix(const ParametersType &lineIntegrals) override
void GetAttenuationFactors(const ParametersType &lineIntegrals, vnl_vector< double > &attenuationFactors) const
Superclass::ParametersType ParametersType
void GetDerivative(const ParametersType &lineIntegrals, DerivativeType &derivatives) const override
MeasureType GetValue(const ParametersType &parameters) const override
~Schlomka2008NegativeLogLikelihood() override=default