from matplotlib.collections import LineCollection
import torch
from enzyme import CACHE_DIR, FIGPATH, PRJ_ROOT



def save_plot(
        name,
        path=FIGPATH,
        fig=None,
        file_formats=["svg", "pdf", "png"],
        **save_args
):  
    transparent = save_args.get("transparent", True)
    for file_format in file_formats:
        fig.savefig(path / (name + f".{file_format}"), transparent=True, **save_args)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.scale import ScaleBase, register_scale
import matplotlib.transforms as mtransforms
from matplotlib.ticker import LogLocator, LogFormatterMathtext

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.scale import ScaleBase, register_scale
import matplotlib.transforms as mtransforms
from matplotlib.ticker import LogLocator, LogFormatter

def riemann(integrand, a, b, n, args=()):
    dx = (b - a) / n
    x = np.linspace(a, b, n)
    return np.sum(integrand(x, *args))*dx, 0.


class CustomLogLocator(LogLocator):
    def tick_values(self, vmin, vmax):
        # Transform the range to cluster close to 0
        vmin = 1 - vmin
        vmax = 1 - vmax

        # Get the tick locations in the transformed range
        ticks = super().tick_values(vmin, vmax)

        # Transform the tick locations back to the original range
        return 1 - ticks
    
from scipy.stats import beta
from math import gamma

from numba import njit, vectorize
beta = lambda alpha, beta, x: gamma(alpha+beta)/(gamma(alpha)*gamma(beta)) * x**(alpha-1) * (1-x)**(beta-1)
# beta = vectorize(nopython=True)(beta)
    
def get_beta_prior(theta_mean, theta_vals, N_pseudo=1):
    a_beta_dist = N_pseudo*theta_mean + 0
    b_beta_dist = N_pseudo*(1-theta_mean) + 0
    p_TH_prior = beta(a_beta_dist, b_beta_dist, theta_vals)
    p_TH_prior /= p_TH_prior.sum()
    return p_TH_prior
    
def logitspace(a, b, n, eps=1e-3):
    if n % 2 == 0: n += 1
    x = np.concatenate([a + np.geomspace(eps, (b - a)/2, (n + 1) // 2), b - np.geomspace(eps*b, (b - a)/2, (n + 1) // 2)[:-1][::-1]])
    return x

def mystep(ax, x,y, where='post', colors=None, **kwargs):
    assert where in ['post', 'pre']
    x = np.array(x)
    y = np.array(y)
    if where=='post': y_slice = y[:-1]
    if where=='pre': y_slice = y[1:]
    X = np.c_[x[:-1],x[1:],x[1:]]
    Y = np.c_[y_slice, y_slice, np.zeros_like(x[:-1])*np.nan]
    if not ax: ax=plt.gca()

    x_, y_ = X.flatten(), Y.flatten()
    if colors is not None:
        points = np.array([x_, y_]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lc = LineCollection(segments, colors=colors)

        # Set the values used for colormapping
        # lc.set_array(dydx)
        # lc.set_linewidth(2)
        line = ax.add_collection(lc)
    else:
        line = ax.plot(x_, y_, **kwargs)
    return (line,)


def save_plot(
        name,
        path=FIGPATH,
        fig=None,
        file_formats=["svg", "pdf", "png"],
        **save_args
):  
    transparent = save_args.get("transparent", True)
    for file_format in file_formats:
        if file_format == "png":
            transparent = False
            save_args.update({"dpi": 600})
        fig.savefig(path / (name + f".{file_format}"), transparent=transparent, **save_args)

from pathlib import Path
from joblib import Memory
memory = Memory(CACHE_DIR, verbose=0)

def get_manager(agent_params, mouse_task_params, manager_params, manager=None, ):
    sim_params = dict(agent_params, **mouse_task_params, **manager_params) 
    
    agent = Actor_Critic(**agent_params)
    device = manager_params['device']
    agent.load_state_dict(torch.load(agent_params['load_path'], map_location=torch.device(device)))

    if manager is None:
        empty_manager = run_simulation(mouse_task, sim_params, agent, plot_episode = False, run=False)
        data = get_data_dict(agent_params, mouse_task_params, manager_params)
        empty_manager.data.from_dictionary(data)
        empty_manager.sim.preprocess_data(empty_manager)
        manager = empty_manager
    else:
        # uses the passed manager object to store the data
        _ = get_data_dict(agent_params, mouse_task_params, manager_params, manager = manager)

    return manager

@memory.cache(ignore=['manager'], verbose=10)
def get_data_dict(agent_params, mouse_task_params, manager_params, manager=None):
    from enzyme.src.main.run_simulation import run_simulation
    from enzyme.src.mouse_task.mouse_task import mouse_task
    from enzyme.src.network.Actor_Critic import Actor_Critic
    agent = Actor_Critic(**agent_params)
    device = manager_params['device']
    agent.load_state_dict(torch.load(agent_params['load_path'], map_location=torch.device(device)))

    sim_params = dict(agent_params, **mouse_task_params, **manager_params)
    if manager is None: 
        manager = run_simulation(mouse_task, sim_params, agent, plot_episode = False)
    data = manager.data.to_dictionary()   

    keys = list(data.keys())
    purge_keys = ['backbone', 'stim','gos', 'nogos', "lick_prob", "value", "Qs", "LTM", "f_gate", "i_gate", "c_gate", "o_gate"]
    purge_keys = ["LTM", "f_gate", "i_gate", "c_gate", "o_gate"]


    for k in keys:
        v = data[k]
        print(v.dtype)
        elem = v[0]
        shape = elem.shape if isinstance(elem, torch.Tensor) else np.array(elem).shape
        dtype = elem.dtype if isinstance(elem, torch.Tensor) else np.array(elem).dtype

        if (len(shape) > 1 and False) or (k in purge_keys):
            # network weights, do not persist
            print(f"deleting {k} of shape ({len(v)}, {shape}, {dtype})")
            del data[k]
        else:
            print(f"persisting {k} of shape ({len(v)}, {shape}, {dtype})")

    return data

def notebook_cache(mem, module, **mem_kwargs):
    """
    https://stackoverflow.com/questions/75202475/joblib-persistence-across-sessions-machines/
    """
    def cache_(f):
        f.__module__ = module
        f.__qualname__ = f.__name__
        return mem.cache(f, **mem_kwargs)
    return cache_