8 #ifndef INCLUDE_TEMPLE_OPTIMIZATION_SO3_NELDER_MEAD_H
9 #define INCLUDE_TEMPLE_OPTIMIZATION_SO3_NELDER_MEAD_H
12 #include <Eigen/Eigenvalues>
13 #include <unsupported/Eigen/MatrixFunctions>
18 namespace Molassembler {
26 template<
typename FloatType =
double>
28 using Matrix = Eigen::Matrix<FloatType, 3, 3>;
31 Eigen::Matrix<FloatType, 3, 12> matrix;
33 decltype(
auto) at(
const unsigned i) {
35 return matrix.template block<3, 3>(0, 3 * i);
38 decltype(
auto) at(
const unsigned i)
const {
40 return matrix.template block<3, 3>(0, 3 * i);
43 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
58 template<
typename Derived>
59 static Matrix skew(
const Eigen::MatrixBase<Derived>& m) {
60 return 0.5 * (m - m.transpose());
63 template<
typename DerivedA,
typename DerivedB>
64 static Matrix log(
const Eigen::MatrixBase<DerivedA>& X,
const Eigen::MatrixBase<DerivedB>& Y) {
65 return skew((X.transpose() * Y).log().real());
68 template<
typename DerivedA,
typename DerivedB>
69 static Matrix exp(
const Eigen::MatrixBase<DerivedA>& X,
const Eigen::MatrixBase<DerivedB>& Y) {
73 template<
typename DerivedA,
typename DerivedB>
74 static FloatType distanceSquared(
const Eigen::MatrixBase<DerivedA>& X,
const Eigen::MatrixBase<DerivedB>& Y) {
75 return FloatType {0.5} * log(X, Y).squaredNorm();
78 template<
typename DerivedA,
typename DerivedB>
79 static Matrix geodesic(
const Eigen::MatrixBase<DerivedA>& a,
const Eigen::MatrixBase<DerivedB>& b,
const FloatType
tau) {
80 return exp(b, tau * log(b, a));
83 template<
typename Derived>
84 static bool contains(
const Eigen::MatrixBase<Derived>& m) {
85 return (m * m.transpose()).isApprox(Matrix::Identity(), 1e-5);
88 static Matrix karcherMean(
const Parameters& points,
const unsigned excludeIdx) {
89 auto calculateOmega = [excludeIdx](
const Parameters& p,
const Matrix& speculativeMean) {
90 Matrix omega = Matrix::Zero();
91 for(
unsigned i = 0; i < 4; ++i) {
96 omega += Manifold::log(speculativeMean, p.at(i));
102 constexpr FloatType delta = 1e-5;
103 Matrix q = (excludeIdx == 0) ? points.at(1) : points.at(0);
104 Matrix omega = calculateOmega(points, q);
106 unsigned iterations = 0;
107 while(omega.norm() >= delta && iterations < 100) {
108 q = Manifold::exp(q, omega);
109 assert(q.allFinite());
111 omega = calculateOmega(points, q);
118 static Matrix randomRotation() {
119 auto A = Matrix::Random();
120 Eigen::ColPivHouseholderQR<Matrix> decomposition(A);
121 Matrix Q = decomposition.householderQ();
122 auto R = decomposition.matrixR();
124 Matrix intermediate = Matrix::Zero();
125 for(
unsigned i = 0; i < 3; ++i) {
126 double value = R(i, i);
128 intermediate(i, i) = -1;
129 }
else if(value > 0) {
130 intermediate(i, i) = 1;
133 Q = Q * intermediate;
136 if(Q.determinant() < 0) {
137 Q.col(0).swap(Q.col(1));
141 assert(Q.allFinite());
148 constexpr FloatType ballRadiusSquared = M_PI * M_PI;
150 parameters.at(0) = Manifold::randomRotation();
151 for(
unsigned i = 1; i < 4; ++i) {
154 R = Manifold::randomRotation();
157 Temple::iota<unsigned>(i),
158 [&](
const unsigned j) ->
bool {
159 return Manifold::distanceSquared(R, parameters.at(j)) >= ballRadiusSquared;
163 parameters.at(i) = R;
174 return value < other.value;
178 static FloatType valueStandardDeviation(
const std::vector<IndexValuePair>& sortedPairs) {
179 const unsigned V = sortedPairs.size();
184 [](
const FloatType carry,
const IndexValuePair& pair) -> FloatType {
185 return carry + pair.value;
192 [
average](
const FloatType carry,
const IndexValuePair& pair) -> FloatType {
193 const FloatType diff = pair.value -
average;
194 return carry + diff * diff;
200 template<
typename UpdateFunction>
203 std::vector<IndexValuePair>& values,
204 UpdateFunction&&
function
206 constexpr FloatType shrinkCoefficient = 0.5;
207 static_assert(0 < shrinkCoefficient && shrinkCoefficient < 1,
"Shrink coefficient bounds not met");
208 const Matrix& bestVertex = points.at(values.front().column);
211 for(
unsigned i = 1; i < 4; ++i) {
212 auto& value = values.at(i);
213 Eigen::Ref<Eigen::Matrix3d> vertex = points.at(value.column);
214 vertex = Manifold::geodesic(vertex, bestVertex, shrinkCoefficient);
215 value.value =
function(vertex);
221 static void replaceWorst(
222 std::vector<IndexValuePair>& sortedPairs,
223 const Matrix& newVertex,
224 const FloatType newValue,
227 assert(std::is_sorted(std::begin(sortedPairs), std::end(sortedPairs)));
228 IndexValuePair replacementPair {
229 sortedPairs.back().column,
234 std::begin(sortedPairs),
235 std::end(sortedPairs),
241 sortedPairs.pop_back();
243 vertices.at(replacementPair.column) = newVertex;
247 typename UpdateFunction,
249 >
static OptimizationReturnType minimize(
251 UpdateFunction&&
function,
254 constexpr FloatType reflectionCoefficient = 1;
255 constexpr FloatType expansionCoefficient = 2;
256 constexpr FloatType contractionCoefficient = 0.5;
258 static_assert(0 < reflectionCoefficient,
"Reflection coefficient bounds not met");
259 static_assert(1 < expansionCoefficient,
"Expansion coefficient bounds not met");
260 static_assert(0 < contractionCoefficient && contractionCoefficient <= 0.5,
"Contraction coefficient bounds not met");
268 constexpr FloatType ballRadiusSquared = M_PI * M_PI;
271 Temple::Adaptors::allPairs(Temple::iota<unsigned>(4)),
272 [&points](
const unsigned i,
const unsigned j) ->
bool {
273 return Manifold::distanceSquared(points.at(i), points.at(j)) >= ballRadiusSquared;
277 throw std::logic_error(
278 "Initial simplex points do not lie within ball of radius pi/2"
283 std::vector<IndexValuePair> values = Temple::sorted(
285 Temple::iota<unsigned>(4),
286 [&](
const unsigned i) -> IndexValuePair {
289 function(points.at(i))
294 assert(values.size() == 4);
296 auto ballCheckingFunction = [](
297 auto&& objectiveFunction,
298 const Parameters& simplexVertices,
299 const Matrix& speculativePoint,
300 const unsigned replacingIndex
302 for(
unsigned i = 0; i < 4; ++i) {
303 if(i == replacingIndex) {
307 if(Manifold::distanceSquared(speculativePoint, simplexVertices.at(i)) >= ballRadiusSquared) {
313 return std::numeric_limits<FloatType>::max();
317 return objectiveFunction(speculativePoint);
320 FloatType standardDeviation;
321 unsigned iteration = 0;
323 const Matrix simplexCentroid = Manifold::karcherMean(points, values.back().column);
324 const Matrix& worstVertex = points.at(values.back().column);
325 const FloatType worstVertexValue = values.back().value;
326 const FloatType bestVertexValue = values.front().value;
329 const Matrix reflectedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -reflectionCoefficient);
330 const FloatType reflectedValue = ballCheckingFunction(
function, points, reflectedVertex, values.back().column);
332 if(reflectedValue < bestVertexValue) {
334 const Matrix expandedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -expansionCoefficient);
335 const FloatType expandedValue = ballCheckingFunction(
function, points, expandedVertex, values.back().column);
337 if(expandedValue < reflectedValue) {
339 replaceWorst(values, expandedVertex, expandedValue, points);
342 replaceWorst(values, reflectedVertex, reflectedValue, points);
344 }
else if(bestVertexValue <= reflectedValue && reflectedValue < values.at(2).value) {
345 replaceWorst(values, reflectedVertex, reflectedValue, points);
346 }
else if(values.at(2).value <= reflectedValue && reflectedValue < worstVertexValue) {
348 const Matrix outsideContractedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -(reflectionCoefficient * contractionCoefficient));
349 const FloatType outsideContractedValue = ballCheckingFunction(
function, points, reflectedVertex, values.back().column);
350 if(outsideContractedValue <= reflectedValue) {
351 replaceWorst(values, outsideContractedVertex, outsideContractedValue, points);
353 shrink(points, values,
function);
357 const Matrix insideContractedVertex = Manifold::geodesic(worstVertex, simplexCentroid, contractionCoefficient);
358 const FloatType insideContractedValue =
function(insideContractedVertex);
360 if(insideContractedValue < worstVertexValue) {
361 replaceWorst(values, insideContractedVertex, insideContractedValue, points);
363 shrink(points, values,
function);
367 standardDeviation = valueStandardDeviation(values);
369 }
while(check.shouldContinue(iteration, values.front().value, standardDeviation));
373 values.front().value,
374 values.front().column
void sort(Container &container)
Calls std::sort on a container.
Definition: Functional.h:261
Nelder-Mead optimization on SO(3) manifold.
Definition: SO3NelderMead.h:27
FloatType value
Final function value.
Definition: SO3NelderMead.h:51
unsigned iterations
Number of iterations.
Definition: SO3NelderMead.h:49
Definition: SO3NelderMead.h:56
Type returned from an optimization.
Definition: SO3NelderMead.h:47
unsigned minimalIndex
Simplex vertex with minimal function value.
Definition: SO3NelderMead.h:53
T accumulate(const Container &container, T init, BinaryFunction &&reductionFunction)
Accumulate shorthand.
Definition: Functional.h:216
Provides pair-generation within a single container or two.
double tau(const std::vector< double > &angles)
Calculates the tau value for four and five angle symmetries.
Definition: TauCriteria.h:63
constexpr std::enable_if_t< std::is_floating_point< Traits::getValueType< ContainerType > >::value, Traits::getValueType< ContainerType >> average(const ContainerType &container)
Definition: Numeric.h:64
bool any_of(const Container &container, UnaryPredicate &&predicate=UnaryPredicate{})
any_of shorthand
Definition: Functional.h:249
constexpr auto map(const ArrayType< T, size > &array, UnaryFunction &&function)
Maps all elements of any array-like container with a unary function.
Definition: Containers.h:63
Four 3x3 matrices form the simplex vertices.
Definition: SO3NelderMead.h:30
Functional-style container-related algorithms.
Definition: SO3NelderMead.h:169