import time
import torch
import sys
sys.path.append('../')
sys.path.append('../torchEFM')
sys.path.append('../configs')
sys.path.append('../utils')
sys.path.append('../datasets')

from datasets import synthetic
from matplotlib import pyplot as plt
from importlib import reload
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
import datasets
from datasets import multidat
import yaml
import numpy as np
import random
from torchEFM.model.models_new import FlowNet
from torchEFM import extended_flow_matching as efm
from einops import rearrange, reduce, repeat, einsum
from torchEFM import utils as ut
import pdb
import copy
import tqdm
import time
from tqdm import tqdm
import argparse
from torch.utils.tensorboard import SummaryWriter
import os
# from utils import evaluation as ev
import importlib
import datasets
from torch.profiler import profile, record_function



parser = argparse.ArgumentParser()
parser.add_argument("--modelname", type=str, default="output")
# Parse known args
args, remaining_sys_argv = parser.parse_known_args()
# Overwrite sys.argv with the remaining arguments for Hydra to use
sys.argv = [sys.argv[0]] + remaining_sys_argv


def trace_handler(prof):
    print(prof.key_averages().table(
        sort_by="cpu_time_total", row_limit=-1))

'''
Usage (CAUTION: config-name, not config_name)
e.g.> python -u run.py --config-name=debug_MNISTWAE
'''
@hydra.main(version_base=None, config_path='./configs', config_name='hoge')
def main(config: DictConfig): 
    
    print(config)
    # Set the seed
    np.random.seed(config.seed)
    random.seed(config.seed)
    torch.manual_seed(config.seed)
    device = torch.device(config.device if torch.cuda.is_available() else "cpu")

    myeval = importlib.import_module(config.evaluation._target_)

    #Load the data
    mydata = instantiate(config.data,rival=config.rival)
    mydata.prepare_tcsampler(config.T, device=config.device, samplesize=config.loader.num_samplec)
    #dataloader for the multidat
    myloader = instantiate(config.loader, dataset=mydata)
    myloader.prepare_LoaderSampler(device=config.device)
    print('myloader completed')

    #Create the model
    mynet = instantiate(config.model,conditional_dim=mydata.c_dim).to(device)
    model = FlowNet(mynet, config).to(device)       
    # parameter count
    params = 0
    for p in model.parameters():
        if p.requires_grad:
            params += p.numel()
            
    print(f'model completed, # of params:{params}')

    #set up the optimizer
    optimizer = instantiate(config.optimizer, params=model.parameters())

    #Sampling cset and tset. Fixed for now.
    #cset, tset = mydata.create_tcsets(config.T, config.Csamples, device)

    #psi creator
    efmobj = instantiate(config.efm, batch=config.batch_size, device=config.device)


    writerlocation = os.path.join(config.savedir, args.modelname)
    writer = SummaryWriter(writerlocation)

    losses = []
    timenow = time.time() 

        #Save the configs
    with open(f"{writerlocation}/config.yaml", 'w') as file:
        #config_copy = copy.deepcopy(dict(config))
        OmegaConf.save(config, file)
        print(f"""Config saved at  {f"{writerlocation}/config.yaml"} """ ) 

        # for key, value in config.items():
        #     print(f"Key: {key}, Value: {value}, Type: {type(value)}")

        # yaml.dump(config_copy, file)

    for k in tqdm(range(config.iter)): 

        optimizer.zero_grad()   

        ## Paths require cbdy, xbdy and source.

        if 'Sync' in  str(type(myloader)): 
            cbdry, sample_c_idx = myloader.get_DataLoaderList()
            xbdries = next(iter(myloader.DataLoaderList[0]))
            xbdry = xbdries[:, :, sample_c_idx]
        else:
            #sample cs and xs from which to define the psi paths. i.e number of interpolated points.
            cbdry, Loaderlist = myloader.get_DataLoaderList()
            xbdries = next(zip(*Loaderlist))        
            xbdries = list(xbdries)
            xbdry = rearrange(torch.stack(xbdries), 'c b x -> b 1 c x')  

        #send the psi conditioning variables to device             
        cbdry = cbdry.to(device)
        xbdry = xbdry.to(device)
        cset, tset = mydata.create_tcsets(config.T, cbdry, config.Csamples, device)

        # #Create the source that is configured to the Extended flow matcher
        source = efmobj.create_source(config, mydata)
        source = source.to(device)  #batch numsamplec dimc dimx     
        #create the psi function conditioned on the above boundary conditions
        efmobj.create_psi(xbdry, cbdry, source)

        # #sample xi
        psi_xi = efmobj.sample_psi_tc(tset, cset, source, prepareForDeriv=True) #batch x nT x nC x dimX

        # #PMA field (to be used for supervision)
        # #b nT nC dimX (1 + dimC)   
        jacobian_val = efmobj.compute_jacobian(tset, cset)

        jacobian_tobe_matched = rearrange(jacobian_val, 'b nT nC x d -> (b nT nC) x d')    
        # #matching field (learner)
  
        model_input = efmobj.create_model_input(psi_xi, tset, cset)  #(b nT nC) x (dimX + (1+ dimC))
        estimated_field = model(model_input)   # (b nT nC) x d  ; d = dimC+1
# 
        #compute the Loss
        loss = ut.lp(jacobian_tobe_matched - estimated_field, p=2)      
        if k % config.report_freq == 0:
            writer.add_scalar('training loss', loss.detach().item(), k)
            if config.debug == 1: 
                print(loss.item())

        if k % config.eval_freq == 0 and k>9: 

            delta = time.time() - timenow
            print('\n \n')
            print(f"""{k}th Iteration Complete.  {config.eval_freq /delta} iters per sec. Evaluating...""")              
            myeval.evaluate(mydata=mydata, model=model, writer=writer, configs=config, device=device, myloader=myloader,n_iter=k)
            timenow = time.time() 

            #if k % int(config.snapshot) == 0 and k > 0: 
            print(f"""Model saved at  {f"{writerlocation}/model.pt"} """ ) 
            torch.save(model, f"{writerlocation}/model.pt")

            
        loss.backward()    
        optimizer.step()

#         prof.step()

    # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

        
    torch.save(model, f"{writerlocation}/model.pt")


    print(f"""Save succeeded in {writerlocation}""")
    myeval.evaluate(mydata=mydata, model=model, writer=writer, configs=config, device=device, myloader=myloader, n_iter=None)

    writer.close()



if __name__ == '__main__':

    main()





