import matplotlib.pyplot as plt
import numpy as np
import mne
from regression.helpers import load_sensor_locations

class HelmetPlot():
    def __init__(self, positions_file_loc, sphere_size = 47.0):
        self.positions = load_sensor_locations(positions_file_loc, partial_sensors=False)
        self.sphere_size = sphere_size
    
    def plot(self, values:np.ndarray,title = None,
            cmap="RdBu_r",vlim=(-0.1,0.1),colorbar_title=None, fontsize = 20):
        fig,ax1 = plt.subplots(ncols=1)
        fig.suptitle(title, fontsize=fontsize)
        if vlim:
            im,cm = mne.viz.plot_topomap(values,self.positions,axes=ax1,vlim=vlim,
                            show=False,size = 4,outlines="head",cmap=cmap,sphere=self.sphere_size) 
        else:
            im,cm = mne.viz.plot_topomap(values,self.positions,axes=ax1,
                                        show=False,size = 4,outlines="head",cmap=cmap,sphere=self.sphere_size) 
        cbar = fig.colorbar(im, ax=ax1)
        if colorbar_title:
            cbar.ax.set_title(colorbar_title)
        return fig
