Source code for scine_autocas.plots.threshold_plot

"""The threshold diagram class.

This module implements a class to plot a threshold diagram.
"""
# -*- coding: utf-8 -*-
__copyright__ = """This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Laboratory of Physical Chemistry, Reiher Group.
See LICENSE.txt for details. """

from typing import Any

import matplotlib.pyplot as plt
import numpy as np


[docs]class ThresholdPlot: """A Class to plot the threshold diagram for s1 values. A threshold diagram provides a graphical way to check back the autocas plateau algorithm. Attributes ---------- plateau_elements : int, default = 10 number of steps required to form a plateau """ __slots__ = ["plateau_elements"]
[docs] def __init__(self): """Construct a treshholdplot object.""" self.plateau_elements: int = 10 """defines the number of elements necessary to count as a plateau"""
[docs] def set_number_of_platueau_elements(self, n_elements: int): """ Set number of steps required for a plateau. Parameters ---------- n_elements : int new number of n_elements """ self.plateau_elements = n_elements
# currently not in use # def plot(self, s1=None, mutinf = None): # if(s1 != None): # plot_s1(s1) # elif(mutinf != None): # plot_mutual_information(mutinf) # else: # print("please provide either 's1', or 'mutinf' as argument")
[docs] def plot(self, s1_entropy: np.ndarray) -> Any: """ Plot the threshold diagram from the single orbital entropy. Parameters ---------- s1 : np.ndarray single orbital entropy Returns ------- plt : matplotlib.pyplot the matplotlib object """ max_s1 = max(s1_entropy) number_of_orbitals = len(s1_entropy) orbitals_index = np.arange(1, number_of_orbitals + 1) thresholds_list = np.zeros((number_of_orbitals)) for i in range(number_of_orbitals): thresholds_list[i] = s1_entropy[i] / max_s1 # sort arrays decreasing sortkey = np.argsort(-thresholds_list) # because numpy 1.19 sortkey = sortkey.astype(int, copy=False) thresholds_list = thresholds_list[sortkey] orbitals_index = orbitals_index[sortkey] x_values = np.arange(0, 1.01, 0.01) y_values = np.zeros((101)) for i, x_value in enumerate(x_values): y_values[i] = sum(thresholds_list > x_value) # get plateau tmp_val = y_values[0] plateau_vector = [] plateau_index = [] thresh_count = 1 for i, y_value in enumerate(y_values): if y_value == tmp_val: thresh_count = thresh_count + 1 else: thresh_count = 1 tmp_val = y_value if thresh_count >= 10: plateau_vector.append(y_value) plateau_index.append(i / 100) # plot threshold diagram plt.figure() plt.plot(x_values, y_values, marker="o", markersize=3, ls="", color="#203f78") # plt.plot(plateau_index, plateau_vector, ls="", marker="x") plt.ylabel("# selected Orbitals", fontsize=16) plt.xlabel("threshold in % of largest element", fontsize=16) if number_of_orbitals > 40: plt.yticks(np.arange(0, number_of_orbitals + 1, 5), fontsize=16) else: plt.yticks(np.arange(0, number_of_orbitals + 1, 1), fontsize=16) plt.xticks(np.arange(0, 1.1, 0.1), fontsize=16) return plt