Molassembler  3.0.0
Molecule graph and conformer library
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Pages
SO3NelderMead.h
Go to the documentation of this file.
1 
8 #ifndef INCLUDE_TEMPLE_OPTIMIZATION_SO3_NELDER_MEAD_H
9 #define INCLUDE_TEMPLE_OPTIMIZATION_SO3_NELDER_MEAD_H
10 
11 #include <Eigen/Core>
12 #include <Eigen/Eigenvalues>
13 #include <unsupported/Eigen/MatrixFunctions>
16 
17 namespace Scine {
18 namespace Molassembler {
19 namespace Temple {
20 
26 template<typename FloatType = double>
27 struct SO3NelderMead {
28  using Matrix = Eigen::Matrix<FloatType, 3, 3>;
30  struct Parameters {
31  Eigen::Matrix<FloatType, 3, 12> matrix;
32 
33  decltype(auto) at(const unsigned i) {
34  assert(i < 4);
35  return matrix.template block<3, 3>(0, 3 * i);
36  }
37 
38  decltype(auto) at(const unsigned i) const {
39  assert(i < 4);
40  return matrix.template block<3, 3>(0, 3 * i);
41  }
42 
43  EIGEN_MAKE_ALIGNED_OPERATOR_NEW
44  };
45 
49  unsigned iterations;
51  FloatType value;
53  unsigned minimalIndex;
54  };
55 
56  struct Manifold {
57  // Returns the skew-symmetric parts of m
58  template<typename Derived>
59  static Matrix skew(const Eigen::MatrixBase<Derived>& m) {
60  return 0.5 * (m - m.transpose());
61  }
62 
63  template<typename DerivedA, typename DerivedB>
64  static Matrix log(const Eigen::MatrixBase<DerivedA>& X, const Eigen::MatrixBase<DerivedB>& Y) {
65  return skew((X.transpose() * Y).log().real());
66  }
67 
68  template<typename DerivedA, typename DerivedB>
69  static Matrix exp(const Eigen::MatrixBase<DerivedA>& X, const Eigen::MatrixBase<DerivedB>& Y) {
70  return X * (Y.exp());
71  }
72 
73  template<typename DerivedA, typename DerivedB>
74  static FloatType distanceSquared(const Eigen::MatrixBase<DerivedA>& X, const Eigen::MatrixBase<DerivedB>& Y) {
75  return FloatType {0.5} * log(X, Y).squaredNorm();
76  }
77 
78  template<typename DerivedA, typename DerivedB>
79  static Matrix geodesic(const Eigen::MatrixBase<DerivedA>& a, const Eigen::MatrixBase<DerivedB>& b, const FloatType tau) {
80  return exp(b, tau * log(b, a));
81  }
82 
83  template<typename Derived>
84  static bool contains(const Eigen::MatrixBase<Derived>& m) {
85  return (m * m.transpose()).isApprox(Matrix::Identity(), 1e-5);
86  }
87 
88  static Matrix karcherMean(const Parameters& points, const unsigned excludeIdx) {
89  auto calculateOmega = [excludeIdx](const Parameters& p, const Matrix& speculativeMean) {
90  Matrix omega = Matrix::Zero();
91  for(unsigned i = 0; i < 4; ++i) {
92  if(i == excludeIdx) {
93  continue;
94  }
95 
96  omega += Manifold::log(speculativeMean, p.at(i));
97  }
98  omega /= 4;
99  return omega;
100  };
101 
102  constexpr FloatType delta = 1e-5;
103  Matrix q = (excludeIdx == 0) ? points.at(1) : points.at(0);
104  Matrix omega = calculateOmega(points, q);
105 
106  unsigned iterations = 0;
107  while(omega.norm() >= delta && iterations < 100) {
108  q = Manifold::exp(q, omega);
109  assert(q.allFinite());
110  assert(contains(q));
111  omega = calculateOmega(points, q);
112  ++iterations;
113  }
114 
115  return q;
116  }
117 
118  static Matrix randomRotation() {
119  auto A = Matrix::Random();
120  Eigen::ColPivHouseholderQR<Matrix> decomposition(A);
121  Matrix Q = decomposition.householderQ();
122  auto R = decomposition.matrixR();
123 
124  Matrix intermediate = Matrix::Zero();
125  for(unsigned i = 0; i < 3; ++i) {
126  double value = R(i, i);
127  if(value < 0) {
128  intermediate(i, i) = -1;
129  } else if(value > 0) {
130  intermediate(i, i) = 1;
131  }
132  }
133  Q = Q * intermediate;
134 
135  // Now Q is in O(n), but not yet in SO(n), which we can ensure with:
136  if(Q.determinant() < 0) {
137  Q.col(0).swap(Q.col(1));
138  }
139 
140  // Is it really orthogonal?
141  assert(Q.allFinite());
142  assert(contains(Q));
143  return Q;
144  }
145  };
146 
147  static Parameters randomParameters() {
148  constexpr FloatType ballRadiusSquared = M_PI * M_PI;
149  Parameters parameters;
150  parameters.at(0) = Manifold::randomRotation();
151  for(unsigned i = 1; i < 4; ++i) {
152  Matrix R;
153  do {
154  R = Manifold::randomRotation();
155  } while(
157  Temple::iota<unsigned>(i),
158  [&](const unsigned j) -> bool {
159  return Manifold::distanceSquared(R, parameters.at(j)) >= ballRadiusSquared;
160  }
161  )
162  );
163  parameters.at(i) = R;
164  }
165 
166  return parameters;
167  }
168 
169  struct IndexValuePair {
170  unsigned column;
171  FloatType value;
172 
173  bool operator < (const IndexValuePair& other) const {
174  return value < other.value;
175  }
176  };
177 
178  static FloatType valueStandardDeviation(const std::vector<IndexValuePair>& sortedPairs) {
179  const unsigned V = sortedPairs.size();
180  // Calculate standard deviation of values
181  const FloatType average = Temple::accumulate(
182  sortedPairs,
183  FloatType {0},
184  [](const FloatType carry, const IndexValuePair& pair) -> FloatType {
185  return carry + pair.value;
186  }
187  ) / V;
188  return std::sqrt(
190  sortedPairs,
191  FloatType {0},
192  [average](const FloatType carry, const IndexValuePair& pair) -> FloatType {
193  const FloatType diff = pair.value - average;
194  return carry + diff * diff;
195  }
196  ) / V
197  );
198  }
199 
200  template<typename UpdateFunction>
201  static void shrink(
202  Parameters& points,
203  std::vector<IndexValuePair>& values,
204  UpdateFunction&& function
205  ) {
206  constexpr FloatType shrinkCoefficient = 0.5;
207  static_assert(0 < shrinkCoefficient && shrinkCoefficient < 1, "Shrink coefficient bounds not met");
208  const Matrix& bestVertex = points.at(values.front().column);
209 
210  // Shrink all points besides the best one and recalculate function values
211  for(unsigned i = 1; i < 4; ++i) {
212  auto& value = values.at(i);
213  Eigen::Ref<Eigen::Matrix3d> vertex = points.at(value.column);
214  vertex = Manifold::geodesic(vertex, bestVertex, shrinkCoefficient);
215  value.value = function(vertex);
216  // NOTE: No need to worry about ball radius in shrink operation
217  }
218  Temple::sort(values);
219  }
220 
221  static void replaceWorst(
222  std::vector<IndexValuePair>& sortedPairs,
223  const Matrix& newVertex,
224  const FloatType newValue,
225  Parameters& vertices
226  ) {
227  assert(std::is_sorted(std::begin(sortedPairs), std::end(sortedPairs)));
228  IndexValuePair replacementPair {
229  sortedPairs.back().column,
230  newValue
231  };
232  sortedPairs.insert(
233  std::lower_bound(
234  std::begin(sortedPairs),
235  std::end(sortedPairs),
236  replacementPair
237  ),
238  replacementPair
239  );
240  // Drop the worst value
241  sortedPairs.pop_back();
242  // Replace the worst column
243  vertices.at(replacementPair.column) = newVertex;
244  }
245 
246  template<
247  typename UpdateFunction,
248  typename Checker
249  > static OptimizationReturnType minimize(
250  Parameters& points,
251  UpdateFunction&& function,
252  Checker&& check
253  ) {
254  constexpr FloatType reflectionCoefficient = 1;
255  constexpr FloatType expansionCoefficient = 2;
256  constexpr FloatType contractionCoefficient = 0.5;
257 
258  static_assert(0 < reflectionCoefficient, "Reflection coefficient bounds not met");
259  static_assert(1 < expansionCoefficient, "Expansion coefficient bounds not met");
260  static_assert(0 < contractionCoefficient && contractionCoefficient <= 0.5, "Contraction coefficient bounds not met");
261 
262  /* We need the points of the simplex to always lie within a ball of radius
263  * pi/2 so that geodesics and the karcher mean are unique.
264  *
265  * We think that if all pairs of points have distance less than pi from one
266  * another, they are within a ball of radius pi/2.
267  */
268  constexpr FloatType ballRadiusSquared = M_PI * M_PI;
269  if(
271  Temple::Adaptors::allPairs(Temple::iota<unsigned>(4)),
272  [&points](const unsigned i, const unsigned j) -> bool {
273  return Manifold::distanceSquared(points.at(i), points.at(j)) >= ballRadiusSquared;
274  }
275  )
276  ) {
277  throw std::logic_error(
278  "Initial simplex points do not lie within ball of radius pi/2"
279  );
280  }
281 
282  // Sort the vertex values
283  std::vector<IndexValuePair> values = Temple::sorted(
284  Temple::map(
285  Temple::iota<unsigned>(4),
286  [&](const unsigned i) -> IndexValuePair {
287  return {
288  i,
289  function(points.at(i))
290  };
291  }
292  )
293  );
294  assert(values.size() == 4);
295 
296  auto ballCheckingFunction = [](
297  auto&& objectiveFunction,
298  const Parameters& simplexVertices,
299  const Matrix& speculativePoint,
300  const unsigned replacingIndex
301  ) -> FloatType {
302  for(unsigned i = 0; i < 4; ++i) {
303  if(i == replacingIndex) {
304  continue;
305  }
306 
307  if(Manifold::distanceSquared(speculativePoint, simplexVertices.at(i)) >= ballRadiusSquared) {
308  /* The new point will lie outside a ball of radius pi / 2 for the
309  * existing points. Geodesics may no longer be unique and the Karcher
310  * mean may be incalculable. Returning a near-infinite function value
311  * will discourage the use of this point.
312  */
313  return std::numeric_limits<FloatType>::max();
314  }
315  }
316 
317  return objectiveFunction(speculativePoint);
318  };
319 
320  FloatType standardDeviation;
321  unsigned iteration = 0;
322  do {
323  const Matrix simplexCentroid = Manifold::karcherMean(points, values.back().column);
324  const Matrix& worstVertex = points.at(values.back().column);
325  const FloatType worstVertexValue = values.back().value;
326  const FloatType bestVertexValue = values.front().value;
327 
328  /* Reflect */
329  const Matrix reflectedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -reflectionCoefficient);
330  const FloatType reflectedValue = ballCheckingFunction(function, points, reflectedVertex, values.back().column);
331 
332  if(reflectedValue < bestVertexValue) {
333  /* Expansion */
334  const Matrix expandedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -expansionCoefficient);
335  const FloatType expandedValue = ballCheckingFunction(function, points, expandedVertex, values.back().column);
336 
337  if(expandedValue < reflectedValue) {
338  // Replace the worst value with the expanded point
339  replaceWorst(values, expandedVertex, expandedValue, points);
340  } else {
341  // Replace the worst value with the reflected point
342  replaceWorst(values, reflectedVertex, reflectedValue, points);
343  }
344  } else if(bestVertexValue <= reflectedValue && reflectedValue < values.at(2).value) {
345  replaceWorst(values, reflectedVertex, reflectedValue, points);
346  } else if(values.at(2).value <= reflectedValue && reflectedValue < worstVertexValue) {
347  /* Outside contraction */
348  const Matrix outsideContractedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -(reflectionCoefficient * contractionCoefficient));
349  const FloatType outsideContractedValue = ballCheckingFunction(function, points, reflectedVertex, values.back().column);
350  if(outsideContractedValue <= reflectedValue) {
351  replaceWorst(values, outsideContractedVertex, outsideContractedValue, points);
352  } else {
353  shrink(points, values, function);
354  }
355  } else {
356  /* Inside contraction */
357  const Matrix insideContractedVertex = Manifold::geodesic(worstVertex, simplexCentroid, contractionCoefficient);
358  const FloatType insideContractedValue = function(insideContractedVertex);
359  // NOTE: No need to worry about ball radius in inside contraction
360  if(insideContractedValue < worstVertexValue) {
361  replaceWorst(values, insideContractedVertex, insideContractedValue, points);
362  } else {
363  shrink(points, values, function);
364  }
365  }
366 
367  standardDeviation = valueStandardDeviation(values);
368  ++iteration;
369  } while(check.shouldContinue(iteration, values.front().value, standardDeviation));
370 
371  return {
372  iteration,
373  values.front().value,
374  values.front().column
375  };
376  }
377 };
378 
379 } // namespace Temple
380 } // namespace Molassembler
381 } // namespace Scine
382 
383 #endif
void sort(Container &container)
Calls std::sort on a container.
Definition: Functional.h:261
Nelder-Mead optimization on SO(3) manifold.
Definition: SO3NelderMead.h:27
FloatType value
Final function value.
Definition: SO3NelderMead.h:51
unsigned iterations
Number of iterations.
Definition: SO3NelderMead.h:49
Type returned from an optimization.
Definition: SO3NelderMead.h:47
unsigned minimalIndex
Simplex vertex with minimal function value.
Definition: SO3NelderMead.h:53
T accumulate(const Container &container, T init, BinaryFunction &&reductionFunction)
Accumulate shorthand.
Definition: Functional.h:216
Provides pair-generation within a single container or two.
double tau(const std::vector< double > &angles)
Calculates the tau value for four and five angle symmetries.
Definition: TauCriteria.h:63
constexpr std::enable_if_t< std::is_floating_point< Traits::getValueType< ContainerType > >::value, Traits::getValueType< ContainerType >> average(const ContainerType &container)
Definition: Numeric.h:64
bool any_of(const Container &container, UnaryPredicate &&predicate=UnaryPredicate{})
any_of shorthand
Definition: Functional.h:249
constexpr auto map(const ArrayType< T, size > &array, UnaryFunction &&function)
Maps all elements of any array-like container with a unary function.
Definition: Containers.h:63
Four 3x3 matrices form the simplex vertices.
Definition: SO3NelderMead.h:30
Functional-style container-related algorithms.