
import sys
from argparse import ArgumentParser
import os

from train_model import train_model

from datasets.eps_data_u import simulate

import warnings
import numpy as np

import torch
from utils import *


warnings.filterwarnings('ignore')
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

if __name__ == '__main__':
    parser = ArgumentParser()

    # Set Hyperparameters
    # i/o
    parser.add_argument('--data_dir', type=str, default='dataset/simulation_low_d', help='dir of data')
    parser.add_argument('--save_dir', type=str, default='log/9-20/5t_200u', help='dir to save result')
    parser.add_argument('--ihdp', type=str, default='/Users/zmin/Desktop/DCHCT_V4_2/data', help='dir to save result')

    # training
    parser.add_argument('--T_model', type=str, default=None,
                        help='model to use,[None, "gps", "bart", "cbgps", "npcbgps", "gbm", "eb", "dcows", "dcw", "vsr"]')
    parser.add_argument('--model', type=str, default='gps', help='model to use,["nn", "drnet", "vcnet"]')
    parser.add_argument('--n_epochs', type=int, default=3000, help='num of epochs to train')
    parser.add_argument('--mdn_epochs', type=int, default=300, help='num of epochs to train')
    parser.add_argument('--beta', type=int, default=200, help='num of epochs to train')
    
    parser.add_argument('--n_exps', type=int, default=30, help="the number of experiments")
    parser.add_argument("--train_bs", default=1800, type=int, help='train batch size')
    parser.add_argument('--n_samples', type=int, default=3700, help="the number of generated samples")
    parser.add_argument('--n_train', type=int, default=1800, help="the number of samples for training")
    parser.add_argument('--n_val', type=int, default=900, help="the number of samples for training")
    parser.add_argument('--n_test', type=int, default=1000, help="the number of samples for training")
    parser.add_argument('--alpha', type=float, default=1, help="the number of experiments")
    parser.add_argument('--temperature', type=int, default=1, help="the number of experiments")
    
    parser.add_argument('--t_left', type=int, default=-7, help="the number of samples for training")
    parser.add_argument('--t_right', type=int, default=15, help="the number of samples for training")
    parser.add_argument('--n_adrf', type=int, default=300, help="the number of samples for training")
    parser.add_argument('--lr', type=float, default=1e-2, help="the number of samples for training")

    parser.add_argument('--t_dim', type=int, default=5, help="the dimension of treatments")
    parser.add_argument('--x_dim', type=int, default=200, help="the dimension of covariates")
    parser.add_argument('--t_bin', type=bool, default=False, help="treament is binary(True) or not(False)")

    # print train info
    parser.add_argument('--verbose', type=int, default=500, help='print train info freq')
    parser.add_argument('--n_workers', type=int, default=0, help='num of workers')

    args = parser.parse_args()
    # 定义模型
    os.environ['CUDA_VISIBLE_DEVICES']='6'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    save_path = args.save_dir

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    models = ['cr']
    x_dims = [200]
    for x_dim in x_dims:
        args.x_dim = x_dim
        torch.manual_seed(0)
        torch.cuda.manual_seed_all(0)
        random.seed(0)
        for name in models:
            args.model = name

            adrf_MSE = np.zeros((args.n_exps,))
            for exp in range(args.n_exps):
                print(f'model: {args.model}')

                setup_seed(exp)

                data, adrf,x = simulate(args)

                train_data = data[:args.n_train, :]
                val_data = data[args.n_train:args.n_samples-args.n_test,:]
                test_data = data[args.n_train+args.n_val:,:]
                # test_x = x[args.n_train+args.n_val:,:]
                # np.save('/home/zmq/Code/EPS_v2/datasets/1t_5u_train.npy',train_data)
                # np.save('/home/zmq/Code/EPS_v2/datasets/1t_5u_test.npy',test_data)
                # np.save('/home/zmq/Code/EPS_v2/datasets/1t_5u_adrf.npy',adrf)
                # np.save('/home/zmq/Code/EPS_v2/datasets/1t_5u_u1000.npy',test_x)

                # if args.model == 'cf':
                if args.model == 'crnetv3':
                    crnet = train_model(args, device, train_data,val_data, adrf)
                    crnet.eval()
                    tx = torch.tensor(adrf[:,:-1], dtype=torch.float32).to(device)
                
                    out = crnet(tx)
                if args.model == 'nn':
                    nn = train_model(args, device, train_data, val_data, adrf)

                    nn.eval()
                    out = nn(torch.tensor(adrf[:,:-1], dtype=torch.float32).to(device))

                if args.model == 'crw':
                    crnet = train_model(args, device, train_data,val_data, adrf)
                    crnet.eval()
                    t = torch.tensor(adrf[:,:args.t_dim], dtype=torch.float32).to(device)
                    x = torch.tensor(adrf[:,args.t_dim:-1], dtype=torch.float32).to(device)
                
                    t_tmp, x_tmp, t_tmp2, x_tmp2, out = crnet.b_forward(t,x)
                if args.model == 'nnv2':
                    crnet = train_model(args, device, train_data,val_data, adrf)
                    crnet.eval()
                    t = torch.tensor(adrf[:,:args.t_dim], dtype=torch.float32).to(device)
                    x = torch.tensor(adrf[:,args.t_dim:-1], dtype=torch.float32).to(device)
                
                    t_tmp, x_tmp, t_tmp2, x_tmp2, out = crnet.b_forward(t,x)
                if args.model == 'cr':
                    crnet = train_model(args, device, train_data,val_data, adrf)
                    crnet.eval()
                    t = torch.tensor(adrf[:,:args.t_dim], dtype=torch.float32).to(device)
                    x = torch.tensor(adrf[:,args.t_dim:-1], dtype=torch.float32).to(device)
                
                    t_tmp, x_tmp, t_tmp2, x_tmp2, out = crnet.b_forward(t,x)
                    torch.save(crnet.state_dict(), os.path.join(save_path, name+'.pkl'))
                if args.model == 'mdnr':
                    mdn = train_model(args, device, train_data)
                    t = torch.tensor(train_data[:,:args.t_dim], dtype=torch.float32).to(device)
                    x = torch.tensor(train_data[:,args.t_dim:-1], dtype=torch.float32).to(device)
                    # p_t = mdn.get_p(torch.tensor(train_data[:, args.t_dim:-1], dtype=torch.float32), torch.tensor(train_data[:,0].squeeze(), dtype=torch.float32))
                    p_t = mdn.get_p(x, t)
                    if args.t_dim > 1:
                        p_t=torch.log(p_t)
                        p_t=p_t.sum(dim=1)
                        p_t=p_t.exp()

                    _, idx = torch.sort(p_t, descending=True)
                    _, w = torch.sort(idx)

                    new_train = torch.concat((torch.tensor(train_data, dtype=torch.float32).to(device), w.reshape(-1,1)), dim=1)
                    args.model = 'nn'
                    nn = train_model(args, device, new_train, val_data, adrf)

                    nn.eval()
                    tx = torch.tensor(adrf[:,:-1], dtype=torch.float32).to(device)
                    out = nn(tx)
                args.model = name
                adrf_mse = ((out.squeeze().cpu().detach().numpy()-adrf[:,-1].squeeze())**2).mean()
                log(save_path, 'exp: '+str(exp)+'mse: '+str(adrf_mse)+ '\n', name+'.txt')

                adrf_MSE[exp] = adrf_mse

            
            log(save_path, 'data_'+str(args.t_dim)+'_'+str(args.x_dim)+' ADRF mse: '+\
                str(adrf_MSE.mean().round(3))+' ± '+str(adrf_MSE.std().round(4)) +\
                '\n', name+'.txt')
            log(save_path, 'ADRF mse: '+str(adrf_MSE.mean().round(3))+' ± '+str(adrf_MSE.std().round(4)) + '\n', name+'.txt')

