"""The threshold diagram class.
This module implements a class to plot a threshold diagram.
"""
# -*- coding: utf-8 -*-
__copyright__ = """This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Laboratory of Physical Chemistry, Reiher Group.
See LICENSE.txt for details. """
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
[docs]class ThresholdPlot:
"""A Class to plot the threshold diagram for s1 values.
A threshold diagram provides a graphical way to check back the autocas
plateau algorithm.
Attributes
----------
plateau_elements : int, default = 10
number of steps required to form a plateau
"""
__slots__ = ["plateau_elements"]
[docs] def __init__(self):
"""Construct a treshholdplot object."""
self.plateau_elements: int = 10
"""defines the number of elements necessary to count as a plateau"""
[docs] def set_number_of_platueau_elements(self, n_elements: int):
"""
Set number of steps required for a plateau.
Parameters
----------
n_elements : int
new number of n_elements
"""
self.plateau_elements = n_elements
# currently not in use
# def plot(self, s1=None, mutinf = None):
# if(s1 != None):
# plot_s1(s1)
# elif(mutinf != None):
# plot_mutual_information(mutinf)
# else:
# print("please provide either 's1', or 'mutinf' as argument")
[docs] def plot(self, s1_entropy: np.ndarray) -> Any:
"""
Plot the threshold diagram from the single orbital entropy.
Parameters
----------
s1 : np.ndarray
single orbital entropy
Returns
-------
plt : matplotlib.pyplot
the matplotlib object
"""
max_s1 = max(s1_entropy)
number_of_orbitals = len(s1_entropy)
orbitals_index = np.arange(1, number_of_orbitals + 1)
thresholds_list = np.zeros((number_of_orbitals))
for i in range(number_of_orbitals):
thresholds_list[i] = s1_entropy[i] / max_s1
# sort arrays decreasing
sortkey = np.argsort(-thresholds_list)
# because numpy 1.19
sortkey = sortkey.astype(int, copy=False)
thresholds_list = thresholds_list[sortkey]
orbitals_index = orbitals_index[sortkey]
x_values = np.arange(0, 1.01, 0.01)
y_values = np.zeros((101))
for i, x_value in enumerate(x_values):
y_values[i] = sum(thresholds_list > x_value)
# get plateau
tmp_val = y_values[0]
plateau_vector = []
plateau_index = []
thresh_count = 1
for i, y_value in enumerate(y_values):
if y_value == tmp_val:
thresh_count = thresh_count + 1
else:
thresh_count = 1
tmp_val = y_value
if thresh_count >= 10:
plateau_vector.append(y_value)
plateau_index.append(i / 100)
# plot threshold diagram
plt.figure()
plt.plot(x_values, y_values, marker="o", markersize=3, ls="", color="#203f78")
# plt.plot(plateau_index, plateau_vector, ls="", marker="x")
plt.ylabel("# selected Orbitals", fontsize=16)
plt.xlabel("threshold in % of largest element", fontsize=16)
if number_of_orbitals > 40:
plt.yticks(np.arange(0, number_of_orbitals + 1, 5), fontsize=16)
else:
plt.yticks(np.arange(0, number_of_orbitals + 1, 1), fontsize=16)
plt.xticks(np.arange(0, 1.1, 0.1), fontsize=16)
return plt