Molassembler  1.0.0
Molecule graph and conformer library
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Pages
NelderMead.h
Go to the documentation of this file.
1 
8 #ifndef INCLUDE_TEMPLE_OPTIMIZATION_NELDER_MEAD_H
9 #define INCLUDE_TEMPLE_OPTIMIZATION_NELDER_MEAD_H
10 
11 #include <Eigen/Core>
13 
14 namespace Scine {
15 namespace Molassembler {
16 namespace Temple {
17 
23 template<typename FloatType = double>
24 struct NelderMead {
25  using MatrixType = Eigen::Matrix<FloatType, Eigen::Dynamic, Eigen::Dynamic>;
26  using VectorType = Eigen::Matrix<FloatType, Eigen::Dynamic, 1>;
27 
31  unsigned iterations;
33  FloatType value;
35  unsigned minimalIndex;
36  };
37 
38  static VectorType generateVertex(
39  const FloatType coefficient,
40  const VectorType& centroid,
41  const VectorType& worstVertex
42  ) {
43  return centroid + coefficient * (centroid - worstVertex);
44  }
45 
46  struct IndexValuePair {
47  unsigned column;
48  FloatType value;
49 
50  bool operator < (const IndexValuePair& other) const {
51  return value < other.value;
52  }
53  };
54 
55  static FloatType valueStandardDeviation(const std::vector<IndexValuePair>& sortedPairs) {
56  const unsigned V = sortedPairs.size();
57  // Calculate standard deviation of values
58  const FloatType average = Temple::accumulate(
59  sortedPairs,
60  FloatType {0},
61  [](const FloatType carry, const IndexValuePair& pair) -> FloatType {
62  return carry + pair.value;
63  }
64  ) / V;
65  return std::sqrt(
67  sortedPairs,
68  FloatType {0},
69  [average](const FloatType carry, const IndexValuePair& pair) -> FloatType {
70  const FloatType diff = pair.value - average;
71  return carry + diff * diff;
72  }
73  ) / V
74  );
75  }
76 
78  static VectorType centroid(
79  const MatrixType& vertices,
80  const std::vector<IndexValuePair>& pairs
81  ) {
82  return (vertices.rowwise().sum() - vertices.col(pairs.back().column)) / (vertices.cols() - 1);
83  }
84 
85  template<typename UpdateFunction>
86  static void shrink(
87  Eigen::Ref<MatrixType> vertices,
88  std::vector<IndexValuePair>& pairs,
89  UpdateFunction&& function
90  ) {
91  constexpr FloatType shrinkCoefficient = 0.5;
92  static_assert(0 < shrinkCoefficient && shrinkCoefficient < 1, "Shrink coefficient bounds not met");
93 
94  const unsigned V = vertices.cols();
95  const unsigned bestColumn = pairs.front().column;
96  const auto& bestVertex = vertices.col(bestColumn);
97 
98  // Shrink all points besides the best one and recalculate function values
99  for(unsigned i = 1; i < V; ++i) {
100  auto& value = pairs.at(i);
101  auto vertex = vertices.col(value.column);
102  vertex = bestVertex + shrinkCoefficient * (vertex - bestVertex);
103  value.value = function(vertex);
104  }
105  Temple::sort(pairs);
106  }
107 
108  static void replaceWorst(
109  std::vector<IndexValuePair>& sortedPairs,
110  const VectorType& newVertex,
111  const FloatType newValue,
112  Eigen::Ref<MatrixType> vertices
113  ) {
114  assert(std::is_sorted(std::begin(sortedPairs), std::end(sortedPairs)));
115  IndexValuePair replacementPair {
116  sortedPairs.back().column,
117  newValue
118  };
119  // Insert the replacement into values, keeping ordering
120  sortedPairs.insert(
121  std::lower_bound(
122  std::begin(sortedPairs),
123  std::end(sortedPairs),
124  replacementPair
125  ),
126  replacementPair
127  );
128  // Drop the worst value
129  sortedPairs.pop_back();
130  // Replace the worst column
131  vertices.col(replacementPair.column) = newVertex;
132  }
133 
134  template<
135  typename UpdateFunction,
136  typename Checker
137  > static OptimizationReturnType minimize(
138  Eigen::Ref<MatrixType> vertices,
139  UpdateFunction&& function,
140  Checker&& check
141  ) {
142  constexpr FloatType reflectionCoefficient = 1;
143  constexpr FloatType expansionCoefficient = 2;
144  constexpr FloatType contractionCoefficient = 0.5;
145 
146  static_assert(0 < reflectionCoefficient, "Reflection coefficient bounds not met");
147  static_assert(1 < expansionCoefficient, "Expansion coefficient bounds not met");
148  static_assert(0 < contractionCoefficient && contractionCoefficient <= 0.5, "Contraction coefficient bounds not met");
149 
150  const unsigned N = vertices.rows();
151  assert(vertices.cols() == N + 1);
152 
153  std::vector<IndexValuePair> values = Temple::sorted(
154  Temple::map(
155  Temple::iota<unsigned>(N + 1),
156  [&](const unsigned i) -> IndexValuePair {
157  return {
158  i,
159  function(vertices.col(i))
160  };
161  }
162  )
163  );
164 
165  FloatType standardDeviation;
166  unsigned iteration = 0;
167  do {
168  const VectorType simplexCentroid = centroid(vertices, values);
169  const VectorType& worstVertex = vertices.col(values.back().column);
170  const FloatType worstVertexValue = values.back().value;
171  const FloatType bestVertexValue = values.front().value;
172 
173  /* Reflect */
174  const VectorType reflectedVertex = generateVertex(reflectionCoefficient, simplexCentroid, worstVertex);
175  const FloatType reflectedValue = function(reflectedVertex);
176 
177  if(reflectedValue < bestVertexValue) {
178  /* Expansion */
179  const VectorType expandedVertex = generateVertex(expansionCoefficient, simplexCentroid, worstVertex);
180  const FloatType expandedValue = function(expandedVertex);
181 
182  if(expandedValue < reflectedValue) {
183  // Replace the worst value with the expanded point
184  replaceWorst(values, expandedVertex, expandedValue, vertices);
185  } else {
186  // Replace the worst value with the reflected point
187  replaceWorst(values, reflectedVertex, reflectedValue, vertices);
188  }
189  } else if(bestVertexValue <= reflectedValue && reflectedValue < values.at(N - 1).value) {
190  replaceWorst(values, reflectedVertex, reflectedValue, vertices);
191  } else if(values.at(N - 1).value <= reflectedValue && reflectedValue < worstVertexValue) {
192  /* Outside contraction */
193  const VectorType outsideContractedVertex = generateVertex(reflectionCoefficient * contractionCoefficient, simplexCentroid, worstVertex);
194  const FloatType outsideContractedValue = function(outsideContractedVertex);
195  if(outsideContractedValue <= reflectedValue) {
196  replaceWorst(values, outsideContractedVertex, outsideContractedValue, vertices);
197  } else {
198  shrink(vertices, values, function);
199  }
200  } else {
201  /* Inside contraction */
202  const VectorType insideContractedVertex = generateVertex(-contractionCoefficient, simplexCentroid, worstVertex);
203  const FloatType insideContractedValue = function(insideContractedVertex);
204  if(insideContractedValue < worstVertexValue) {
205  replaceWorst(values, insideContractedVertex, insideContractedValue, vertices);
206  } else {
207  shrink(vertices, values, function);
208  }
209  }
210 
211  standardDeviation = valueStandardDeviation(values);
212  ++iteration;
213  } while(check.shouldContinue(iteration, values.front().value, standardDeviation));
214 
215  return {
216  iteration,
217  values.front().value,
218  values.front().column
219  };
220  }
221 };
222 
223 } // namespace Temple
224 } // namespace Molassembler
225 } // namespace Scine
226 
227 #endif
void sort(Container &container)
Calls std::sort on a container.
Definition: Functional.h:271
Type returned from an optimization.
Definition: NelderMead.h:29
static VectorType centroid(const MatrixType &vertices, const std::vector< IndexValuePair > &pairs)
Calculates the simplex centroid excluding the worst vertex.
Definition: NelderMead.h:78
unsigned iterations
Number of iterations.
Definition: NelderMead.h:31
T accumulate(const Container &container, T init, BinaryFunction &&reductionFunction)
Accumulate shorthand.
Definition: Functional.h:226
unsigned minimalIndex
Simplex vertex with minimal function value.
Definition: NelderMead.h:35
Nelder-Mead optimization.
Definition: NelderMead.h:24
constexpr std::enable_if_t< std::is_floating_point< Traits::getValueType< ContainerType > >::value, Traits::getValueType< ContainerType >> average(const ContainerType &container)
Definition: Numeric.h:64
constexpr auto map(const ArrayType< T, size > &array, UnaryFunction &&function)
Maps all elements of any array-like container with a unary function.
Definition: Containers.h:62
Functional-style container-related algorithms.
FloatType value
Final function value.
Definition: NelderMead.h:33