import copy
import itertools
import sys
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.ops import MLP
from blackbox_models import *
from mech_models import *
from models import *
from dags import *
from utils_g import *
from hyper_param_settings import *
torch.set_default_dtype(torch.float64)
device=None
#comment this out if not using GPU

def load_real_data(repeats,exp=1):
    
    cases=np.load("mr_cases.npy")
    cases=cases[:,0]
    if exp==2:
        cases=np.load("mr_cases_corrupted.npy")
    rng=np.random.default_rng(seed=2024)
    noise=rng.standard_normal(size=cases.shape,dtype='float64')
    cases=np.concatenate([cases,noise[:,:,:1]],axis=-1)
    #generate permutations
    perms=np.zeros((repeats,cases.shape[0]),dtype='int32')
    for i in range(repeats):
        perms[i]=rng.permutation(cases.shape[0])
    print(cases.shape)
    return [perms, cases]
    


def main(args, repeats=10, N=4, V=3):
    #choose GPU
    refit=int(args[5])
    epoch=int(args[4])
    GPU_ID=args[3]
    start_g=args[2]
    model_name=args[1]
    exp_type=args[0]
    r_mode=model_name[model_name.find('_')+1:]
    print(r_mode)
    
    device = torch.device('cuda:'+str(GPU_ID) if torch.cuda.is_available() else 'cpu')
    print(device)
    
    #load and prep data
    perms, cases=load_real_data(repeats,int(exp_type[-1]))
    dag=real_dags[start_g]
    edge_map=real_edge_maps[start_g]


    for repeat in range(repeats):
        seed=2024+repeat
        if repeat==1:
            n=0
            nn=1
        elif repeat==0:
            n=3
            nn=4
        else:
            n=2
            nn=4
        for train_split in range(N):
            print(f"repeat {repeat} train_split {train_split}")
            #hyperparam tuning with cv
            best_score=-1e5 #negative mse as score
            list_param_dicts=hyper_param_dicts[model_name]
            best_hyper_param=list_param_dicts[0]
            for i in range(len(list_param_dicts)):
                hyper_param=list_param_dicts[i]
                score=0
                for val_split in range(V):
                    torch.manual_seed(seed)
                    train,val,test,train_mean,train_std=cv_split2(perms, cases, repeat,\
                                                                  train_split, val_split, 0,\
                                                                  N, V, batch_size=len(cases))
                    if start_g=='no_graph':
                        if r_mode!='TS':
                            model=globals()[model_name](hyper_params=hyper_param)
                        else:
                            model=globals()[model_name](hyper_params=hyper_param,device=device)                            
                    
                    else:
                        model=globals()[model_name](DAG=dag,edge_map=edge_map,hyper_params=hyper_param)
                    if epoch==1:
                        print(sum(p.numel() for p in model.parameters()))
                    train_h,val_h=train_model(model, train, val, epochs=epoch, hyper_params=hyper_param,\
                                              train_std=train_std, device=device, verbose=False,\
                                              r_mode=r_mode)
                    score-=np.min(val_h)
                if score>best_score:
                    best_hyper_param=hyper_param
                    best_score=score
            #print(best_hyper_param)
            
            #re-train with best hyper_param
            np.savez(f"{model_name}_{exp_type}rep10_best_hyper_param_{repeat}_{train_split}.npz", **best_hyper_param)
            torch.manual_seed(seed)
            train,val,test,train_mean,train_std=cv_split2(perms, cases, repeat,\
                                                          train_split, 0, 0,\
                                                          N, V, batch_size=len(cases))
            if start_g=='no_graph':
                if r_mode!='TS':
                    model=globals()[model_name](hyper_params=best_hyper_param)
                else:
                    model=globals()[model_name](hyper_params=best_hyper_param,device=device)   
            else:
                model=globals()[model_name](DAG=dag,edge_map=edge_map,hyper_params=best_hyper_param)
            #print(sum(p.numel() for p in model.parameters()))
            train_h,val_h=train_model(model, train, val, epochs=epoch, hyper_params=hyper_param,\
                                      train_std=train_std, device=device, verbose=False,\
                                      path=f"{model_name}_{exp_type}rep10_{start_g}_{repeat}_{train_split}.pth",\
                                      r_mode=r_mode)
            if model_name!='MNODE_GL' or refit!=True:
                continue
            model.load_state_dict(torch.load(f"{model_name}_{exp_type}_{start_g}_{repeat}_{train_split}.pth"))
            edge_weights=model.return_edge_weights().detach().cpu()
            total=len(edge_weights)
            noise=edge_weights[-1].item()
            sorted_edge_weights=np.sort(edge_weights)

            best_score=-1e5
            list_param_dicts=hyper_param_dicts[model_name+"R"]
            best_hyper_param=list_param_dicts[0]
            for i in range(len(list_param_dicts)):
                hyper_param=list_param_dicts[i]
                step=int(hyper_param['rate']*total)
                th=sorted_edge_weights[np.where(sorted_edge_weights==noise)[0][0]+step]
                edge_weights=model.return_edge_weights().detach()
                edge_weights[edge_weights<=th]=0
                edge_weights[edge_weights>th]=1
                edge_weights[-1]=0
                #print(step)
                #print(edge_weights)
                score=0
                for val_split in range(V):
                    torch.manual_seed(seed)
                    train,val,test,train_mean,train_std=cv_split2(perms, cases, repeat,\
                                                                  train_split, val_split, 0,\
                                                                  N, V, batch_size=len(cases))
                    model2=globals()[model_name](DAG=dag, edge_map=edge_map, init_ew=edge_weights, hyper_params=best_hyper_param)
                    #print(sum(p.numel() for p in model.parameters()))
                    model2.freeze_ew()
                    train_h,val_h=train_model(model2, train, val, epochs=epoch, hyper_params=hyper_param,\
                                          train_std=train_std, device=device, verbose=False, r_mode=r_mode)
                    score-=np.min(val_h)
                #print(score)
                if score>best_score:
                    best_hyper_param=hyper_param
                    best_score=score
            step=int(best_hyper_param['rate']*total)
            th=sorted_edge_weights[np.where(sorted_edge_weights==noise)[0][0]+step]
            edge_weights=model.return_edge_weights().detach()
            edge_weights[edge_weights<=th]=0
            edge_weights[edge_weights>th]=1
            edge_weights[-1]=0
            #print(edge_weights)
            #print(step)
            torch.manual_seed(seed)
            train,val,test,train_mean,train_std=cv_split2(perms, cases, repeat,\
                                                          train_split, 0, 0,\
                                                          N, V, batch_size=len(cases))
            model2=globals()[model_name](DAG=dag, edge_map=edge_map, init_ew=edge_weights, hyper_params=best_hyper_param)
            model2.freeze_ew()
            train_h,val_h=train_model(model2, train, val, epochs=epoch, hyper_params=hyper_param,\
                                      train_std=train_std, device=device, verbose=False,\
                                      path=f"{model_name}R_{exp_type}_{start_g}_{repeat}_{train_split}.pth",\
                                      r_mode=r_mode)
            


if __name__=='__main__':
    args=sys.argv[1:]
    # exp experiment_type model_name starting_graph GPU_ID
    main(args)
