
#https://pythonot.github.io/index.html#installation

#%%
import math
import os
import time
import sys
sys.path.append('./') 
import matplotlib.pyplot as plt
# plt.rcParams["text.usetex"] = True
import numpy as np
from functools import partial
import random
from tqdm import tqdm
import torch
import torch.distributions as dist
import functorch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from einops import rearrange, reduce, repeat
from torch.utils.tensorboard import SummaryWriter
import functools
import itertools
import argparse
import yaml
import pdb 
import io
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
import copy
from importlib import reload
import numpy as np
import hydra


import sys
sys.path.append('../')
from datasets import synthetic
from torchEFM import extended_flow_matching, utils
from torchEFM.model.models import MLP
from datasets import synthetic
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import importlib
import pickle


@hydra.main(version_base=None, config_path='../configs', config_name='20240116_results')
def main(config: DictConfig):
    savepath = os.path.join(config.resultdir, config.config_name) 
    os.makedirs(savepath, exist_ok=True)         
    evaluate(config, savepath) 


def evaluate(config: DictConfig, savepath=None, evallist=None): 

    figout = config.figout
    myeval = importlib.import_module(config.evaluation._target_)

    comparelist = config.comparisons.split(',')

    querylist =  parse_querykey(config.comparekey)

    targdirs = [os.path.join(config.resultdir, comparelist[k]) for k in range(len(comparelist))]
    
    queried = [filterout(querylist[k], comparelist[k], config.resultdir) for k in range(len(comparelist))]  


    guidancelist = str(config.guidance).split(',') 
    

    results = {}
    for k in range(len(comparelist)):
        
        targname = comparelist[k]
        targyaml = retrieve_yaml(config.resultdir, targname)
            
        if 'guided' == targyaml['rival']: 
            for guide in guidancelist:
                expname = queried[k]
                evalclist = torch.tensor([float(item) for item in config.evalclist.split(',')]).reshape(-1, 1)
                resultsX = myeval.evaluate_dir(targroot=config.resultdir, 
                targname=targname, expname=expname, eval_clist=evalclist, 
                guidance=float(guide), evalmode=config.evalmode, 
                evalsize=config.evalsize)
                read_results(results, resultsX, figout=figout, naming=targname + str(guide)) 

        else:
            expname = queried[k]
            evalclist = torch.tensor([float(item) for item in config.evalclist.split(',')]).reshape(-1, 1)
            resultsX = myeval.evaluate_dir(targroot=config.resultdir, 
            targname=targname, expname=expname, eval_clist=evalclist, 
            guidance=None, evalmode=config.evalmode, 
            evalsize=config.evalsize)

            if targname in results.keys():
                naming = targname + str(k)
            else:
                naming = targname
            read_results(results, resultsX, figout=figout, naming=naming) 

    if figout == 0:
        with open(os.path.join(savepath, 'result.pkl'), 'wb') as f:
            pickle.dump(results, f)   
    else:
        return results

def retrieve_yaml(root, targname):
    targdir = os.path.join(root, targname)
    targs = os.listdir(targdir)
    targpath = os.path.join(targdir, targs[0], 'config.yaml')
    with open(targpath, 'r') as f:
        config = yaml.safe_load(f)
    return config   

    
def read_results(results, resultsX, figout=0, naming=None):
    targname = naming
    if figout == 0:
        results[targname] = cleandict(resultsX, ['terminal_ver', 'terminal_hor', 'gt'])
    else:
        results[targname] = resultsX
        print("Outputting the generated outputs.")
    print(len(results))


def cleandict(mydict, toremove):
    for key in toremove:
        if key in mydict.keys():
            mydict.pop(key)
    return mydict

def filterout(querylist, dirname, root):
    dirpath = os.path.join(root, dirname) 
    dirlist = os.listdir(dirpath)
    filtered = np.array([arein(querylist, dirlist[k]) for k in range(len(dirlist))])
    filtered_idx = np.where(filtered)[0] 
    if len(filtered_idx)>1:
        print('!!!!!Warning!!!! More than one directory is filtered. Taking the last element')
        for idx in filtered_idx:
            print(f"""{idx}: {dirlist[idx]}""" )
        print('+'*20)
    elif len(filtered_idx) == 0:
        print(f""" NO MATCH FOUND FOR {querylist} on {dirname}. Make sure queries are aligned.""")
        print(dirlist) 
        pdb.set_trace()
        raise NotImplementedError 
    return dirlist[filtered_idx[0]]


def arein(querylist, targname): 
    mybool = np.array([querylist[k] in targname for k in range(len(querylist))])
    answer = np.all(mybool)
    return answer

def parse_querykey(querystr):
    querystr = remove_keys_from_string(querystr, '()') 
    querylist = querystr.split(',')
    querylist = [querylist[k].split(' and ') for k in range(len(querylist))]
    return querylist
    
def remove_keys_from_string(s, keys):
    """
    Remove multiple characters (keys) from a string.

    :param s: The original string.
    :param keys: A string containing all characters to be removed.
    :return: A new string with the specified characters removed.
    """
    # Create a translation table with None for all characters to be removed
    translation_table = str.maketrans('', '', keys)
    
    # Translate the string using the translation table
    return s.translate(translation_table)


if __name__=='__main__':
    main()