Source code for scine_autocas.plots.threshold_plot

"""The threshold diagram class.

This module implements a class to plot a threshold diagram.
"""
# -*- coding: utf-8 -*-
__copyright__ = """This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details. """

from typing import Any

import matplotlib.pyplot as plt
import numpy as np


[docs]class ThresholdPlot: """A Class to plot the threshold diagram for s1 values. A threshold diagram provides a graphical way to check back the autocas plateau algorithm. Attributes ---------- plateau_elements : int, default = 10 number of steps required to form a plateau """ __slots__ = ["plateau_elements"]
[docs] def __init__(self): """Construct a treshholdplot object.""" self.plateau_elements: int = 10 """defines the number of elements necessary to count as a plateau"""
[docs] def set_number_of_platueau_elements(self, n_elements: int): """ Set number of steps required for a plateau. Parameters ---------- n_elements : int new number of n_elements """ self.plateau_elements = n_elements
# currently not in use # def plot(self, s1=None, mutinf = None): # if(s1 != None): # plot_s1(s1) # elif(mutinf != None): # plot_mutual_information(mutinf) # else: # print("please provide either 's1', or 'mutinf' as argument")
[docs] def plot(self, s1_entropy: np.ndarray) -> Any: """ Plot the threshold diagram from the single orbital entropy. Parameters ---------- s1 : np.ndarray single orbital entropy Returns ------- plt : matplotlib.pyplot the matplotlib object """ max_s1 = max(s1_entropy) number_of_orbitals = len(s1_entropy) orbitals_index = np.arange(1, number_of_orbitals + 1) thresholds_list = np.zeros((number_of_orbitals)) for i in range(number_of_orbitals): thresholds_list[i] = s1_entropy[i] / max_s1 # sort arrays decreasing sortkey = np.argsort(-thresholds_list) # because numpy 1.19 sortkey = sortkey.astype(int, copy=False) thresholds_list = thresholds_list[sortkey] orbitals_index = orbitals_index[sortkey] x_values = np.arange(0, 1.01, 0.01) y_values = np.zeros((101)) for i, x_value in enumerate(x_values): y_values[i] = sum(thresholds_list > x_value) # get plateau tmp_val = y_values[0] plateau_vector = [] plateau_index = [] thresh_count = 1 for i, y_value in enumerate(y_values): if y_value == tmp_val: thresh_count = thresh_count + 1 else: thresh_count = 1 tmp_val = y_value if thresh_count >= 10: plateau_vector.append(y_value) plateau_index.append(i / 100) # plot threshold diagram plt.figure() plt.plot(x_values, y_values, marker="o", markersize=3, ls="", color="#203f78") # plt.plot(plateau_index, plateau_vector, ls="", marker="x") plt.ylabel("# selected Orbitals", fontsize=16) plt.xlabel("threshold in % of largest element", fontsize=16) if number_of_orbitals > 40: plt.yticks(np.arange(0, number_of_orbitals + 1, 5), fontsize=16) else: plt.yticks(np.arange(0, number_of_orbitals + 1, 1), fontsize=16) plt.xticks(np.arange(0, 1.1, 0.1), fontsize=16) return plt