import time
import os
import numpy as np
#import matplotlib.pyplot as plt
#from math import *
import pandas as pd
import argparse
#from IPython import embed
import math
import models
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parameter import Parameter
from torchvision import datasets
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import utils



def train_model(test_loader, model_A,model_B,model_D, lr, device,opt_A,opt_B,Data_opt1,train_loader=None,train_data=True
                ,A_det='det'):
    model_A.train()
    model_B.train()
    if(train_data==True):
        model_D.train()
    else:
        model_D.eval()
    if (train_loader is None):
        train_error=True
    else:
        train_error = False
    #     print("Start Training...")
    total = 0
    correct = 0
    losses = []
    data_x_list = []
    data_ep_list = []
    y_A_list=[]
    y_B_list = []
    y_test_list = []

    if (train_loader is not None):
        iter_source_train = iter(train_loader)

    for _, batch in enumerate(test_loader):



        batch_X = batch['feature'].to(device)
        batch_ep = batch['label'].to(device)

        if (train_loader is not None):
            batch_train = iter_source_train.next()
            batch_X_train=batch_train['feature'].to(device)
            batch_ep_train = batch_train['label'].to(device)
        if(train_error==True):
            batch_X_train=batch_X
            batch_ep_train=batch_ep

        Data_adjust = torch.zeros_like(batch_X, requires_grad=True)
        Data_adjusted = batch_X + Data_adjust
        ep_adjust= torch.zeros_like(batch_ep, requires_grad=True)
        ep_adjusted=batch_ep + ep_adjust
        criterion = nn.MSELoss()
        A_out_train=model_A(batch_X_train.detach())
        B_out_train=model_B(batch_X_train.detach())
        Data_out_train=model_D(batch_X_train.detach()).detach()+batch_ep_train.detach()


        b = [{"params": [Data_adjust,ep_adjust]}]
        parameter_list = b#a + b
        Data_opt2 = torch.optim.SGD(parameter_list, lr=lr)

        opt_A.zero_grad()
        opt_B.zero_grad()
        loss_B = criterion(A_out_train, Data_out_train)
        loss_A = criterion(B_out_train, Data_out_train)
        loss_train = loss_A + loss_B
        loss_train.backward()
        opt_A.step()
        opt_B.step()




        if(A_det is 'det'):
            A_out_test = model_A(Data_adjusted).detach()

        else:
            A_out_test = model_A(Data_adjusted)


        B_out_test = model_B(Data_adjusted).detach()

        Data_out_test = model_D(Data_adjusted)+ep_adjusted#(batch,dim)

        if (train_data == True):
            Data_opt1.zero_grad()
            Data_opt2.zero_grad()
            loss_data = criterion(Data_out_test, A_out_test) - criterion(Data_out_test, B_out_test)
            data_mean=torch.mean(Data_out_test, dim=0)
            data_std=torch.std(Data_out_test, dim=0)
            one_base1 = torch.ones_like(data_mean)
            loss_data_regular=criterion(data_mean, 0*one_base1)+criterion(data_std, 1*one_base1)
            loss_data_all=loss_data+loss_data_regular
            loss_data_all.backward()
            Data_opt1.step()
            Data_opt2.step()



        data_x_list.append(Data_adjusted.detach().cpu().numpy())
        data_ep_list.append(ep_adjusted.detach().cpu().numpy())
        y_A_list.append(A_out_test.detach().cpu().numpy())
        y_B_list.append(B_out_test.detach().cpu().numpy())
        y_test_list.append(Data_out_test.detach().cpu().numpy())




    data_x_np=np.concatenate(data_x_list, axis=0)
    data_ep_np = np.concatenate(data_ep_list, axis=0)
    y_A_np = np.concatenate(y_A_list, axis=0)
    y_B_np = np.concatenate(y_B_list, axis=0)
    y_test_np = np.concatenate(y_test_list, axis=0)
    mseA=utils.get_MSE_np(y_A_np,y_test_np)
    maeA = utils.get_MAE_np(y_A_np,y_test_np)
    mapeA = utils.get_MAPE_np(y_A_np,y_test_np)
    mseB = utils.get_MSE_np(y_B_np, y_test_np)
    maeB = utils.get_MAE_np(y_B_np, y_test_np)
    mapeB = utils.get_MAPE_np(y_B_np, y_test_np)
    #rmse = np.sqrt(mse / total)
    return data_x_np,data_ep_np, ((mseA,maeA,mapeA),(mseB,maeB,mapeB))
def test_model(test_loader, model_A,model_B,model_D, device):
    model_A.eval()
    model_B.eval()
    model_D.eval()

    #     print("Start Training...")
    total = 0
    correct = 0
    losses = []
    y_A_list=[]
    y_B_list = []
    y_test_list = []
    for _, batch in enumerate(test_loader):



        batch_X = batch['feature'].to(device)
        batch_ep = batch['label'].to(device)

        ep_adjusted = batch_ep
        Data_adjusted = batch_X
        Data_out_test = model_D(Data_adjusted) + ep_adjusted

        batch_X = batch['feature'].to(device)


        criterion = nn.MSELoss()
        A_out_train=model_A(batch_X.detach())

        B_out_train=model_B(batch_X.detach())



        y_A_list.append(A_out_train.detach().cpu().numpy())
        y_B_list.append(B_out_train.detach().cpu().numpy())
        y_test_list.append(Data_out_test.detach().cpu().numpy())



    y_A_np = np.concatenate(y_A_list, axis=0)
    y_B_np = np.concatenate(y_B_list, axis=0)
    y_test_np = np.concatenate(y_test_list, axis=0)
    mseA=utils.get_MSE_np(y_A_np,y_test_np)
    maeA = utils.get_MAE_np(y_A_np,y_test_np)
    mapeA = utils.get_MAPE_np(y_A_np,y_test_np)
    mseB = utils.get_MSE_np(y_B_np, y_test_np)
    maeB = utils.get_MAE_np(y_B_np, y_test_np)
    mapeB = utils.get_MAPE_np(y_B_np, y_test_np)

    return 0,0, ((mseA,maeA,mapeA),(mseB,maeB,mapeB))
def updata_dataset(data__np):
    npmean = np.mean(data__np, axis=0, keepdims=True)
    npstd = np.std(data__np, axis=0, keepdims=True)
    trainx = (data__np - npmean) / npstd
    dataset = utils.A_Dataset(features=trainx)
    dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    return dataset_loader


def initial_models(models_shapeA,models_shapeB,train_size,model_num,device):
    modelA_list = nn.ModuleList([]) # several train models
    modelB_list = nn.ModuleList([])
    train_choicelist = []
    models_shapeA = models_shapeA.split("-")
    models_shapeA = [int(x) for x in models_shapeA]
    models_shapeB = models_shapeB.split("-")
    models_shapeB = [int(x) for x in models_shapeB]
    for _ in range(model_num):
        modela_i=models.Linear_N_H(models_shapeA).to(device)
        modelA_list.append(modela_i)
        modelb_i = models.Linear_N_H(models_shapeB).to(device)
        modelB_list.append(modelb_i)
        train_choice1 = np.random.choice(len(testx), size=train_size, replace=False)
        train_choicelist.append(train_choice1)
    return modelA_list,modelB_list,train_choicelist


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--epochs2', type=int, default=10)
    parser.add_argument('--epochs_for_newmodels', type=int, default=5)
    parser.add_argument('--epochs_for_finmodels', type=int, default=10)#30-120 training epochs
    parser.add_argument('--exp_dir', type=str, default=None)
    parser.add_argument('--gpu', type=int, default=1)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--m_num_', type=int, default=1)

    parser.add_argument('--train_size', type=int, default=500)
    parser.add_argument('--test_eachsize', type=int, default=5000)
    parser.add_argument('--test_allsize', type=int, default=100000)
    # updata test dataset,ep, and train dataset each test_eachsize


    parser.add_argument('--testep_std', type=int, default=500)
    parser.add_argument('--model_shapeA', type=str, default='6-4-50-6')
    parser.add_argument('--model_shapeB', type=str, default='6-32-50-6')
    parser.add_argument('--model_shapeData', type=str, default='6-4-50-6')
    parser.add_argument('--data_input_shape', type=str, default='6-100')#[dim,length]
    parser.add_argument('--models_num', type=int, default=1)
    parser.add_argument('--dataset_name', type=str, default='???')
    parser.add_argument('--training_error', type=int, default=0)
    parser.add_argument('--testing_error', type=int, default=0)

    parser.add_argument('--model_typeA', type=str, default='LSTM_1_H')
    parser.add_argument('--model_typeB', type=str, default='LSTM_1_H')
    parser.add_argument('--model_typeData', type=str, default='LSTM_1_H')
    parser.add_argument('--A_det', type=str, default='nodet')#'det',try both

    args = parser.parse_args()




    print('start',args.dataset_name)






    if (args.training_error):
        args.test_allsize=args.train_size*10
        args.batch_size = min(args.batch_size, args.train_size - 5)
    args.m_num=args.m_num_-1
    batch_size = args.batch_size

    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:" + '0' if use_cuda else "cpu")

    models_shapeData = args.model_shapeData.split("-")
    models_shapeData = [int(x) for x in models_shapeData]
    input_shape = args.data_input_shape.split("-")
    input_shape = [int(x) for x in input_shape]

    testx=np.float32(np.random.rand(args.test_allsize, input_shape[0],input_shape[1]))
    testep =np.float32(np.random.rand(args.test_allsize,models_shapeData[-1]))

    npmean=np.mean(testx,axis=0,keepdims=True)
    npstd = np.std(testx, axis=0, keepdims=True)
    epmean=np.mean(testep,axis=0,keepdims=True)
    epstd = np.std(testep, axis=0, keepdims=True)
    testx=(testx-npmean)/npstd
    testep=(testep-epmean)/epstd/args.testep_std

    args.exp_dir = './exp_dir'
    dataset_name = args.dataset_name
    exp_dir = args.exp_dir
    model_name1 = 'regression_1d_PDS0' + str(args.m_num)
    runs_dir = os.path.join(exp_dir, 'runs', dataset_name)
    weights_dir = os.path.join(exp_dir, 'weights', dataset_name)
    logs_dir = os.path.join(exp_dir, 'logs', dataset_name)
    if not os.path.exists(runs_dir):
        os.makedirs(runs_dir)
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    if not os.path.exists(logs_dir):
        os.makedirs(logs_dir)


    model_D=models.model_and_shape(models_shapeData, args.model_typeData).to(device)









    if(args.testing_error):
        modelA_list,modelB_list,train_choicelist=models.initial_models(args.model_shapeA,args.model_shapeB,(len(testx),args.train_size),args.models_num,device,models_typeA=args.model_typeA,models_typeB=args.model_typeB)
        train_logs=[]

        for i in range(args.epochs):
            modelA_list,modelB_list,train_choicelist=models.initial_models(args.model_shapeA,args.model_shapeB,(len(testx),args.train_size),args.models_num,device,models_typeA=args.model_typeA,models_typeB=args.model_typeB)
            Data_opt1 = torch.optim.Adam(model_D.parameters(), lr=args.lr)#initialize each iteration


            for ii in range(len(modelA_list)):
                modelA = modelA_list[ii]
                modelB = modelB_list[ii]
                # opt_A = torch.optim.SGD(modelA.parameters(), lr=args.lr, momentum=0.1)
                # opt_B = torch.optim.SGD(modelB.parameters(), lr=args.lr, momentum=0.1)
                # Data_opt1 = torch.optim.SGD(model_D.parameters(), lr=args.lr, momentum=0.1)
                opt_A = torch.optim.Adam(modelA.parameters(), lr=args.lr)
                opt_B = torch.optim.Adam(modelB.parameters(), lr=args.lr)


                for j in range(args.epochs_for_newmodels):
                    mseA, maeA, mapeA, mseB, maeB, mapeB = [], [], [], [], [], []

                    test_choice1 = np.random.choice(len(testx), size=args.test_eachsize, replace=False)
                    testx_small = testx[test_choice1]
                    testep_small = testep[test_choice1]
                    test_dataset = utils.A_Dataset(features=testx_small, labels=testep_small)


                    train_choice1 = train_choicelist[ii]
                    trainx = testx[train_choice1]
                    trainep = testep[train_choice1]
                    trainx = trainx.repeat(args.test_eachsize // len(trainx) + 1, axis=0)
                    trainep = trainep.repeat(args.test_eachsize // len(trainep) + 1, axis=0)
                    train_dataset = utils.A_Dataset(features=trainx, labels=trainep)

                    dataset_loader_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                                                      shuffle=False,
                                                                      num_workers=0)
                    dataset_loader_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                                       shuffle=True,
                                                                       num_workers=0)


                    data_x_np, data_ep_np, evaluate = train_model(dataset_loader_test, modelA, modelB, model_D, args.lr,
                                                                  device,opt_A,opt_B,Data_opt1, train_loader=dataset_loader_train,A_det=args.A_det)

                    npmeans = np.mean(data_x_np, axis=0, keepdims=True)
                    npstds = np.std(data_x_np, axis=0, keepdims=True)
                    epmeans = np.mean(data_ep_np, axis=0, keepdims=True)
                    epstds = np.std(data_ep_np, axis=0, keepdims=True)
                    data_x_np = (data_x_np - npmeans) / npstds
                    data_ep_np = (data_ep_np - epmeans) / epstds / args.testep_std

                    testx[test_choice1] = data_x_np
                    testep[test_choice1] = data_ep_np
                    mseAi, maeAi, mapeAi, mseBi, maeBi, mapeBi = evaluate[0][0], evaluate[0][1], evaluate[0][2], \
                                                                 evaluate[1][0], evaluate[1][1], evaluate[1][2]
                    mseA.append(mseAi), maeA.append(maeAi), mapeA.append(mapeAi), mseB.append(mseBi), maeB.append(
                        maeBi), mapeB.append(mapeBi)

            mseAV,maeAV,mapeAV,mseBV,maeBV,mapeBV=np.mean(mseA),np.mean(maeA),np.mean(mapeA),np.mean(mseB),np.mean(maeB),np.mean(mapeB)
            rmseA,rmseB = np.sqrt(mseAV),np.sqrt(mseBV)
            now_epoch_result = {'epoch': i,'maeA': maeAV,
                                'mapeA': mapeAV, 'maeB': maeBV, 'mapeB': mapeBV,
                                'rmseA': rmseA, 'rmseB': rmseB, 'M-rmse': rmseA-rmseB}

            train_logs.append(now_epoch_result)
            train_logs_df = pd.DataFrame(train_logs)
            train_logs_path = os.path.join(logs_dir, model_name1 + 'testing_not_fin.csv')
            train_logs_df.to_csv(train_logs_path)

        npmean = np.mean(testx, axis=0, keepdims=True)
        npstd = np.std(testx, axis=0, keepdims=True)
        epmean = np.mean(testep, axis=0, keepdims=True)
        epstd = np.std(testep, axis=0, keepdims=True)
        testx = (testx - npmean) / npstd
        testep = (testep - epmean) / epstd / args.testep_std
        modelA_list,modelB_list,train_choicelist=models.initial_models(args.model_shapeA,args.model_shapeB,(len(testx),args.train_size),args.models_num,device,models_typeA=args.model_typeA,models_typeB=args.model_typeB)
        train_logs = []
        for iii in range(args.epochs2):
            models_num=1
            modelA_list,modelB_list,train_choicelist=models.initial_models(args.model_shapeA,args.model_shapeB,(len(testx),args.train_size),models_num,device,models_typeA=args.model_typeA,models_typeB=args.model_typeB)
            now_valid_mseA, now_valid_mseB = 9999, 9999

            train_choice1 = train_choicelist[0]
            trainx = testx[train_choice1]
            trainep = testep[train_choice1]
            trainx = trainx.repeat(args.test_eachsize // len(trainx) + 1, axis=0)
            trainep = trainep.repeat(args.test_eachsize // len(trainep) + 1, axis=0)
            train_dataset = utils.A_Dataset(features=trainx, labels=trainep)

            mask = np.ones(len(testx), dtype=bool)
            mask[train_choice1] = False
            test_notrain_npy = testx[mask]
            test_notrain_ep = testep[mask]
            test_choice1 = np.random.choice(len(test_notrain_npy), size=args.test_eachsize, replace=False)
            testx_small = test_notrain_npy[test_choice1]
            testep_small = test_notrain_ep[test_choice1]
            test_dataset = utils.A_Dataset(features=testx_small, labels=testep_small)

            valid_mask = np.ones(len(testx), dtype=bool)
            valid_mask[train_choice1] = False
            valid_mask[test_choice1] = False
            valid_notrain_npy = testx[valid_mask]
            valid_notrain_ep = testep[valid_mask]
            valid_choice1 = np.random.choice(len(valid_notrain_npy), size=args.test_eachsize, replace=False)
            validx_small = valid_notrain_npy[valid_choice1]
            validep_small = valid_notrain_ep[valid_choice1]
            valid_dataset = utils.A_Dataset(features=validx_small, labels=validep_small)
            dataset_loader_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                                              shuffle=False,
                                                              num_workers=0)
            dataset_loader_valid = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size,
                                                               shuffle=False,
                                                               num_workers=0)
            dataset_loader_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                               shuffle=True,
                                                               num_workers=0)
            ii = 0
            modelA = modelA_list[ii]
            modelB = modelB_list[ii]
            # opt_A = torch.optim.SGD(modelA.parameters(), lr=args.lr, momentum=0.1)
            # opt_B = torch.optim.SGD(modelB.parameters(), lr=args.lr, momentum=0.1)
            # Data_opt1 = torch.optim.SGD(model_D.parameters(), lr=args.lr, momentum=0.1)
            opt_A = torch.optim.Adam(modelA.parameters(), lr=args.lr)
            opt_B = torch.optim.Adam(modelB.parameters(), lr=args.lr)
            Data_opt1 = torch.optim.Adam(model_D.parameters(), lr=args.lr)
            for j in range(args.epochs_for_finmodels):

                if(1):
                    data_x_np, data_ep_np, evaluate_valid = train_model(dataset_loader_valid, modelA, modelB, model_D, args.lr,
                                                                  device,opt_A,opt_B,Data_opt1, train_loader=dataset_loader_train,
                                                                  train_data=False)
                    _, _, evaluate_test = test_model(dataset_loader_test, modelA, modelB, model_D, device)
                    if now_valid_mseA > evaluate_valid[0][0]:
                        mseAi, maeAi, mapeAi = evaluate_test[0][0], evaluate_test[0][1], evaluate_test[0][2]
                        now_valid_mseA = evaluate_valid[0][0]
                    if now_valid_mseB > evaluate_valid[1][0]:
                        mseBi, maeBi, mapeBi = evaluate_test[1][0], evaluate_test[1][1], evaluate_test[1][2]
                        now_valid_mseB = evaluate_valid[1][0]
            rmseA, rmseB = np.sqrt(mseAi), np.sqrt(mseBi)
            now_epoch_result = {'epoch': iii, 'maeA': maeAi,
                                'mapeA': mapeAi, 'maeB': maeBi, 'mapeB': mapeBi,
                                'rmseA': rmseA, 'rmseB': rmseB, 'M-rmse': rmseA - rmseB}
            train_logs.append(now_epoch_result)
            train_logs_df = pd.DataFrame(train_logs)
            train_logs_path = os.path.join(logs_dir, model_name1 + 'testing_fin.csv')
            train_logs_df.to_csv(train_logs_path)
        npmean=np.mean(testx,axis=0,keepdims=True)
        npstd = np.std(testx, axis=0, keepdims=True)
        epmean=np.mean(testep,axis=0,keepdims=True)
        epstd = np.std(testep, axis=0, keepdims=True)
        testx=(testx-npmean)/npstd
        testep=(testep-epmean)/epstd/args.testep_std



