Source code for ensembler.visualisation.plotPotentials

from numbers import Number
from typing import Tuple, List

import matplotlib
import numpy as np
from matplotlib import pyplot as plt, colorbar

import ensembler.potentials.TwoD as pot2D
from ensembler.potentials import OneD as pot
from ensembler.potentials._basicPotentials import _potentialNDCls,_potential1DCls, _potential1DClsPerturbed

from ensembler.visualisation import plot_layout_settings
from ensembler.visualisation import style

for key, value in plot_layout_settings.items():
    matplotlib.rcParams[key] = value


# UTIL FUNCTIONS
[docs]def significant_decimals(s: float) -> float: significant_decimal = 2 if (s % 1 != 0): decimals = str(float(s)).split(".")[-1] for digit in decimals: if (digit == "0"): significant_decimal += 1 else: return round(s, significant_decimal) else: return s
[docs]def plot_potential(potential: _potentialNDCls, positions: list, out_path:str=None, x_range=None, y_range=None, title: str = None, ax=None, ): if(potential.constants[potential.nDimensions] == 1): return plot_1DPotential(potential=potential, positions=positions, out_path=out_path, x_range=x_range, y_range=y_range, title=title, ax=ax) elif(potential.constants[potential.nDimensions] == 2): return plot_2DPotential(potential=potential, positions=positions, title=title, out_path=out_path)
""" 1D Plotting Functions """
[docs]def plot_1DPotential(potential: _potential1DCls, positions: list, out_path: str = None, x_range=None, y_range=None, title: str = None, ax=None): fig, axes = plt.subplots(nrows=1, ncols=2) plot_1DPotential_V(potential=potential, positions=positions, ax=axes[0], x_range=x_range, y_range=y_range, title="", color=style.potential_color(0)) plot_1DPotential_dhdpos(potential=potential, positions=positions, ax=axes[1], x_range=x_range, y_range=y_range, title="") fig.tight_layout() fig.subplots_adjust(top=0.8) fig.suptitle(title, y=0.95) if (title != None) else fig.suptitle("" + str(potential.name), y=0.96) if (out_path): fig.savefig(out_path) plt.close(fig) return fig, out_path
[docs]def plot_1DPotential_V(potential: _potential1DCls, positions: list, color=None, x_range=None, y_range=None, title: str = None, ax=None, y_label: str = "$V\\ [k_B T]$"): # generat Data energies = potential.ene(positions=positions) # is there already a figure? if (ax is None): fig = plt.figure() ax = fig.add_subplot(111) else: fig = None # plot if (color): ax.plot(positions, energies, c=color) else: ax.plot(positions, energies) ax.set_xlim(min(x_range), max(x_range)) if (x_range != None) else ax.set_xlim(min(positions), max(positions)) ax.set_ylim(min(y_range), max(y_range)) if (y_range != None) else ax.set_ylim(min(energies), max(energies)) ax.set_xlabel('$r$') ax.set_ylabel(y_label) ax.set_title(title) if (title != None) else ax.set_title("Potential " + str(potential.name)) if (ax != None): return fig, ax else: return ax
[docs]def plot_1DPotential_dhdpos(potential: _potential1DCls, positions: list, color=style.potential_color(1), x_range=None, y_range=None, title: str = None, ax=None, yUnit: str = "kT"): # generat Data energies = potential.force(positions=positions) # is there already a figure? if (ax is None): fig = plt.figure() ax = fig.add_subplot(111) else: fig = None # plot ax.plot(positions, energies, c=color) ax.set_xlim(min(x_range), max(x_range)) if (x_range != None) else ax.set_xlim(min(positions), max(positions)) ax.set_ylim(min(y_range), max(y_range)) if (y_range != None) else ax.set_ylim(min(energies), max(energies)) ax.set_xlabel('$r$') ax.set_ylabel('$\partial V / \partial r\\ [' + yUnit + ']$') ax.set_title(title) if (title != None) else ax.set_title("Potential " + str(potential.name)) if (ax != None): return fig, ax else: return ax
[docs]def plot_1DPotential_Termoverlay(potential: _potential1DCls, positions: list, x_range=None, y_range=None, title: str = None, ax=None): # generate dat energies = potential.ene(positions=positions) dVdpos = potential.force(positions=positions) # is there already a figure? if (ax is None): fig = plt.figure() ax = fig.add_subplot(111) else: fig = None color = style.potential_color(1) color1 = style.potential_color(2) color2 = style.potential_color(3) ax.plot(positions, energies, label="V", c=color) ax.plot(positions, list(map(abs, dVdpos)), label="absdVdpos", c=color1) ax.plot(positions, dVdpos, label="dVdpos", c=color2) ax.set_xlim(min(x_range), max(x_range)) if (x_range != None) else ax.set_xlim(min(positions), max(positions)) ax.set_ylim(min(y_range), max(y_range)) if (y_range != None) else ax.set_ylim(min([min(energies), min(dVdpos)]), max([max(energies), max(dVdpos)])) ax.set_ylabel("$V/kJ$") ax.set_xlabel("$x$") ax.legend() ax.set_title(title) if (title != None) else ax.set_title("Potential " + str(potential.__name__)) if (ax != None): return fig, ax else: return ax
""" 2D Plotting Functions """
[docs]def plot_2DPotential(potential: pot2D._potential2DCls, positions: List[Tuple[Number, Number]] = None, title: str = None, out_path:str=None, x1_range=None, x2_range=None) -> ( plt.Figure, plt.Axes): """ This function plots the potential energy landscape of a 2D - Potential Function Parameters ---------- V positions2D title out_path x_label y_label space_range point_resolution ax show_plot dpi cmap Returns ------- """ fig, axes = plt.subplots(nrows=1, ncols=1) axes = list([axes]) _, _, surf = plot_2D_potential_V(potential=potential, positions2D=positions, ax=axes[0], space_range=[x1_range, x2_range], title="", x_label="$r_1$", y_label="$r_2$") #plot_2D_potential_force(potential=potential, positions2D=positions, ax=axes[1], space_range=[x1_range,x2_range], # title="") cb = plt.colorbar(surf) cb.set_label("V [kT]") fig.tight_layout() fig.subplots_adjust(top=0.8) fig.suptitle(title, y=0.95) if (title != None) else fig.suptitle("" + str(potential.name), y=0.96) if (out_path): fig.savefig(out_path) plt.close(fig) return fig, out_path
[docs]def plot_2D_potential_V(potential: pot2D._potential2DCls, positions2D: List[Tuple[Number, Number]] = None, title: str = None, out_path:str=None, x_label: str = None, y_label: str = None, space_range: Tuple[Tuple[Number, Number], Tuple[Number, Number]] = (-10, 10), point_resolution: int = 1000, ax=None, dpi: int = 300, cmap=style.qualitative_map) -> ( plt.Figure, plt.Axes, np.array): # build positions if (isinstance(positions2D, type(None))): minX, maxX = min(space_range[0]), max(space_range[0]) minY, maxY = min(space_range[1]), max(space_range[1]) positionsX = np.linspace(minX, maxX, point_resolution) positionsY = np.linspace(minY, maxY, point_resolution) x_positions, y_positions = np.meshgrid(positionsX, positionsY) positions2D = np.array([x_positions.flatten(), y_positions.flatten()]).T else: positions2D = np.array(positions2D) minX, maxX = min(positions2D[:, 0]), max(positions2D[:, 0]) minY, maxY = min(positions2D[:, 1]), max(positions2D[:, 1]) # landscapes V_pots = potential.ene(positions2D) minV, maxV = np.min(V_pots), np.max(V_pots) side = int(np.sqrt(positions2D.shape[0])) V_land = V_pots.reshape([side,side]) # make Figure if (isinstance(ax, type(None))): fig, ax = plt.subplots(ncols=1, dpi=dpi) else: fig = None surf = ax.imshow(V_land, cmap=cmap, extent=[minX, maxX, minY, maxY]) #print(maxX, maxY) ax.set_xlim([minX, maxX-1]) ax.set_ylim([minY, maxY-1]) if (isinstance(x_label, type(None))): ax.set_xlabel("x") else: ax.set_xlabel(x_label) if (isinstance(y_label, type(None))): ax.set_ylabel("y") else: ax.set_ylabel(y_label) ax.set_xticks(np.linspace(minX, maxX, 3)) ax.set_yticks(np.linspace(minY, maxY, 3)) if (isinstance(title, type(None))): ax.set_title("Potential Landscape") else: ax.set_title(title) # color bar: if (not isinstance(fig, type(None))): cbaxes = fig.add_axes([0.9, 0.1, 0.03, 0.8]) cb = plt.colorbar(surf, fraction=0.046, pad=0.04, cax=cbaxes, ticks=list(np.round(np.linspace(minV, maxV, 5), 2))) cb.set_label("V/[kT]") fig.tight_layout() if (isinstance(out_path, str)): fig.savefig(out_path) #else: # fig.show() return fig, out_path, surf
""" MultiState Plotting Functions """ # 1D
[docs]def multiState_overlays(states: list, positions: list = np.linspace(-8, 8, 500), y_range: tuple = (0, 10), title: str = "Multiple state overlay", label_prefix: str = "State", out_path: str = None): fig, ax = plot_1DPotential_V(potential=states[0], positions=positions) for state in states[1:-1]: plot_1DPotential_V(potential=state, positions=positions, ax=ax) plot_1DPotential_V(potential=states[-1], positions=positions, ax=ax, y_range=[0, 10], title=title) for num, line in enumerate(ax.lines): line._label = label_prefix + " " + chr(num + 65) ax.legend() if (out_path): fig.savefig(out_path) plt.close() return fig, out_path
[docs]def plot_2perturbedEnergy_landscape(potential: _potential1DClsPerturbed, positions: list, lambdas: list, cmap=style.qualitative_map, x_range=None, lam_range=None, title: str = None, colbar: bool = False, ax=None): energy_map_lin = [] for y in lambdas: potential.set_lambda(y) energy_map_lin.append(potential.ene(positions)) energy_map_lin = np.array(energy_map_lin) if (ax is None): fig = plt.figure(figsize=(15, 5)) ax = fig.add_subplot(111) colbar = True else: fig = None surf = ax.imshow(energy_map_lin, cmap=cmap, interpolation="nearest", origin='center', extent=[min(positions), max(positions), min(lambdas), max(lambdas)], vmax=100, vmin=0, aspect="auto") if (colbar): colorbar.Colorbar(ax, surf, label='Energy') if (x_range): ax.set_xlim(min(x_range), max(x_range)) if (lam_range): ax.set_ylim(min(lam_range), max(lam_range)) ax.set_xlabel('x') ax.set_ylabel('$\lambda$') if (title): ax.set_title(title) return fig, ax, surf
# show feature landscape per s
[docs]def envPot_differentS_overlay_min0_plot(eds_potential: pot.envelopedPotential, s_values: list, positions: list, y_range: tuple = None, hide_legend: bool = False, title: str = None, out_path: str = None): # generate energy values ys = [] scale = 1 # 0.1 for s in s_values: eds_potential.s = s enes = eds_potential.ene(positions) y_min = min(enes) y = list(map(lambda z: (z - y_min) * scale, enes)) ys.append(y) # plotting fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(20, 10)) for i, (s, y) in enumerate(reversed(list(zip(s_values, ys)))): color = style.potential_color(i % len(style.potential_color)) axes.plot(positions, y, label="s_" + str(significant_decimals(s)), c=color) if (y_range != None): axes.set_ylim(y_range) axes.set_xlim(min(positions), max(positions)) # styling axes.set_ylabel("Vr/[kJ]") axes.set_xlabel("r") axes.set_title("different Vrs aligned at 0 with different s-values overlayed ") ##optionals if (not hide_legend): axes.legend() if (title): fig.suptitle(title) if (out_path): fig.savefig(out_path) #fig.show() return fig, axes
# show feature landscape per s
[docs]def envPot_differentS_overlay_plot(eds_potential: pot.envelopedPotential, s_values: list, positions: list, y_range: tuple = None, hide_legend: bool = False, title: str = None, out_path: str = None, axes=None): # generate energy values ys = [] for s in s_values: eds_potential.s = s enes = eds_potential.ene(positions) ys.append(enes) # plotting if (axes is None): fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(20, 10)) else: fig = None for i, (s, y) in enumerate(reversed(list(zip(s_values, ys)))): color = style.potential_color(i % 12) axes.plot(positions, y, label="s_" + str(significant_decimals(s)), color=color) # styling axes.set_xlim(min(positions), max(positions)) axes.set_ylabel("Vr/[kJ]") axes.set_xlabel("r") if (title is None): axes.set_title("different $V_{r}$s with different s-values overlayed ") else: axes.set_title(title) ##optionals if (y_range != None): axes.set_ylim(y_range) if (not hide_legend): axes.legend() if (title and not isinstance(fig, type(None))): fig.suptitle(title) if (out_path and not isinstance(fig, type(None))): fig.savefig(out_path) #if (not isinstance(fig, type(None))): fig.show() return fig, axes
[docs]def envPot_diffS_compare(eds_potential: pot.envelopedPotential, s_values: list, positions: list, y_range: tuple = None, title: str = None, out_path: str = None): ##row/column ratio per_row = 4 n_rows = (len(s_values) // per_row) + 1 if ((len(s_values) % per_row) > 0) else (len(s_values) // per_row) ##plot fig, axes = plt.subplots(nrows=n_rows, ncols=per_row, figsize=(20, 10)) axes = [ax for ax_row in axes for ax in ax_row] for ind, (ax, s) in enumerate(zip(axes, s_values)): color = style.potential_color(ind % len(style.potential_color)) eds_potential.s = s y = eds_potential.ene(positions) ax.plot(positions, y, c=color) # styling ax.set_xlim(min(positions), max(positions)) ax.set_title("s_" + str(significant_decimals(s))) ax.set_ylabel("Vr/[kJ]") ax.set_xlabel("r") if (y_range != None): ax.set_ylim(y_range) ##optionals if (title): fig.suptitle(title) if (out_path): fig.savefig(out_path) #fig.show() return fig, axes
[docs]def plot_envelopedPotential_system(eds_potential: pot.envelopedPotential, positions: list, s_value: float = None, Eoffi: list = None, y_range: tuple = None, title: str = None, out_path: str = None): if (s_value != None): eds_potential.s = s_value # set new s if (Eoffi != None): if (len(Eoffi) == len(eds_potential.V_is)): eds_potential.set_Eoff(Eoffi) else: raise IOError("There are " + str(len(eds_potential.V_is)) + " states and " + str( Eoffi) + ", but the numbers have to be equal!") ##calc energies energy_Vr = eds_potential.ene(positions) energy_Vis = [state.ene(positions) for state in eds_potential.V_is] num_states = len(eds_potential.V_is) ##plot nicely fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 10)) axes = [ax for ax_row in axes for ax in ax_row] y_values = energy_Vis + [energy_Vr] labels = ["state_" + str(ind) for ind in range(1, len(energy_Vis) + 1)] + ["refState"] for i, (ax, y, label) in enumerate(zip(axes, y_values, labels)): color = style.potential_color(i % 12) ax.plot(positions, y, c=color) ax.set_xlim(min(positions), max(positions)) ax.set_ylim(y_range) ax.set_title(label) ax.set_ylabel("Vr/[kJ]") ax.set_xlabel("r_" + label) ##optionals if (title): fig.suptitle(title) if (out_path): fig.savefig(out_path) #fig.show() return fig, axes
[docs]def plot_envelopedPotential_2State_System(eds_potential: pot.envelopedPotential, positions: list, s_value: float = None, Eoffi: list = None, title: str = None, out_path: str = None, V_max: float = 600, V_min: float = None): if (len(eds_potential.V_is) > 2): raise IOError(__name__ + " can only be used with two states in the potential!") if (s_value != None): eds_potential.s = s_value if (Eoffi != None): if (len(Eoffi) == len(eds_potential.V_is)): eds_potential.set_Eoff(Eoffi) else: raise IOError("There are " + str(len(eds_potential.V_is)) + " states and " + str( Eoffi) + ", but the numbers have to be equal!") # Calculate energies energy_Vr = eds_potential.ene(positions) energy_Vis = [state.ene(positions) for state in eds_potential.V_is] energy_map = [] min_e = 0 for x in positions: row = eds_potential.ene(list(map(lambda y: [[x], [y]], list(positions)))) row_cut = list(map(lambda x: V_max if (V_max != None and float(x) > V_max) else float(x), row)) energy_map.append(row_cut) if (min(row) < min_e): min_e = min(row) if (V_min is None): V_min = min_e ##plot nicely fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 10)) axes = [ax for ax_row in axes for ax in ax_row] y_values = energy_Vis + [energy_Vr] labels = ["State_" + str(ind) for ind in range(1, len(energy_Vis) + 1)] + ["State_R"] # plot the line potentials colors = ["steelblue", "orange", "forestgreen"] for ax, y, label, c in zip(axes, y_values, labels, colors): ax.plot(positions, y, c) ax.set_xlim(min(positions), max(positions)) ax.set_ylim([V_min, V_max]) ax.set_title("Potential $" + label + "$") ax.set_ylabel("$V/[kJ]$") ax.set_xlabel("$r_{ " + label + "} $") # plot phase space surface ax = axes[-1] surf = ax.imshow(energy_map, cmap="inferno", interpolation="nearest", origin='center', extent=[min(positions), max(positions), min(positions), max(positions)], vmax=V_max, vmin=V_min) ax.set_xlabel("$r_{" + labels[0] + "}$") ax.set_ylabel("$r_{" + labels[1] + "}$") ax.set_title("complete phaseSpace of $state_R$") # fig.colorbar(surf, aspect=5, label='Energy/kJ') ##optionals if (title): fig.suptitle(title) if (out_path): fig.savefig(out_path) #fig.show() return fig, axes
[docs]def envPot_diffS_2stateMap_compare(eds_potential: pot.envelopedPotential, s_values: list, positions: list, V_max: float = 500, V_min: float = None, title: str = None, out_path: str = None): ##row/column ratio per_row = 4 n_rows = (len(s_values) // per_row) + 1 if ((len(s_values) % per_row) > 0) else (len(s_values) // per_row) ##plot fig, axes = plt.subplots(nrows=n_rows, ncols=per_row, figsize=(20, 10)) axes = [ax for ax_row in axes for ax in ax_row] first = True for ax, s in zip(axes, s_values): eds_potential.s_i = s min_e = 0 energy_map = [] for x in positions: row = eds_potential.ene(list(map(lambda y: [[x], [y]], list(positions)))) row_cut = list(map(lambda x: V_max if (V_max != None and float(x) > V_max) else float(x), row)) energy_map.append(row_cut) if (min(row) < min_e): min_e = min(row) if (V_min is None and first): V_min = min_e first = False #print("emin: ", min_e) # plot phase space surface surf = ax.imshow(energy_map, cmap="viridis", interpolation="nearest", origin='center', extent=[min(positions), max(positions), min(positions), max(positions)], vmax=V_max, vmin=V_min) ax.set_xlabel("$r_1$") ax.set_ylabel("$r_2$") ax.set_title("complete phaseSpace of $state_R$") fig.colorbar(surf, aspect=10, label='Energy/kJ') ##optionals if (title): fig.suptitle(title) if (out_path): fig.savefig(out_path) #fig.show() return fig, axes
# 2D """ Wrappers for special Cases """
[docs]def plot_2D_2states(V1, V2, space_range: Tuple[Number, Number] = None, point_resolution=500): fig, axes = plt.subplots(ncols=2, figsize=[15, 10]) _, _ , surf1 = plot_2D_potential_V(V1, ax=axes[0], title="State 1", x_label="$\phi/[^{\circ}]$", y_label="$\psi/[^{\circ}]$", space_range=space_range, point_resolution=point_resolution) _, _, surf2 = plot_2D_potential_V(V2, ax=axes[1], title="State 2", x_label="$\phi/[^{\circ}]$", y_label="$\psi/[^{\circ}]$", space_range=space_range, point_resolution=point_resolution) # color bar: ax2 = fig.gca() cbaxes = fig.add_axes([ax2.get_position().x1 * 1.15, ax2.get_position().y0, 0.03, ax2.get_position().height]) cb = plt.colorbar(surf2, cax=cbaxes, ticks=list(np.round(np.linspace(np.min(surf1._A), np.max(surf1._A), 5), 2)), ) cb.set_label("V/[kT]") fig.tight_layout() fig.suptitle("The Two End States for EDS Potential", y=0.9) return fig
[docs]def plot_2D_2State_EDS_potential(eds_pot, out_path: str = None, traj=None, s=100, positions2D=None, space_range=[[-np.pi, np.pi],[-np.pi, np.pi]], point_resolution=500, x_label="$\phi/[^{\circ}$]", y_label="$\psi/[^{\circ}$]", verbose=False): """ Used in publication Parameters ---------- eds_pot out_path traj s positions2D space_range point_resolution x_label y_label verbose Returns ------- """ trajectory_color = style.trajectory_color # build positions if (isinstance(positions2D, type(None))): minX, maxX = min(space_range[0]), max(space_range[0]) minY, maxY = min(space_range[1]), max(space_range[1]) positionsX = np.linspace(minX, maxX, point_resolution) positionsY = np.linspace(minY, maxY, point_resolution) x_positions, y_positions = np.meshgrid(positionsX, positionsY) positions2D = np.array([x_positions.flatten(), y_positions.flatten()]).T else: positions2D = np.array(positions2D) point_resolution = len(np.unique(positions2D[:, 0])) minX, maxX = min(positions2D[:, 0]), max(positions2D[:, 0]) minY, maxY = min(positions2D[:, 1]), max(positions2D[:, 1]) # calc energies for total space # subPotentials eds_pot.s_i = s V1 = eds_pot.V_is[0] V2 = eds_pot.V_is[1] # Energies energies1 = V1.ene(positions2D) energies2 = V2.ene(positions2D) energiesEds = eds_pot.ene(positions2D) side = int(np.sqrt(positions2D.shape[0])) # generate map for 2D if (verbose): print("map data") energies1Map = energies1.reshape([side, side]) energies2Map = energies2.reshape([side, side]) energiesEdsMap = energiesEds.reshape([side, side]) # plotting if (verbose): print("plot") fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=[15, 6], dpi=300) minV, maxV = np.min(energies1Map), np.max(energies1Map) #print(minX, maxX, minY, maxY) surf1 = ax1.imshow(energies1Map, cmap=style.qualitative_map, interpolation="nearest", origin='center', vmax=maxV, vmin=minV, extent=[minX, maxX, minY, maxY]) surf2 = ax2.imshow(energies2Map, cmap=style.qualitative_map, interpolation="nearest", origin='center', vmax=maxV, vmin=minV, extent=[minX, maxX, minY, maxY]) minV, maxV = np.min(energies1), np.max(energies1) surf3 = ax3.imshow(energiesEdsMap, cmap=style.qualitative_map, interpolation="nearest", origin='center', vmax=maxV, vmin=minV, extent=[minX, maxX, minY, maxY]) # color bar: cbaxes = fig.add_axes([1.0, 0.1, 0.03, 0.8]) cb = plt.colorbar(surf3, fraction=0.046, pad=0.04, cax=cbaxes, ticks=list(np.round(np.linspace(minV, maxV, 5), 2))) cb.set_label("V/[kT]") ##LAEBELLING FUN ax1.set_ylim([minY, maxY]) ax2.set_ylim([minY, maxY]) ax3.set_ylim([minY, maxY]) ax1.set_xlim([minX, maxX]) ax2.set_xlim([minX, maxX]) ax3.set_xlim([minX, maxX]) ax1.set_ylabel(y_label, fontsize=18) ax1.set_xlabel(x_label, fontsize=18) ax2.set_xlabel(x_label, fontsize=18) ax3.set_xlabel(x_label, fontsize=18) ax1.set_yticks([minY, minY/2, np.mean([minY, maxY]), maxY/2, maxY]) ax2.set_yticks([]) ax3.set_yticks([]) ax1.set_xticks([minX, minX/2, np.mean([minX, maxX]), maxX/2, maxX]) ax2.set_xticks([minX, minX/2, np.mean([minX, maxX]), maxX/2, maxX]) ax3.set_xticks([minX, minX/2, np.mean([minX, maxX]), maxX/2, maxX]) ax1.tick_params(labelsize=14) ax2.tick_params(labelsize=14) ax3.tick_params(labelsize=14) # put TRAJ in to landscape if (not isinstance(traj, type(None))): vis_pos_x, vis_pos_y = np.squeeze(np.array(list(map(np.array, traj.position)))).T #print("single",vis_pos_x, vis_pos_y) ax1.scatter(vis_pos_x, vis_pos_y, c=trajectory_color,)# alpha=0.3) ax2.scatter(vis_pos_x, vis_pos_y, c=trajectory_color,)# alpha=0.3) ax3.scatter(vis_pos_x, vis_pos_y, c=trajectory_color,)# alpha=0.3) ax1.scatter(vis_pos_x[-1], vis_pos_y[-1], c="r") ax2.scatter(vis_pos_x[-1], vis_pos_y[-1], c="r") ax3.scatter(vis_pos_x[-1], vis_pos_y[-1], c="r") ax1.scatter(vis_pos_x[0], vis_pos_y[0], c="g") ax2.scatter(vis_pos_x[0], vis_pos_y[0], c="g") ax3.scatter(vis_pos_x[0], vis_pos_y[0], c="g") ax1.set_title("State 0", fontsize=20) ax2.set_title("State 1", fontsize=20) ax3.set_title("$s=" + str(eds_pot.s_i) + "$", fontsize=16) fig.suptitle("EDS potential: s=" + str(eds_pot.s_i)) if (isinstance(out_path, type(None))): return fig else: fig.savefig(out_path, bbox_inches='tight') plt.close(fig) return out_path
[docs]def plot_2D_2State_EDS_potential_sDependency(sVal_traj_Dict: (dict, List), eds_pot, out_path: str = None, plot_trajs=False, space_range=[[-np.pi, np.pi],[-np.pi, np.pi]], point_resolution=500, positions2D=None, x_label="$\phi/[^{\circ}$]", y_label="$\psi/[^{\circ}$]", verbose=False): traj_color = "orange" ##positions # build positions if (isinstance(positions2D, type(None))): minX, maxX = min(space_range[0]), max(space_range[0]) minY, maxY = min(space_range[1]), max(space_range[1]) positionsX = np.linspace(minX, maxX, point_resolution) positionsY = np.linspace(minY, maxY, point_resolution) x_positions, y_positions = np.meshgrid(positionsX, positionsY) positions2D = np.array([x_positions.flatten(), y_positions.flatten()]).T else: positions2D = np.array(positions2D) point_resolution = len(np.unique(positions2D[:, 0])) minX, maxX = min(positions2D[:, 0]), max(positions2D[:, 0]) minY, maxY = min(positions2D[:, 1]), max(positions2D[:, 1]) # V1, V2 = eds_pot.V_is if (verbose): print("calc tot space") (V1, V2) = eds_pot.V_is energies1 = V1.ene(positions2D) energies2 = V2.ene(positions2D) # map data if (verbose): print("map data") energies1Map = energies1.reshape([point_resolution, point_resolution]) energies2Map = energies2.reshape([point_resolution, point_resolution]) energyMaps = [energies1Map, energies2Map, []] relative_barrier = round(np.max(energies1Map) - np.min(energies1Map), 2) minV, maxV = min(energies1), min(energies1) + relative_barrier if (verbose): print("plot") # gridspec inside gridspec nrows = len(sVal_traj_Dict) ncols = 3 # 3 states in the system fig = plt.figure(figsize=(7, 21), constrained_layout=False, dpi=300) outer_grid = fig.add_gridspec(nrows, ncols, wspace=0.1, hspace=0.1) for row, s in zip(range(nrows), sorted(sVal_traj_Dict, reverse=True)): if (verbose): print("fun") # eds pot energies eds_pot.s_i = s energiesEds = eds_pot.ene(positions2D) energiesEdsMap = energiesEds.reshape([point_resolution, point_resolution]) energyMaps[-1] = energiesEdsMap eminV, emaxV = np.min(energies1Map), np.max(energies1Map) if (verbose): print("EDS - Barrier: ", emaxV - eminV) if (plot_trajs): tmp_visit_x, tmp_visit_y = np.squeeze(np.array(list(map(np.array, sVal_traj_Dict[s].position)))).T #print(tmp_visit_x, tmp_visit_y) # plot landscapes for col in range(ncols): ax = fig.add_subplot(outer_grid[row, col]) if (col == 2): #eminV, emaxV = np.min(energiesEdsMap), np.max(energiesEdsMap) + relative_barrier surf = ax.imshow(energyMaps[col], cmap=style.qualitative_map, origin='center', vmax=emaxV, vmin=eminV, extent=[minX, maxX, minY, maxY]) # interpolation="nearest", else: surf = ax.imshow(energyMaps[col], cmap=style.qualitative_map, interpolation="nearest", origin='center', vmax=maxV, vmin=minV, extent=[minX, maxX, minY, maxY]) if (plot_trajs): ax.scatter(tmp_visit_x, tmp_visit_y, c=traj_color, alpha=0.3, s=2) # plot trajs ax.set_ylim([minY, maxY]) ax.set_xlim([minX, maxX]) ax.tick_params(labelsize=14) # labelling fun if (row == 0): if (col == 0): ax.set_title("State 1", fontsize=20) elif (col == 1): ax.set_title("State 2", fontsize=20) else: ax.set_title("EDS state", fontsize=20) if (col == 0): ax.set_ylabel(y_label, fontsize=18) ax.set_yticks(np.round([minY, np.mean([minY, maxY]), maxY])) #ax.text(x=-450, y=-0, s="s=" + str(s), rotation=90, verticalalignment="center", # horizontalalignment="center", fontsize=14) else: ax.set_yticks([]) if (row == nrows - 1): ax.set_xlabel(x_label, fontsize=18) ax.set_xticklabels(np.round([minX, np.mean([minX, maxX]), maxX]), rotation=45) else: ax.set_xticks([]) # colorbar cmap = style.qualitative_map norm = matplotlib.colors.Normalize(vmin=minV, vmax=maxV) cbaxes = fig.add_axes([1.0, 0.1, 0.03, 0.8]) cb = matplotlib.colorbar.ColorbarBase(cbaxes, cmap=style.qualitative_map, norm=norm, orientation='vertical', ) cb.set_label("V/[kT]") if (isinstance(out_path, type(None))): return fig else: fig.savefig(out_path, bbox_inches='tight') plt.close(fig) return out_path
if __name__ == "__main__": pass