#%%
import math
import os
import time
import sys
sys.path.append('./') 
import matplotlib.pyplot as plt
# plt.rcParams["text.usetex"] = True
import numpy as np
from functools import partial
import random
from tqdm import tqdm
import torch
import torch.distributions as dist
import functorch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from einops import rearrange, reduce, repeat
from torch.utils.tensorboard import SummaryWriter
import functools
import itertools
import argparse
import yaml
import pdb 
import io
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
import copy


import sys
sys.path.append('../')
from datasets import synthetic
from torchEFM import extended_flow_matching, utils
from torchEFM.model.models import MLP
from datasets import synthetic
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

default_path = '../results'
model_path = os.path.join(default_path, 'model.pt')


## Path Helpers
def straight_gamma(sval,cond):
    "c shoud be tensor"
    s_ = torch.tensor([sval]).to(cond.device)
    return torch.cat((s_, cond))
    #return torch.tensor([s_, cond])

def straight_gammadot(s,cond):
    gammadot_val = torch.cat([torch.tensor([1.]).to(cond.device),  torch.zeros_like(cond)]) 
    return gammadot_val
    #return torch.tensor([1,torch.zeros_like(cond)])

def straight_gammadot_yoko(s,cond):
    # print(s)
    yoko_val = torch.cat([torch.tensor([0.]).to(cond.device), torch.tensor([1]*len(cond)).to(cond.device)] ) 
    return yoko_val
    #return torch.tensor([0.,1.])

def vert_path(cvaltensor, device):
    gamma = lambda t: straight_gamma(t,cvaltensor).to(device)
    gammadot = lambda t: straight_gammadot(t,cvaltensor).to(device)
    return gamma, gammadot

def hor_path(device, direction=None): 
    gamma_yoko = lambda s: straight_gamma(1.,s.item()*direction).to(device)
    gammadot_yoko  = lambda t: straight_gammadot_yoko(1.,t.item()*direction).to(device)
    return gamma_yoko, gammadot_yoko

##Image Helpers 
def convert2array(fig):
    canvas = FigureCanvas(fig)
    canvas.draw()
    image_array = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8)
    image_array = image_array.reshape(canvas.get_width_height()[::-1] + (4,))
    image_array = image_array[:, :, [2, 1, 0]]
    return torch.tensor(image_array)


def plot_to_tensor(fig):
    """Converts a matplotlib figure to a tensor."""
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    tensor = ToTensor()(plt.imread(buf))
    plt.close(fig)  # Close the figure after conversion
    return tensor


def evaluate(mydata=None, model=None, writer=None, configs=None, device='cpu',
myloader=None,**kwargs): 

    print(f"""Using device {device}""")

    log_root = configs['savedir']
    model_name='test'
    #model_name = configs['model_path'].split('/')[-1]
    print(model_name)
    print(log_root)
    guidance = 1.3

    if writer is None:
        writer = SummaryWriter(log_dir=log_root) 

    if model is None: 
        model = torch.load(configs['model_path'], map_location=torch.device('cpu')).to(device)

    pdata = torch.stack(copy.deepcopy(mydata.pdata[:-1]))



    #Source generation and Model loading 

    efmobj = instantiate(configs.efm, device=configs.device)
    #cbdry, Loaderlist = myloader.get_DataLoaderList()
    xbdry = rearrange(copy.deepcopy(pdata), 'c b x -> b 1 c x').to(device) 
    cbdry = mydata.c_list.to(device)

    if configs.rival == 'guided': 
        cbdry = cbdry[:-1]

    pre_c_list = torch.tensor([[0.], [0.25], [0.5], [0.75], [1.], [2.]]).to(device)
    cset = torch.tensor(list(itertools.product(pre_c_list, repeat=configs.data.c_dim)))
    cset = torch.cat((cset, torch.zeros(len(cset),1)), dim=1).to(device)  


    #Set on which to evaluate.
    
    m = torch.distributions.multivariate_normal.MultivariateNormal(
                                torch.zeros(mydata.x_dim), torch.eye(mydata.x_dim))    
    lamb = guidance


    image_tensors = [] 

    print(f"""Guidance = {lamb}""" )
    for k in range(len(cset)):

        if hasattr(efmobj, 'lm'):
            efmobj.lm.create_regmats(copy.deepcopy(xbdry), cbdry) 
            xtensor_mean, xtensor = efmobj.lm.compute_regression(cset) 
            xtensor_mean = xtensor_mean.to(device)
            source = m.sample((1024,)).to(device) + xtensor_mean[0,0, [k]] #source0  (I think the original code was multiplying the matrix other way around
        else:
            source = m.sample((1024,)).to(device) 

        print(f"""reporting c={cset[k]} """) 
        gamma, gammadot = vert_path(cset[k], device=device) 
        node0 = NeuralODE(utils.torch_wrapper(model,gamma,gammadot, guidance=lamb),
                            solver="dopri5", sensitivity="adjoint",
                            atol=1e-4, rtol=1e-4)
        
        #source = m.sample((1024,)).to(device) + xtensor_mean[0,0, [k]] #source0  (I think the original code was multiplying the matrix other way around
        #source = m.sample((1024,)).to(device) 
        with torch.no_grad():
            traj0 = node0.trajectory(
                source,
                t_span=torch.linspace(0, 1, 100),
            )
            cval = cset[k].to('cpu')
            fig0 = utils.plot_trajectories(traj0.cpu(), returnFig=True,title=f"""$c={cval}$,$\lambda={lamb}$""")
        image_tensors.append(plot_to_tensor(fig0))

    grid_tensor = make_grid(image_tensors, nrow=len(pre_c_list))
    writer.add_image(model_name, grid_tensor)
    #writer.add_figure(model_name,figyoko,3,close=True)



if __name__ == '__main__':

    '''
    Usage : 
    
    python evaluate.py --model_path=$MYMODELPATH --logdir=$MYLOGDIR

    View the results by 
    tensorboard --logdir=$MYLOGDIR 
    
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--logdir', type=str, default="./" )
    parser.add_argument('--config_path', type=str, default='./configs/basic_2d_new.yaml')
    parser.add_argument('--model_path', type=str, default=model_path) 
    parser.add_argument('--device', default='cpu')
    parser.add_argument('-a', '--attrs', nargs='*', default=())
    parser.add_argument('-w', '--warning', action='store_true')
    args = parser.parse_args()

    with open(args.config_path, 'r') as f:
        config = yaml.safe_load(f)
    config['config_path'] = args.config_path
    config['model_path'] = args.model_path
    config['logdir'] =  args.logdir

    # Modify the yaml file using attr
    for attr in args.attrs:
        module, new_value = attr.split('=')
        keys = module.split('.')
        target = functools.reduce(dict.__getitem__, keys[:-1], config)
        if keys[-1] in target.keys():
            target[keys[-1]] = yaml.safe_load(new_value)
        else:
            raise ValueError('The following key is not defined in the config file:{}', keys)

    config = DictConfig(config)

    mydata = instantiate(config.data)
    device=args.device
    myloader = instantiate(config.loader, dataset=mydata)

    model = torch.load(config.model_path, map_location=torch.device('cpu')).to(device)

    evaluate(mydata=mydata, model=model, writer=None, configs=config, myloader=myloader)
