#%%
import math
import os
import time
import sys
sys.path.append('./') 
import matplotlib.pyplot as plt
# plt.rcParams["text.usetex"] = True
import numpy as np
from functools import partial
import random
from tqdm import tqdm
import torch
import torch.distributions as dist
import functorch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from einops import rearrange, reduce, repeat
from torch.utils.tensorboard import SummaryWriter
import functools
import itertools
import argparse
import yaml
import pdb 
import io
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
import copy
import ot

import sys
sys.path.append('../')
from datasets import synthetic
from torchEFM import extended_flow_matching, utils
from torchEFM.model.models import MLP
from datasets import synthetic
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import torch.distributed as dist


default_path = './'
model_path = os.path.join(default_path, './model.pt')


## Path Helpers
def straight_gamma(sval,cond):
    "c shoud be tensor"
    s_ = torch.tensor([sval]).to(cond.device)
    return torch.cat((s_, cond))
    #return torch.tensor([s_, cond])

def straight_gammadot(s,cond):
    gammadot_val = torch.cat([torch.tensor([1.]).to(cond.device),  torch.zeros_like(cond)]) 
    return gammadot_val
    #return torch.tensor([1,torch.zeros_like(cond)])

def straight_gammadot_yoko(s,cond):
    # print(s)
    yoko_val = torch.cat([torch.tensor([0.]).to(cond.device), torch.tensor([1]*len(cond)).to(cond.device)] ) 
    return yoko_val
    #return torch.tensor([0.,1.])

def vert_path(cvaltensor, device):
    gamma = lambda t: straight_gamma(t,cvaltensor).to(device)
    gammadot = lambda t: straight_gammadot(t,cvaltensor).to(device)
    return gamma, gammadot

def hor_path(device, direction=None): 
    gamma_yoko = lambda s: straight_gamma(1.,s.item()*direction).to(device)
    gammadot_yoko  = lambda t: straight_gammadot_yoko(1.,t.item()*direction).to(device)
    return gamma_yoko, gammadot_yoko

##Image Helpers 
def convert2array(fig):
    canvas = FigureCanvas(fig)
    canvas.draw()
    image_array = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8)
    image_array = image_array.reshape(canvas.get_width_height()[::-1] + (4,))
    image_array = image_array[:, :, [2, 1, 0]]
    return torch.tensor(image_array)


def plot_to_tensor(fig):
    """Converts a matplotlib figure to a tensor."""
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    tensor = ToTensor()(plt.imread(buf))
    plt.close(fig)  # Close the figure after conversion
    return tensor


def evaluate(mydata=None, model=None, writer=None, configs=None, device='cpu',
myloader=None, eval_clist=None, guidance=None, evalmode=None, **kwargs): 

    print(f"""Using device {device}""")
    result_dict = {} 

    log_root = configs['savedir']
    model_name='test'
    #model_name = configs['model_path'].split('/')[-1]
    print(model_name)
    print(log_root)

    if writer is None:
        writer = SummaryWriter(log_dir=log_root) 
    elif writer == 'return':
        pass

    if model is None: 

        model = torch.load(configs['model_path'], map_location=torch.device('cpu')).to(device)

    if type(mydata.pdata)==torch.tensor:
        pdata = copy.deepcopy(mydata.pdata).to(device)
    else:
        lens = np.array([len(mydata.pdata[k]) for k in range(len(mydata.pdata))])
        pdata = [mydata.pdata[k][:np.min(lens)] for k in range(len(mydata.pdata))]
        pdata = torch.stack(pdata).to(device)

    model.eval()
    #Source generation and Model loading 
    efmobj = instantiate(configs.efm, device=device, batch=configs.batch_size)
    #cbdry, Loaderlist = myloader.get_DataLoaderList()
    xbdry = rearrange(copy.deepcopy(pdata), 'c b x -> b 1 c x').to(device) 
    cbdry = mydata.c_list.to(device)
    if hasattr(efmobj, 'lm'):
        efmobj.lm.create_regmats(copy.deepcopy(xbdry), cbdry) 

    #Set on which to evaluate.
    if eval_clist is None: 
        pre_c_list = torch.tensor([[0.], [0.25], [0.5], [0.75], [1.],[2.]])
    else:
        pre_c_list = torch.tensor(eval_clist)
    cset = torch.tensor(list(itertools.product(pre_c_list, repeat=configs.data.c_dim))) 


    image_tensors = [] 

    normalvar = torch.distributions.multivariate_normal.MultivariateNormal(
                                torch.zeros(mydata.x_dim), torch.eye(mydata.x_dim)) 
    
    if configs.rival == 'guided': 
        configs.data.rival = 'guided'
        gtdata = instantiate(configs.data, c_preset=cset)
        gt_dict = {} 
        #remove  [0 0 ... 1] and rewrite the key
        for k in range(len(gtdata.c_list)):
            if gtdata.c_list[k][-1] == 1:
                pass 
            else:
                gt_dict[tuple(gtdata.c_list[k].numpy())] = gtdata.pdata[k] 
    else:
        gtdata = instantiate(configs.data, c_preset=cset)        
        gt_dict = {tuple(gtdata.c_list[k].numpy()):gtdata.pdata[k] for k in range(len(gtdata.c_list))} 
    w2_result = []

    terminal_results_vert = process_vertical(efmobj=efmobj, cset=cset, model=model, 
     vert_path=vert_path, image_tensors=image_tensors, xbdry=xbdry, cbdry=cbdry, m=normalvar,
     device=device, configs=configs, eval_size=len(gtdata.pdata[0]), guidance=guidance)
     
    result_dict['terminal_ver'] =  terminal_results_vert
    print(f"""Evaluating with {evalmode}""" )
    for tuplekey in gt_dict.keys():

        gt_targ = copy.deepcopy(gt_dict[tuplekey] )
        pred_targ = copy.deepcopy(terminal_results_vert[tuplekey])

        n, d = pred_targ.shape
        a, b = torch.ones((n,)) / n, torch.ones((n,)) / n  # uniform distribution on samples
        if evalmode == 'W1': 
            M = ot.dist(gt_targ, pred_targ, metric='euclidean').detach()
        elif evalmode == 'W2':
            M = ot.dist(gt_targ, pred_targ).detach()
        w2_result.append(ot.emd2(a, b, M).item()) 

    if 'extended_flow_matching' in configs.efm._target_:
        terminal_results_hor = process_horizontal(efmobj=efmobj, cset=cset, model=model, 
        hor_path=hor_path, image_tensors=image_tensors, xbdry=xbdry, cbdry=cbdry, m=normalvar, mydata=mydata,
        device=device)
        result_dict['terminal_hor'] =  terminal_results_hor        
        trf_result = [] 
        for tuplekey in terminal_results_hor.keys():
            if type(tuplekey) != str:
                gt_targ = gt_dict[tuplekey] 
                pred_targ = terminal_results_hor[tuplekey]
                
                n, d = pred_targ.shape
                a, b = torch.ones((n,)) / n, torch.ones((n,)) / n  # uniform distribution on samples
                M = ot.dist(gt_targ, pred_targ).detach()
                trf_result.append(ot.emd2(a, b, M).item()) 
            elif tuplekey == 'path':
                result_dict['path'] = terminal_results_hor['path']
            else:
                pdb.set_trace()
                raise NotImplementedError

        result_dict['trf'] = (list(terminal_results_hor.keys()), trf_result)
    else:
        print('Transfer not supported') 
    result_dict['gt'] = gt_dict
    result_dict['gen'] = (list(gt_dict.keys()), w2_result)


    grid_tensor = make_grid(image_tensors, nrow=len(pre_c_list))
    
    #print((gtdata.c_list, w2_result), (terminal_results_hor.keys(), trf_result))

    if writer == 'return':
        pass
    else: 
        writer.add_image(model_name, grid_tensor)

    
    return result_dict
    #return terminal_results_vert, terminal_results_hor, (gtdata.c_list, w2_result), (terminal_results_hor.keys(), trf_result)

    #writer.add_figure(model_name,figyoko,3,close=True)

def process_vertical(efmobj=None, cset=None, 
model=None, image_tensors=None, vert_path=None, xbdry=None, cbdry=None,
 m=None, device=None, configs=None, eval_size=None, guidance=None):

    if 'guided'  != configs.rival:
        guidance = None
    else:
        print(f"""Guidance Mode Activated with Guidance={guidance}""")
        guidance = guidance
        cset = torch.cat((cset, torch.zeros(len(cset),1)), dim=1).to(device)  

    if eval_size is not None:
        eval_size = eval_size
    else:
        eval_size = 1024
    print(f"""Evaluating with {eval_size} """ )

    cset = cset.to(device)
    terminal_results_vert = {} 
    for k in range(len(cset)):

        if hasattr(efmobj, 'lm'):
            efmobj.lm.create_regmats(copy.deepcopy(xbdry), cbdry) 
            xtensor_mean, xtensor = efmobj.lm.compute_regression(cset) 
            source = m.sample((eval_size,)).to(device) + xtensor_mean[0,0, [k]] #source0  (I think the original code was multiplying the matrix other way around
        else:
            source = m.sample((eval_size,)).to(device)

        print(f"""reporting c={cset[k]} """) 
        gamma, gammadot = vert_path(cset[k], device=device) 

        derivatives=  utils.torch_wrapper(model,gamma,gammadot, guidance=guidance)
        node0 = NeuralODE(derivatives,
                            solver="dopri5", sensitivity="adjoint",
                            atol=1e-4, rtol=1e-4)
        
        with torch.no_grad():

            traj0 = node0.trajectory(
                source,
                t_span=torch.linspace(0, 1, 100),
            )
            cval = cset[k].to('cpu')
            fig0 = utils.plot_trajectories(traj0.cpu(), returnFig=True,title=f'$c={cval}$')
            terminal_results_vert[cval] = traj0.detach().cpu()[-1]
        image_tensors.append(plot_to_tensor(fig0))
    terminal_results_vert = {tuple(key.numpy()): value for key, value in terminal_results_vert.items()}
    return terminal_results_vert

def process_horizontal(efmobj=None, cset=None, model=None,  image_tensors=None, hor_path=None, xbdry=None, cbdry=None,m=None, device=None, 
mydata=None, eval_size=None):
    direction = torch.ones(mydata.c_dim).to(device).float()
    gamma_yoko, gammadot_yoko = hor_path(device=device, direction=direction) 


    terminal_results_hor = {} 

    node_yoko = NeuralODE(utils.torch_wrapper(model,gamma_yoko,gammadot_yoko),
                    solver="dopri5", sensitivity="adjoint",
                    atol=1e-4, rtol=1e-4)
    with torch.no_grad():
        traj_yoko = node_yoko.trajectory(
        mydata.pdata[0].to(device), t_span=torch.linspace(0, 1, 100),
        )
        figyoko = utils.plot_trajectories(traj_yoko.cpu(),returnFig=True) 
        terminal_results_hor[(0,0)] = traj_yoko.detach().cpu()[0]
        terminal_results_hor[tuple(direction.to('cpu').numpy()) ] = traj_yoko.detach().cpu()[-1]
        terminal_results_hor['path'] = traj_yoko.detach().cpu()
       
    image_tensors.append(plot_to_tensor(figyoko))
    return terminal_results_hor


def evaluate_dir(targroot='', targname='', expname='', eval_clist=None, device=0, guidance=None, evalmode=None, evalsize=None):
    

    targdir=os.path.join(targroot, targname) 

    if len(expname) == 0:
        dirlist = os.listdir(targdir)
        expname=dirlist[0]
        print(f"""Using {expname}""")
    else:
        expname = expname
    
    resultdir = os.path.join(targdir, expname)
    configpath = os.path.join(resultdir, 'config.yaml')
    with open(configpath, 'r') as f:
        config = yaml.safe_load(f)
        if evalsize is not None:
            config['data']['data_num'] = evalsize
    config['model_path'] = os.path.join(resultdir, 'model.pt')
    config= OmegaConf.create(config)
    print(f"""Loaded {configpath}""" )

    

    mydata = instantiate(config.data)
    myloader = instantiate(config.loader, dataset=mydata)
    try:
        model = torch.load(config.model_path, map_location=torch.device('cpu')).to(device)
    except:
        print("Possibly detecting the data parallel mode")
        dist.init_process_group(backend='gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)
        model = torch.load(config.model_path, map_location=torch.device('cpu')).to(device)
        dist.destroy_process_group()

    print(f"""Loaded {config.model_path}""" )


    print("Test configuration complete. Evaluating...")
    dir_result = evaluate(mydata=mydata, model=model, writer='return', configs=config, myloader=myloader, device=device,
    eval_clist = eval_clist, guidance=guidance, evalmode=evalmode)

    return dir_result

if __name__ == '__main__':

    '''
    Usage : 
    
    python evaluate.py --model_path=$MYMODELPATH --logdir=$MYLOGDIR

    View the results by 
    tensorboard --logdir=$MYLOGDIR 
    
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--targroot', type=str, default='../')
    parser.add_argument('--targname', type=str, default='')
    parser.add_argument('--expname', type=str, default='')

    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('-a', '--attrs', nargs='*', default=())
    parser.add_argument('-w', '--warning', action='store_true')
    args = parser.parse_args()

    # resultdir = os.path.join(targdir, expname)
    # configpath = os.path.join(resultdir, 'config.yaml')
    device=args.device

    evaluate_dir(targroot=args.targroot, targname=args.targname, expname=args.expname, eval_clist=[[0.],[0.5],[1.],[2.]], device=args.device)
    # with open(configpath, 'r') as f:
    #     config = yaml.safe_load(f)
    # config['model_path'] = os.path.join(resultdir, 'model.pt')
    # config= OmegaConf.create(config)
    

    # mydata = instantiate(config.data)
    # myloader = instantiate(config.loader, dataset=mydata)
    # model = torch.load(config.model_path, map_location=torch.device('cpu')).to(device)

    # print("Test configuration complete. Evaluating...")
    # evaluate(mydata=mydata, model=model, writer='return', configs=config, myloader=myloader, device=device,
    # eval_clist = [[0.5],[2.]])
