import argparse
import os

import torch
import torch.nn as nn
import pandas as pd
from utils import *
from models.toy import ToyProjector
import sys
import time




if __name__ == "__main__":
    from loguru import logger

    parser = argparse.ArgumentParser(description='Code for GGF')
    parser.add_argument('--gpu_id', type=str, default='0', help="device id to run")
    parser.add_argument('--dimension', type=int, default=8, help="dimension of the features")
    parser.add_argument('--class_num', type=int, default=2, help="number of classes")
    parser.add_argument('--task', type=str, default='portraits')
    parser.add_argument('--lr', type=float, default=0.001, help="learning rate")
    parser.add_argument('--batch_size', type=int, default=1024, help="batch size")
    parser.add_argument('--hidden_dimension', type=int, default=512)
    parser.add_argument('--entropy_reg', type=float, default=0.01)

    parser.add_argument('--phi1', type=str, default='linear', choices=['linear', 'kl'], help='Choices of phi1 star')
    parser.add_argument('--phi2', type=str, default='kl', choices=['linear', 'kl'], help='Choices of phi2 star')
    parser.add_argument('--regularize', action='store_true', default=True, help='use regularization or not')

    parser.add_argument('--sink_reg', type=float, default=1.0e-2,
                        help='the regularization strength of sinkhorn algorithm')
    parser.add_argument('--kl_reg', type=float, default=1.0e-3, help='the regularization strength of kl divergence')

    parser.add_argument('--alpha', type=int, default=4, help="alpha")
    parser.add_argument('--iterations', type=int, default=120, help="iterations = alpha * T")
    parser.add_argument('--lamb', type=float, default=0, help="the weight of two classifier-based energy functions")
    parser.add_argument('--eta1', type=float, default=0.03, help="step size of distribution-based energy functions")
    parser.add_argument('--eta2', type=float, default=0.08, help="step size of classifier-based energy functions")
    parser.add_argument('--eta3', type=float, default=0.01, help="step size of sample-based energy functions")
    parser.add_argument('--confidence', type=float, default=0.00, help="confidence threshold, 0 means finetuning with preserved learning")
    parser.add_argument('--save_path', type=str, default='save/', help="modules path")
    parser.add_argument('--seed', type=int, default=1024, help="random seed")
    parser.add_argument('--dis_coeff', type=float, default=0.0075, help="the coefficient of the discirminator")
    parser.add_argument('--clfr_path', type=str, default='./init_cfr', help='path for initialize classifier')

    # parser.add_argument('--gen_path', type=str, default='./gen_network_sample', help='path for generator')
    parser.add_argument('--gen_path', type=str, default='./trans_map', help='path for generator')


    parser.add_argument('--csv_path', type=str, default='./result_csv(GDA)', help='path for discriminator')
    parser.add_argument('--phase_num', type=int, default=5, help='number of phases')

    parser.add_argument('--top_k_reg', type=int, default=5, help='top k regularization strength')


    args = parser.parse_args()

    uot_seed_dict = {"portraits": 4096, "mnist45": 4, "mnist60": 4}
    uot_reg_dict = {"portraits": 0.1, "mnist45": 0.01, "mnist60": 0.005}


    setup_seed(args.seed)
    args.csv_path = args.csv_path + f"/task_{args.task}"
    args.gen_path = args.gen_path + f"/task_{args.task}"
    args.clfr_path = args.clfr_path + f"/task_{args.task}"




    construct_path(args.csv_path)


    coeff_list = [0.001, 0.0015, 0.0020, 0.0025, 0.005, 0.0075, 0.01, 0.05, 0.10]

    batch_size = args.batch_size
    dimension = args.dimension
    datapkl = pd.read_pickle('dataset/%s.pkl' % args.task)
    z_all = datapkl['data']
    y_all = datapkl['label']

    s_data = torch.from_numpy(z_all[0])
    s_label = torch.from_numpy(y_all[0])
    t_data = torch.from_numpy(z_all[-1])
    t_label = torch.from_numpy(y_all[-1])
    s_dataset = TensorDataset(s_data, s_label)
    t_dataset = TensorDataset(t_data, t_label)
    source_loader = DataLoader(dataset=s_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    target_loader = DataLoader(dataset=t_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    source_loader_test = DataLoader(dataset=s_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    target_loader_test = DataLoader(dataset=t_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    # init modules
    # init temp variables
    tmp = deepcopy(s_data).cuda()
    tmp_labels = deepcopy(s_label).cuda()


    eta_tmp = torch.tensor([args.eta1]).cuda()


    best_value = 0
    best_model = None

    # load transport map
    classifier = initial_classfier(dimension, args.class_num)
    classifier.load_state_dict(torch.load(os.path.join(args.clfr_path, f'netC_task_{args.task}.pt')))

    tmp_clf = deepcopy(classifier)
    optimizer_inner = optim.Adam([{"params": tmp_clf.parameters(), "lr": 0.001}, ], weight_decay=1e-3)

    initial_accuracy = eval_model(classifier, target_loader_test)
    logger.warning(f"Initial Accuracy: {eval_model(classifier, target_loader_test):.5f}")

    coeff_ot_dist_list = []
    coeff_data_list = []

    tmp_data_list = [tmp]
    result_dataframe_list = []


    for phase_idx in range(args.phase_num):
        netG = ToyProjector(data_dim=args.dimension, hidden_dim=args.hidden_dimension).cuda()
        netG.load_state_dict(torch.load(os.path.join(args.gen_path, f'netG_task_{args.task}_phase_{phase_idx}.pt')))

        # transport samples
        netG.eval()
        with torch.no_grad():
            tmp = netG(tmp)
            tmp_data_list.append(tmp)

    total_time = 0.0
    for phase_idx in range(len(tmp_data_list)):  # top_k_key:
        start_time = time.time()
        logger.info(f"the phase is: {phase_idx}")
        tmp_target = tmp_data_list[phase_idx]

        if args.confidence == 0:
            tmp_pseudo, tmp_labels_pseudo = deepcopy(tmp_target), deepcopy(tmp_labels).cuda()
        else:
            tmp_pseudo, tmp_labels_pseudo = get_pseudo_dataset(tmp_clf, tmp_target, confidence_q=args.confidence)

        for _ in range(10):
            pred_t = tmp_clf(tmp_pseudo)
            ce_loss = F.cross_entropy(pred_t, tmp_labels_pseudo)
            optimizer_inner.zero_grad()
            ce_loss.backward()
            optimizer_inner.step()

        end_time = time.time()
        total_time = total_time + end_time - start_time

        acc = eval_model(tmp_clf, target_loader_test)
        temp_dict = {"seed": args.seed, "phase_idx": phase_idx, "init accuracy": initial_accuracy, "acc": acc}
        result_dataframe_list.append(pd.DataFrame(temp_dict, index=[0]))

    total_result_dataframe = pd.concat(result_dataframe_list, axis=0)
    total_result_dataframe["time"] = total_time
    total_result_dataframe.to_csv(os.path.join(args.csv_path, f'total_result_{args.task}.csv'))

