import torch
import torch.nn as nn
import functorch
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split

import argparse
import os
import pandas as pd
from models.toy import *
from utils import select_phi, initial_classfier, eval_model, construct_path
from copy import deepcopy

import json



def train_classifier_model(model, train_loader, val_loader, epochs, lr=0.001, cuda=True):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    best_val = 0
    for epoch in range(epochs):
        for step, src_data in enumerate(train_loader):
            if len(src_data) == 2:
                s_x, s_y = src_data
            if cuda:
                s_x = s_x.cuda()
                s_y = s_y.cuda()
            s_pred = model.forward(s_x)
            ce_loss = torch.nn.functional.cross_entropy(s_pred,  s_y.long())
            optimizer.zero_grad()
            ce_loss.backward()
            optimizer.step()
        if val_loader is not None and (epoch+1) % 2 == 0:
            acc_val = eval_model(model, val_loader)
            if best_val < acc_val:
                best_val = acc_val
                best_model = deepcopy(model)
    if val_loader is not None:
        return best_model
    return model


def cost_matrix_sqeuclidean(x, y):
    x2 = x.pow(2).sum(dim=1, keepdim=True)
    y2 = y.pow(2).sum(dim=1, keepdim=True).t()
    xy = x @ y.t()
    return x2 - 2*xy + y2


def train_esuot(cfg, src_dataloader, tar_dataloader):

    # how many phase we want
    number_of_phase = cfg.phase_num
    entropy_reg = cfg.entropy_reg

    # phi

    phi2 = select_phi(cfg.phi2)


    # train
    netD_list, netG_list = [], []
    netD_loss_dict, netG_loss_dict = {}, {}


    for phase_idx in range(number_of_phase):

        logger.info(f"we start training phase: {phase_idx}")

        netD_loss_list = []
        # potential function $w_{\phi}$
        netD = ToyDiscriminator(data_dim=cfg.dimension, hidden_dim=cfg.hidden_dimension).cuda()
        optimizerD = torch.optim.Adam(netD.parameters(), lr=cfg.lr)

        for x_src, x_tar in zip(src_dataloader, tar_dataloader):
            x_src_covariate, x_src_label = x_src[0].float().cuda(), x_src[1].float().cuda()
            x_tar_covariate, x_tar_label = x_tar[0].float().cuda(), x_tar[1].float().cuda()

            total_D_loss = 0.0
            for epoch in range(cfg.epochs):
                pred_result = netD(x_tar_covariate)
                dual_result = phi2(-pred_result).mean()

                # cost_matrix
                C = torch.cdist(x_src_covariate, x_tar_covariate, p=2) ** 2  # [Bx, By]
                logits = (pred_result.unsqueeze(0) - C) / entropy_reg

                logsumexp = torch.logsumexp(logits, dim=1)
                log_mean_exp = logsumexp - torch.log(torch.tensor(x_tar_covariate.shape[0], dtype=torch.float).cuda())
                loss_soft = (entropy_reg * log_mean_exp).mean()

                total_loss = loss_soft + dual_result

                optimizerD.zero_grad()
                total_loss.backward()
                optimizerD.step()
                total_D_loss = total_D_loss + total_loss.item()

            netD_loss_list.append(total_D_loss)
        netD_loss_dict[f"phase_{phase_idx}"] = netD_loss_list
        # transport map training start!
        for params in netD.parameters():
            params.requires_grad = False

        netG = ToyProjector(data_dim=cfg.dimension, hidden_dim=cfg.hidden_dimension).cuda()
        optimizerG = torch.optim.Adam(netG.parameters(), lr=cfg.lr)
        netG_loss_list = []

        netD.eval()

        for x_src, x_tar in zip(src_dataloader, tar_dataloader):
            total_G_loss = 0.0
            x_src_covariate, x_src_label = x_src[0].float().cuda(), x_src[1].float().cuda()

            for epoch in range(cfg.epochs):
                # fake sample generation
                x_src_pred = netG(x_src_covariate)  # [B, y_dim]

                # cost_term = 0.5 * torch.sum((x_pred - x_src) ** 2, dim=1) - netD(x_pred)
                cost_term = torch.sum(((x_src_pred - x_src_covariate).view(x_src_covariate.size(0), -1))**2, dim=1) - netD(x_src_pred)
                optimizerG.zero_grad()
                cost_term.mean().backward()
                optimizerG.step()
                total_G_loss = total_G_loss + cost_term.mean().item()
            netG_loss_list.append(total_G_loss)
        netG_loss_dict[f"phase_{phase_idx}"] = netG_loss_list



        netD_list.append(netD)
        netG_list.append(netG)

        # update the location for next phase
        netG_test = netG_list[-1]
        netG_test.eval()
        with torch.no_grad():
            src_covariate_list = []
            src_label_list = []
            for x_src, x_tar in zip(src_dataloader, tar_dataloader):
                x_src_covariate = x_src[0].float().cuda()

                x_src_pred = netG_test(x_src_covariate).cpu()#.numpy()

                x_src_label = x_src[1].float()
                # logger.warning(f"the x_src shape is {x_src.shape}, x_src_label shape is {x_src_label.shape}")
                src_covariate_list.append(x_src_pred)
                src_label_list.append(x_src_label)
        s_data = torch.cat(src_covariate_list)
        s_label = torch.cat(src_label_list)

        s_dataset = TensorDataset(s_data, s_label)
        src_dataloader = DataLoader(dataset=s_dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=False)



    return netG_list, netD_list, netG_loss_dict, netD_loss_dict









if __name__ == "__main__":

    from loguru import logger
    from utils import setup_seed

    parser = argparse.ArgumentParser(description='Code for ESUOT approach')
    parser.add_argument('--gpu_id', type=str, default='0', help="device id to run")
    parser.add_argument('--dimension', type=int, default=8, help="dimension of the features")
    parser.add_argument('--class_num', type=int, default=2, help="number of classes")
    parser.add_argument('--task', type=str, default='portraits')

    parser.add_argument('--hidden_dimension', type=int, default=512)
    parser.add_argument('--entropy_reg', type=float, default=0.1)



    parser.add_argument('--phi2', type=str, default='kl', choices=['linear', 'kl'], help='Choices of $f^{\star}$')

    parser.add_argument('--lr', type=float, default=0.0001, help="learning rate")
    parser.add_argument('--batch_size', type=int, default=1024, help="batch size")
    parser.add_argument('--save_path', type=str, default='save/', help="modules path")
    parser.add_argument('--epoch1', type=int, default=2000, help="epoch for training initial classifier")
    parser.add_argument('--epoch2', type=int, default=50000, help="epoch for training score network")
    parser.add_argument('--epoch3', type=int, default=50000, help="epoch for training rectified flow")
    parser.add_argument('--savepath', type=str, default="train_logs/uot_sample")
    parser.add_argument('--seed', type=int, default=4096, help="random seed")
    parser.add_argument('--dis_coeff', type=float, default=0.0075, help="the coefficient of the discirminator")
    parser.add_argument('--lmbda', type=float, default=0.01, help='regularization hyperparameter')
    parser.add_argument('--epochs', type=int, default=1000, help='The number of epochs')
    parser.add_argument('--regularize', action='store_true', default=True, help='use regularization or not')
    parser.add_argument('--tau', type=float, default=0.10, help='scalar value multiplied to quadratic cost functional')
    parser.add_argument('--phase_num', type=int, default=5, help='number of phases')


    parser.add_argument('--clfr_path', type=str, default='./init_cfr', help='path for initialize classifier')

    parser.add_argument('--gen_path', type=str, default='./trans_map', help='path for generator')



    args = parser.parse_args()


    setup_seed(args.seed)

    args.gen_path = args.gen_path + f"/task_{args.task}"
    args.clfr_path = args.clfr_path + f"/task_{args.task}"

    construct_path(args.clfr_path)
    construct_path(args.gen_path)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    batch_size = args.batch_size
    dimension = args.dimension
    datapkl = pd.read_pickle('dataset/%s.pkl' % args.task)

    z_all = datapkl['data']
    y_all = datapkl['label']

    logger.info(f"the dataset length: {type(datapkl['data'][0])}, shape: {datapkl['data'][0].shape}, shape: {datapkl['data'][1].shape}, the type: {len(datapkl['data'])}.")

    logger.info(f"let us start playing!")

    s_data = torch.from_numpy(z_all[0])
    s_label = torch.from_numpy(y_all[0])
    t_data = torch.from_numpy(z_all[-1])
    t_label = torch.from_numpy(y_all[-1])
    s_dataset = TensorDataset(s_data, s_label)
    logger.info(f"the s data and s label shapes: {s_data.shape}, {s_label.shape}")
    t_dataset = TensorDataset(t_data, t_label)
    source_loader = DataLoader(dataset=s_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    target_loader = DataLoader(dataset=t_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    source_loader_test = DataLoader(dataset=s_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    target_loader_test = DataLoader(dataset=t_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    # init classifier
    classifier = initial_classfier(dimension, args.class_num)
    classifier = train_classifier_model(classifier, source_loader, source_loader, args.epoch1)
    logger.warning(f"Initial Accuracy: {eval_model(classifier, target_loader_test)}")
    torch.save(classifier.state_dict(), os.path.join(args.clfr_path, f'netC_task_{args.task}.pt'))

    #  src_dataloader, tar_dataloader
    netG_list, _, _, _ = train_esuot(cfg=args, src_dataloader=source_loader, tar_dataloader=target_loader)

    for phase_idx in range(args.phase_num):
        netG = netG_list[phase_idx]
        torch.save(netG.state_dict(), os.path.join(args.gen_path, f'netG_task_{args.task}_phase_{phase_idx}.pt'))

