8 #ifndef INCLUDE_TEMPLE_OPTIMIZATION_SO3_NELDER_MEAD_H 
    9 #define INCLUDE_TEMPLE_OPTIMIZATION_SO3_NELDER_MEAD_H 
   12 #include <Eigen/Eigenvalues> 
   13 #include <unsupported/Eigen/MatrixFunctions> 
   18 namespace Molassembler {
 
   26 template<
typename FloatType = 
double>
 
   28   using Matrix = Eigen::Matrix<FloatType, 3, 3>;
 
   31     Eigen::Matrix<FloatType, 3, 12> matrix;
 
   33     decltype(
auto) at(
const unsigned i) {
 
   35       return matrix.template block<3, 3>(0, 3 * i);
 
   38     decltype(
auto) at(
const unsigned i)
 const {
 
   40       return matrix.template block<3, 3>(0, 3 * i);
 
   43     EIGEN_MAKE_ALIGNED_OPERATOR_NEW
 
   58     template<
typename Derived>
 
   59     static Matrix skew(
const Eigen::MatrixBase<Derived>& m) {
 
   60       return 0.5 * (m - m.transpose());
 
   63     template<
typename DerivedA, 
typename DerivedB>
 
   64     static Matrix log(
const Eigen::MatrixBase<DerivedA>& X, 
const Eigen::MatrixBase<DerivedB>& Y) {
 
   65       return skew((X.transpose() * Y).log().real());
 
   68     template<
typename DerivedA, 
typename DerivedB>
 
   69     static Matrix exp(
const Eigen::MatrixBase<DerivedA>& X, 
const Eigen::MatrixBase<DerivedB>& Y) {
 
   73     template<
typename DerivedA, 
typename DerivedB>
 
   74     static FloatType distanceSquared(
const Eigen::MatrixBase<DerivedA>& X, 
const Eigen::MatrixBase<DerivedB>& Y) {
 
   75       return FloatType {0.5} * log(X, Y).squaredNorm();
 
   78     template<
typename DerivedA, 
typename DerivedB>
 
   79     static Matrix geodesic(
const Eigen::MatrixBase<DerivedA>& a, 
const Eigen::MatrixBase<DerivedB>& b, 
const FloatType 
tau) {
 
   80       return exp(b, tau * log(b, a));
 
   83     template<
typename Derived>
 
   84     static bool contains(
const Eigen::MatrixBase<Derived>& m) {
 
   85       return (m * m.transpose()).isApprox(Matrix::Identity(), 1e-5);
 
   88     static Matrix karcherMean(
const Parameters& points, 
const unsigned excludeIdx) {
 
   89       auto calculateOmega = [excludeIdx](
const Parameters& p, 
const Matrix& speculativeMean) {
 
   90         Matrix omega = Matrix::Zero();
 
   91         for(
unsigned i = 0; i < 4; ++i) {
 
   96           omega += Manifold::log(speculativeMean, p.at(i));
 
  102       constexpr FloatType delta = 1e-5;
 
  103       Matrix q = (excludeIdx == 0) ? points.at(1) : points.at(0);
 
  104       Matrix omega = calculateOmega(points, q);
 
  106       unsigned iterations = 0;
 
  107       while(omega.norm() >= delta && iterations < 100) {
 
  108         q = Manifold::exp(q, omega);
 
  109         assert(q.allFinite());
 
  111         omega = calculateOmega(points, q);
 
  118     static Matrix randomRotation() {
 
  119       auto A = Matrix::Random();
 
  120       Eigen::ColPivHouseholderQR<Matrix> decomposition(A);
 
  121       Matrix Q = decomposition.householderQ();
 
  122       auto R = decomposition.matrixR();
 
  124       Matrix intermediate = Matrix::Zero();
 
  125       for(
unsigned i = 0; i < 3; ++i) {
 
  126         double value = R(i, i);
 
  128           intermediate(i, i) = -1;
 
  129         } 
else if(value > 0) {
 
  130           intermediate(i, i) = 1;
 
  133       Q = Q * intermediate;
 
  136       if(Q.determinant() < 0) {
 
  137         Q.col(0).swap(Q.col(1));
 
  141       assert(Q.allFinite());
 
  148     constexpr FloatType ballRadiusSquared = M_PI * M_PI;
 
  150     parameters.at(0) = Manifold::randomRotation();
 
  151     for(
unsigned i = 1; i < 4; ++i) {
 
  154         R = Manifold::randomRotation();
 
  157           Temple::iota<unsigned>(i),
 
  158           [&](
const unsigned j) -> 
bool {
 
  159             return Manifold::distanceSquared(R, parameters.at(j)) >= ballRadiusSquared;
 
  163       parameters.at(i) = R;
 
  174       return value < other.value;
 
  178   static FloatType valueStandardDeviation(
const std::vector<IndexValuePair>& sortedPairs) {
 
  179     const unsigned V = sortedPairs.size();
 
  184       [](
const FloatType carry, 
const IndexValuePair& pair) -> FloatType {
 
  185         return carry + pair.value;
 
  192         [
average](
const FloatType carry, 
const IndexValuePair& pair) -> FloatType {
 
  193           const FloatType diff = pair.value - 
average;
 
  194           return carry + diff * diff;
 
  200   template<
typename UpdateFunction>
 
  203     std::vector<IndexValuePair>& values,
 
  204     UpdateFunction&& 
function 
  206     constexpr FloatType shrinkCoefficient = 0.5;
 
  207     static_assert(0 < shrinkCoefficient && shrinkCoefficient < 1, 
"Shrink coefficient bounds not met");
 
  208     const Matrix& bestVertex = points.at(values.front().column);
 
  211     for(
unsigned i = 1; i < 4; ++i) {
 
  212       auto& value = values.at(i);
 
  213       Eigen::Ref<Eigen::Matrix3d> vertex = points.at(value.column);
 
  214       vertex = Manifold::geodesic(vertex, bestVertex, shrinkCoefficient);
 
  215       value.value = 
function(vertex);
 
  221   static void replaceWorst(
 
  222     std::vector<IndexValuePair>& sortedPairs,
 
  223     const Matrix& newVertex,
 
  224     const FloatType newValue,
 
  227     assert(std::is_sorted(std::begin(sortedPairs), std::end(sortedPairs)));
 
  228     IndexValuePair replacementPair {
 
  229       sortedPairs.back().column,
 
  234         std::begin(sortedPairs),
 
  235         std::end(sortedPairs),
 
  241     sortedPairs.pop_back();
 
  243     vertices.at(replacementPair.column) = newVertex;
 
  247     typename UpdateFunction,
 
  249   > 
static OptimizationReturnType minimize(
 
  251     UpdateFunction&& 
function,
 
  254     constexpr FloatType reflectionCoefficient = 1;
 
  255     constexpr FloatType expansionCoefficient = 2;
 
  256     constexpr FloatType contractionCoefficient = 0.5;
 
  258     static_assert(0 < reflectionCoefficient, 
"Reflection coefficient bounds not met");
 
  259     static_assert(1 < expansionCoefficient, 
"Expansion coefficient bounds not met");
 
  260     static_assert(0 < contractionCoefficient && contractionCoefficient <= 0.5, 
"Contraction coefficient bounds not met");
 
  268     constexpr FloatType ballRadiusSquared = M_PI * M_PI;
 
  271         Temple::Adaptors::allPairs(Temple::iota<unsigned>(4)),
 
  272         [&points](
const unsigned i, 
const unsigned j) -> 
bool {
 
  273           return Manifold::distanceSquared(points.at(i), points.at(j)) >= ballRadiusSquared;
 
  277       throw std::logic_error(
 
  278         "Initial simplex points do not lie within ball of radius pi/2" 
  283     std::vector<IndexValuePair> values = Temple::sorted(
 
  285         Temple::iota<unsigned>(4),
 
  286         [&](
const unsigned i) -> IndexValuePair {
 
  289             function(points.at(i))
 
  294     assert(values.size() == 4);
 
  296     auto ballCheckingFunction = [](
 
  297       auto&& objectiveFunction,
 
  298       const Parameters& simplexVertices,
 
  299       const Matrix& speculativePoint,
 
  300       const unsigned replacingIndex
 
  302       for(
unsigned i = 0; i < 4; ++i) {
 
  303         if(i == replacingIndex) {
 
  307         if(Manifold::distanceSquared(speculativePoint, simplexVertices.at(i)) >= ballRadiusSquared) {
 
  313           return std::numeric_limits<FloatType>::max();
 
  317       return objectiveFunction(speculativePoint);
 
  320     FloatType standardDeviation;
 
  321     unsigned iteration = 0;
 
  323       const Matrix simplexCentroid = Manifold::karcherMean(points, values.back().column);
 
  324       const Matrix& worstVertex = points.at(values.back().column);
 
  325       const FloatType worstVertexValue = values.back().value;
 
  326       const FloatType bestVertexValue = values.front().value;
 
  329       const Matrix reflectedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -reflectionCoefficient);
 
  330       const FloatType reflectedValue = ballCheckingFunction(
function, points, reflectedVertex, values.back().column);
 
  332       if(reflectedValue < bestVertexValue) {
 
  334         const Matrix expandedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -expansionCoefficient);
 
  335         const FloatType expandedValue = ballCheckingFunction(
function, points, expandedVertex, values.back().column);
 
  337         if(expandedValue < reflectedValue) {
 
  339           replaceWorst(values, expandedVertex, expandedValue, points);
 
  342           replaceWorst(values, reflectedVertex, reflectedValue, points);
 
  344       } 
else if(bestVertexValue <= reflectedValue && reflectedValue < values.at(2).value) {
 
  345         replaceWorst(values, reflectedVertex, reflectedValue, points);
 
  346       } 
else if(values.at(2).value <= reflectedValue && reflectedValue < worstVertexValue) {
 
  348         const Matrix outsideContractedVertex = Manifold::geodesic(worstVertex, simplexCentroid, -(reflectionCoefficient * contractionCoefficient));
 
  349         const FloatType outsideContractedValue = ballCheckingFunction(
function, points, reflectedVertex, values.back().column);
 
  350         if(outsideContractedValue <= reflectedValue) {
 
  351           replaceWorst(values, outsideContractedVertex, outsideContractedValue, points);
 
  353           shrink(points, values, 
function);
 
  357         const Matrix insideContractedVertex = Manifold::geodesic(worstVertex, simplexCentroid, contractionCoefficient);
 
  358         const FloatType insideContractedValue = 
function(insideContractedVertex);
 
  360         if(insideContractedValue < worstVertexValue) {
 
  361           replaceWorst(values, insideContractedVertex, insideContractedValue, points);
 
  363           shrink(points, values, 
function);
 
  367       standardDeviation = valueStandardDeviation(values);
 
  369     } 
while(check.shouldContinue(iteration, values.front().value, standardDeviation));
 
  373       values.front().value,
 
  374       values.front().column
 
void sort(Container &container)
Calls std::sort on a container. 
Definition: Functional.h:261
Nelder-Mead optimization on SO(3) manifold. 
Definition: SO3NelderMead.h:27
FloatType value
Final function value. 
Definition: SO3NelderMead.h:51
unsigned iterations
Number of iterations. 
Definition: SO3NelderMead.h:49
Definition: SO3NelderMead.h:56
Type returned from an optimization. 
Definition: SO3NelderMead.h:47
unsigned minimalIndex
Simplex vertex with minimal function value. 
Definition: SO3NelderMead.h:53
T accumulate(const Container &container, T init, BinaryFunction &&reductionFunction)
Accumulate shorthand. 
Definition: Functional.h:216
Provides pair-generation within a single container or two. 
double tau(const std::vector< double > &angles)
Calculates the tau value for four and five angle symmetries. 
Definition: TauCriteria.h:63
constexpr std::enable_if_t< std::is_floating_point< Traits::getValueType< ContainerType > >::value, Traits::getValueType< ContainerType >> average(const ContainerType &container)
Definition: Numeric.h:64
bool any_of(const Container &container, UnaryPredicate &&predicate=UnaryPredicate{})
any_of shorthand 
Definition: Functional.h:249
constexpr auto map(const ArrayType< T, size > &array, UnaryFunction &&function)
Maps all elements of any array-like container with a unary function. 
Definition: Containers.h:63
Four 3x3 matrices form the simplex vertices. 
Definition: SO3NelderMead.h:30
Functional-style container-related algorithms. 
Definition: SO3NelderMead.h:169