import argparse
import os
import json
import time
import numpy as np

import torch
from torch import optim
import torch.utils.data

from .model import VAE
from .. import utils
# from model import VAE
# import sys; sys.path.append('../'); import utils
from .dataloader import MyDataLoader
from .metrics import mean_corr_coef
from sklearn.cross_decomposition import CCA

def set_parser():
    parser = argparse.ArgumentParser(description='')

    # input/output setting
    parser.add_argument('--outdir', type=str, required=True)
    parser.add_argument('--datadir', type=str, required=True)
    parser.add_argument('--dataname-train', type=str, default='train')
    parser.add_argument('--dataname-valid', type=str, default='valid')

    # prior knowledge
    parser.add_argument('--range-dcoeff', type=float, nargs=2, default=[5e-3, 2e-1])

    # model (general)
    parser.add_argument('--dim-z-aux1', type=int, required=True, help="if 0, aux1 is still alive without latent variable; set -1 to deactivate")
    parser.add_argument('--dim-z-aux2', type=int, required=True, help="if 0, aux2 is still alive without latent variable; set -1 to deactivate")
    parser.add_argument('--activation', type=str, default='elu') #choices=['relu','leakyrelu','elu','softplus','prelu'],
    parser.add_argument('--ode-solver', type=str, default='euler')
    parser.add_argument('--intg-lev', type=int, default=1)
    parser.add_argument('--no-phy', action='store_true', default=False)

    # model (decoder)
    parser.add_argument('--x-lnvar', type=float, default=-10.0)
    parser.add_argument('--hidlayers-aux1-dec', type=int, nargs='+', default=[128,])
    parser.add_argument('--hidlayers-aux2-dec', type=int, nargs='+', default=[128,])

    # model (encoder)
    parser.add_argument('--hidlayers-aux1-enc', type=int, nargs='+', default=[128,])
    parser.add_argument('--hidlayers-aux2-enc', type=int, nargs='+', default=[128,])
    parser.add_argument('--hidlayers-unmixer', type=int, nargs='+', default=[128,])
    parser.add_argument('--hidlayers-dcoeff', type=int, nargs='+', default=[128])
    parser.add_argument('--arch-feat', type=str, default='mlp')
    parser.add_argument('--num-units-feat', type=int, default=256)
    parser.add_argument('--hidlayers-feat', type=int, nargs='+', default=[256,])
    parser.add_argument('--num-rnns-feat', type=int, default=1)

    # optimization (base)
    parser.add_argument('--learning-rate', type=float, default=1e-3)
    parser.add_argument('--weight-decay', type=float, default=1e-3)
    parser.add_argument('--adam-eps', type=float, default=1e-3)
    parser.add_argument('--grad-clip', type=float, default=10.0)
    parser.add_argument('--batch-size', type=int, default=200)
    parser.add_argument('--epochs', type=int, default=5000)
    parser.add_argument('--balance-kld', type=float, default=1.0)
    parser.add_argument('--balance-unmix', type=float, default=0.0)
    parser.add_argument('--balance-dataug', type=float, default=0.0)
    parser.add_argument('--balance-lact-dec', type=float, default=0.0)
    parser.add_argument('--balance-lact-enc', type=float, default=0.0)

    # others
    parser.add_argument('--train-size', type=int, default=-1)
    parser.add_argument('--save-interval', type=int, default=999999999)
    parser.add_argument('--num-workers', type=int, default=0)
    parser.add_argument('--cuda', action='store_true', default=False)
    parser.add_argument('--seed', type=int, default=1234567890)

    return parser


if __name__ == '__main__':
    parser = set_parser()
    args = parser.parse_args()
    args.cuda = args.cuda and torch.cuda.is_available()
    # device = torch.device("cuda" if args.cuda else "cpu")
    device = 'cuda:0'
    # set random seed
    torch.manual_seed(args.seed)
    args.dim_x = 20
    args.dim_t = 50
    args.dx = 0.1
    args.dt = 0.02
    loader_test = MyDataLoader('{}/data_{}.pt'.format(args.datadir, 'test'), 8, 5000, False)
    model = VAE(vars(args)).to(device)
    model.load_state_dict(torch.load('output1/model.pt', map_location=device))
    model.eval()
    for batch_idx, (data, context, param, _) in enumerate(loader_test):
        data = torch.load('data.pt')
        context = torch.load('context.pt')
        data = data.to(device)
        context = context.to(device)
        param = param.to(device)
        all_data = torch.cat((data.unsqueeze(1), context), dim=1)
        dcoeff_stat, z_aux1_stat, z_aux2_stat, unmixed = model.encode(all_data[:, :-1])
        dcoeff, z_aux1, z_aux2 = model.draw(dcoeff_stat, z_aux1_stat, z_aux2_stat, hard_z=False)
        init_y = data[:, :, 0].clone()
        x_mean, x_PA, x_PB, x_P, x_lnvar = model.decode(dcoeff, z_aux1, z_aux2, init_y, full=True)
        torch.save(data.detach().cpu(), 'GT-Meta-Hybrid-Rec.pt')
        torch.save(x_mean.detach().cpu(), 'XT-Meta-Hybrid-Rec.pt')
        print('%.5f' % torch.sum((x_mean - data).pow(2), dim=[1, 2]).mean())
        print('%.7f' % torch.mean((dcoeff_stat['mean'] - param[:, [0]]).pow(2)))
        z_aux = torch.cat((dcoeff_stat['mean'], z_aux1_stat['mean'], z_aux2_stat['mean']), dim=-1)
        cca = CCA(n_components=2, max_iter=5000)
        cca.fit(param.cpu().detach().numpy(), z_aux.cpu().detach().numpy())
        res_in = cca.transform(param.cpu().detach().numpy(), z_aux.cpu().detach().numpy())
        mcc_weak_in = mean_corr_coef(res_in[0], res_in[1])
        print(mcc_weak_in)
        init_y = context[:, -1, :, 0].clone()
        x_mean, x_PA, x_PB, x_P, x_lnvar = model.decode(dcoeff, z_aux1, z_aux2, init_y, full=True)
        print('%.5f' % torch.sum((x_mean - all_data[:, -1]).pow(2), dim=[1, 2]).mean())
        torch.save(context[:, -1].detach().cpu(), 'GT-Meta-Hybrid-Gen.pt')
        torch.save(x_mean.detach().cpu(), 'XT-Meta-Hybrid-Gen.pt')




