8 #ifndef INCLUDE_TEMPLE_OPTIMIZATION_NELDER_MEAD_H
9 #define INCLUDE_TEMPLE_OPTIMIZATION_NELDER_MEAD_H
15 namespace Molassembler {
23 template<
typename FloatType =
double>
25 using MatrixType = Eigen::Matrix<FloatType, Eigen::Dynamic, Eigen::Dynamic>;
26 using VectorType = Eigen::Matrix<FloatType, Eigen::Dynamic, 1>;
38 static VectorType generateVertex(
39 const FloatType coefficient,
41 const VectorType& worstVertex
43 return centroid + coefficient * (centroid - worstVertex);
51 return value < other.value;
55 static FloatType valueStandardDeviation(
const std::vector<IndexValuePair>& sortedPairs) {
56 const unsigned V = sortedPairs.size();
61 [](
const FloatType carry,
const IndexValuePair& pair) -> FloatType {
62 return carry + pair.value;
69 [
average](
const FloatType carry,
const IndexValuePair& pair) -> FloatType {
70 const FloatType diff = pair.value -
average;
71 return carry + diff * diff;
79 const MatrixType& vertices,
80 const std::vector<IndexValuePair>& pairs
82 return (vertices.rowwise().sum() - vertices.col(pairs.back().column)) / (vertices.cols() - 1);
85 template<
typename UpdateFunction>
87 Eigen::Ref<MatrixType> vertices,
88 std::vector<IndexValuePair>& pairs,
89 UpdateFunction&&
function
91 constexpr FloatType shrinkCoefficient = 0.5;
92 static_assert(0 < shrinkCoefficient && shrinkCoefficient < 1,
"Shrink coefficient bounds not met");
94 const unsigned V = vertices.cols();
95 const unsigned bestColumn = pairs.front().column;
96 const auto& bestVertex = vertices.col(bestColumn);
99 for(
unsigned i = 1; i < V; ++i) {
100 auto& value = pairs.at(i);
101 auto vertex = vertices.col(value.column);
102 vertex = bestVertex + shrinkCoefficient * (vertex - bestVertex);
103 value.value =
function(vertex);
108 static void replaceWorst(
109 std::vector<IndexValuePair>& sortedPairs,
110 const VectorType& newVertex,
111 const FloatType newValue,
112 Eigen::Ref<MatrixType> vertices
114 assert(std::is_sorted(std::begin(sortedPairs), std::end(sortedPairs)));
115 IndexValuePair replacementPair {
116 sortedPairs.back().column,
122 std::begin(sortedPairs),
123 std::end(sortedPairs),
129 sortedPairs.pop_back();
131 vertices.col(replacementPair.column) = newVertex;
135 typename UpdateFunction,
137 >
static OptimizationReturnType minimize(
138 Eigen::Ref<MatrixType> vertices,
139 UpdateFunction&&
function,
142 constexpr FloatType reflectionCoefficient = 1;
143 constexpr FloatType expansionCoefficient = 2;
144 constexpr FloatType contractionCoefficient = 0.5;
146 static_assert(0 < reflectionCoefficient,
"Reflection coefficient bounds not met");
147 static_assert(1 < expansionCoefficient,
"Expansion coefficient bounds not met");
148 static_assert(0 < contractionCoefficient && contractionCoefficient <= 0.5,
"Contraction coefficient bounds not met");
150 const unsigned N = vertices.rows();
151 assert(vertices.cols() == N + 1);
153 std::vector<IndexValuePair> values = Temple::sorted(
155 Temple::iota<unsigned>(N + 1),
156 [&](
const unsigned i) -> IndexValuePair {
159 function(vertices.col(i))
165 FloatType standardDeviation;
166 unsigned iteration = 0;
168 const VectorType simplexCentroid =
centroid(vertices, values);
169 const VectorType& worstVertex = vertices.col(values.back().column);
170 const FloatType worstVertexValue = values.back().value;
171 const FloatType bestVertexValue = values.front().value;
174 const VectorType reflectedVertex = generateVertex(reflectionCoefficient, simplexCentroid, worstVertex);
175 const FloatType reflectedValue =
function(reflectedVertex);
177 if(reflectedValue < bestVertexValue) {
179 const VectorType expandedVertex = generateVertex(expansionCoefficient, simplexCentroid, worstVertex);
180 const FloatType expandedValue =
function(expandedVertex);
182 if(expandedValue < reflectedValue) {
184 replaceWorst(values, expandedVertex, expandedValue, vertices);
187 replaceWorst(values, reflectedVertex, reflectedValue, vertices);
189 }
else if(bestVertexValue <= reflectedValue && reflectedValue < values.at(N - 1).value) {
190 replaceWorst(values, reflectedVertex, reflectedValue, vertices);
191 }
else if(values.at(N - 1).value <= reflectedValue && reflectedValue < worstVertexValue) {
193 const VectorType outsideContractedVertex = generateVertex(reflectionCoefficient * contractionCoefficient, simplexCentroid, worstVertex);
194 const FloatType outsideContractedValue =
function(outsideContractedVertex);
195 if(outsideContractedValue <= reflectedValue) {
196 replaceWorst(values, outsideContractedVertex, outsideContractedValue, vertices);
198 shrink(vertices, values,
function);
202 const VectorType insideContractedVertex = generateVertex(-contractionCoefficient, simplexCentroid, worstVertex);
203 const FloatType insideContractedValue =
function(insideContractedVertex);
204 if(insideContractedValue < worstVertexValue) {
205 replaceWorst(values, insideContractedVertex, insideContractedValue, vertices);
207 shrink(vertices, values,
function);
211 standardDeviation = valueStandardDeviation(values);
213 }
while(check.shouldContinue(iteration, values.front().value, standardDeviation));
217 values.front().value,
218 values.front().column
void sort(Container &container)
Calls std::sort on a container.
Definition: Functional.h:261
Definition: NelderMead.h:46
Type returned from an optimization.
Definition: NelderMead.h:29
static VectorType centroid(const MatrixType &vertices, const std::vector< IndexValuePair > &pairs)
Calculates the simplex centroid excluding the worst vertex.
Definition: NelderMead.h:78
unsigned iterations
Number of iterations.
Definition: NelderMead.h:31
T accumulate(const Container &container, T init, BinaryFunction &&reductionFunction)
Accumulate shorthand.
Definition: Functional.h:216
unsigned minimalIndex
Simplex vertex with minimal function value.
Definition: NelderMead.h:35
Nelder-Mead optimization.
Definition: NelderMead.h:24
constexpr std::enable_if_t< std::is_floating_point< Traits::getValueType< ContainerType > >::value, Traits::getValueType< ContainerType >> average(const ContainerType &container)
Definition: Numeric.h:64
constexpr auto map(const ArrayType< T, size > &array, UnaryFunction &&function)
Maps all elements of any array-like container with a unary function.
Definition: Containers.h:63
Functional-style container-related algorithms.
FloatType value
Final function value.
Definition: NelderMead.h:33