import itertools
import sys
import numpy as np
import os
import copy
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.ops import MLP
from blackbox_models import *
from mech_models import *
from models import *
from dags import *
from utils_g import *
from hyper_param_settings import *
torch.set_default_dtype(torch.float64)
device=None
#comment this out if not using GPU


def gen_syn_data(seed,exp=1,train_size=100):
    rng=np.random.default_rng(seed=seed)
    n=60
    t=np.linspace(1,n,n).reshape(n,1)
    data=[]
    for k in range(train_size):
        x=[(i+1)/100*np.exp(1-t/n/10/(i+1))+rng.normal(0,0.5,(n,1)) for i in range(1)]
        for i in range(6):
            x.append(rng.normal(0,0.5,(n,1)))
        x=np.concatenate(x,axis=1)
        dt=5e-2
        s1=[0]
        s2=[0]
        s3=[0]
        v=[0]
        for i in range(n):
            if exp==1:
                v.append(v[-1]+dt*(4*x[i,0]-0.5*(v[-1]-1)))
            else:
                v.append(v[-1]+dt*(4*x[i,0]-4e-1*x[i,1]+4e-2*x[i,2]-4e-3*x[i,3]+4e+4*x[i,4]-4e-5*x[i,5]-0.5*(v[-1]-1)))
        sample=np.concatenate([np.reshape(v[1:],(n,1)),x],axis=-1)
        data.append(sample)
    cases=np.array(data)
    noise=rng.standard_normal(size=cases.shape,dtype='float64')
    cases=np.concatenate([cases,noise[:,:,:1]],axis=-1)
    perms=np.zeros((1,cases.shape[0]),dtype='int32')
    perms[0]=rng.permutation(cases.shape[0])
    print(cases.shape)
    return [perms, cases]
    


def main(args, repeats=40, N=1, V=3):
    #choose GPU
    sz=int(args[5])
    epoch=int(args[4])
    GPU_ID=args[3]
    start_g=args[2]
    model_name=args[1]
    exp_type=args[0]
    r_mode=model_name[model_name.find('_')+1:]
    print(r_mode)
    
    device = torch.device('cuda:'+str(GPU_ID) if torch.cuda.is_available() else 'cpu')
    print(device)
    
    
    dag=syn_dags[start_g]
    edge_map=syn_edge_maps[start_g]

    for repeat in range(repeats):
        perms, cases=gen_syn_data(2024+repeat, int(exp_type[-1]), sz)

        for train_split in range(N):
            print(f"repeat {repeat}")
            #hyperparam tuning with cv
            best_score=-1e5 #negative mse as score
            list_param_dicts=list(hyper_param_dicts[model_name])
            best_hyper_param=list_param_dicts[0]
            for i in range(len(list_param_dicts)):
                hyper_param=list_param_dicts[i]
                score=0
                for val_split in range(V):
                    torch.manual_seed(2024)
                    train,val,test,train_mean,train_std=cv_split2(perms, cases, repeat,\
                                                                  train_split, val_split, 0,\
                                                                  N, V, batch_size=len(cases))
                    if start_g=='no_graph':
                        if r_mode!='TS':
                            model=globals()[model_name](hyper_params=hyper_param)
                        else:
                            model=globals()[model_name](hyper_params=hyper_param,device=device)                            
                    
                    else:
                        model=globals()[model_name](DAG=dag,edge_map=edge_map,hyper_params=hyper_param)
                    if epoch==1:
                        #pass
                        print(sum(p.numel() for p in model.parameters()))
                    train_h,val_h=train_model(model, train, val, epochs=epoch, hyper_params=hyper_param,\
                                              train_std=train_std, device=device, verbose=False,\
                                              r_mode=r_mode)
                    score-=np.min(val_h)
                #print(f"a1 {hyper_param['a1']} a2{hyper_param['a2']} lr {hyper_param['lr']}")
                #print(score)
                if score>best_score:
                    best_hyper_param=hyper_param
                    best_score=score
            #print(best_hyper_param)
            
            #re-train with best hyper_param
            print(f"best_hyper{best_hyper_param}")
            np.savez(f"final_{sz}/{model_name}_{exp_type}_lcp _best_hyper_param_{repeat}_{train_split}.npz", **best_hyper_param)
            torch.manual_seed(2024)
            train,val,test,train_mean,train_std=cv_split2(perms, cases, repeat,\
                                                          train_split, 0, 0,\
                                                          N, V, batch_size=len(cases))
            if start_g=='no_graph':
                if r_mode!='TS':
                    model=globals()[model_name](hyper_params=best_hyper_param)
                else:
                    model=globals()[model_name](hyper_params=best_hyper_param,device=device)   
            else:
                model=globals()[model_name](DAG=dag,edge_map=edge_map,hyper_params=best_hyper_param)
            #print(sum(p.numel() for p in model.parameters()))
            train_h,val_h=train_model(model, train, val, epochs=epoch, hyper_params=hyper_param,\
                                      train_std=train_std, device=device, verbose=False,\
                                      path=f"final_{sz}/{model_name}_{exp_type}_{start_g}_{repeat}_{train_split}.pth",\
                                      r_mode=r_mode)
            
            


if __name__=='__main__':
    args=sys.argv[1:]
    # exp experiment_type model_name starting_graph GPU_ID
    main(args)
