import copy
import torch
import time
import logging
import argparse
import os
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# from reptile import Learner
from methods import adaminimax, sgda, pdsm, tiada

from data_loader import Sent140Dataset,  collate_pad
from torch.utils.data import DataLoader
import random
import numpy as np
torch.backends.cudnn.enabled = False
def random_seed(value):
    torch.backends.cudnn.deterministic=True
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    np.random.seed(value)
    random.seed(value)

def ImbalanceGenerator(data, imratio=0.2, split=True):
    val = copy.deepcopy(data)
    X = data.sentences
    Y = data.labels
    id_list = list(range(len(X)))
    np.random.shuffle(id_list)
    X = [X[id] for id in id_list]
    Y = [Y[id] for id in id_list]
    X_copy = X.copy()
    Y_copy = Y.copy()
    num_neg = np.where(np.array(Y_copy) == 0)[0][:1000].shape[0]
    num_pos = np.where(np.array(Y_copy) == 1)[0].shape[0]
    keep_num_pos = int((imratio / (1 - imratio)) * num_neg)
    neg_id_list = np.where(np.array(Y_copy) == 0)[0][:1000]
    pos_id_list = np.where(np.array(Y_copy) == 1)[0][:keep_num_pos]
    remain_list = [neg_id_list.tolist() + pos_id_list.tolist()][0]
    X_copy = [X_copy[i] for i in remain_list]
    Y_copy =  [Y_copy[i] for i in remain_list]
    id_list = list(range(len(X_copy)))
    np.random.shuffle(id_list)
    sentences =  [X_copy[id] for id in id_list]
    labels = [Y_copy[id] for id in id_list]
    size_data = len(labels)
    if split:
        data.sentences = sentences[:int(size_data/2)]
        val.sentences = sentences[int(size_data/2):]
        data.labels = labels[:int(size_data/2)]
        val.labels = labels[int(size_data/2):]
        return data, val
    else:
        data.sentences = sentences
        data.labels = labels
        return data

def main():
    
    parser = argparse.ArgumentParser()

    parser.add_argument("--data", default='sentment140', type=str,
                        help="dataset: [news_data, snli, sentment140]", )

    parser.add_argument("--data_path", default='../data/news-data/dataset.json', type=str,
                        help="Path to dataset file")

    parser.add_argument("--batch_size", default=32, type=int,
                        help="batch_size", )

    parser.add_argument("--test_batch_size", default=32, type=int,
                        help="test batch size", )

    parser.add_argument("--save_direct", default='sentiment140', type=str,
                        help="Path to save file")

    parser.add_argument("--methods" , default='adaminimax', type=str,
                        help="choice method [sgda, pdsm, tiada, adaminimax]")

    parser.add_argument("--num_labels", default=2, type=int,
                        help="Number of class for classification")

    parser.add_argument("--epoch", default=50, type=int,
                        help="Number of outer interation")
    
    parser.add_argument("--inner_batch_size", default=32, type=int,
                        help="Training batch size in inner iteration")

    parser.add_argument("--neumann_lr", default=1e-2, type=float,
                        help="update for neumann series")

    parser.add_argument("--hessian_q", default=3, type=int,
                        help="Q steps for hessian-inverse-vector product")

    parser.add_argument("--outer_update_lr", default= 1e-1, type=float,
                        help="Meta learning rate")

    parser.add_argument("--inner_update_lr", default=1e-1, type=float,
                        help="Inner update learning rate")
    
    parser.add_argument("--inner_update_step", default=3, type=int,
                        help="Number of interation in the inner loop during train time")

    parser.add_argument("--gamma", default=1e-3, type=float,
                        help="clipping threshold")

    parser.add_argument("--seed", default=2, type=int,
                        help="random seed")

    parser.add_argument("--beta", default=0.9, type=float,
                        help="momentum parameters")

    parser.add_argument("--imratio", default=0.9, type=float,
                        help="The ratio of imbalance")

    # RNN hyperparameter settings
    parser.add_argument("--word_embed_dim", default=300, type=int,
                        help="word embedding dimensions")

    parser.add_argument("--encoder_dim", default=4096, type=int,
                        help="encodding dimensions")

    parser.add_argument("--n_enc_layers", default=2, type=int,
                        help="encoding layers")

    parser.add_argument("--fc_dim", default=1024, type=int,
                        help="dimension of fully-connected layer")

    parser.add_argument("--n_classes", default=2, type=int,
                        help="classes of targets")

    parser.add_argument("--linear_fc", default=False, type=bool,
                        help="classes of targets")

    parser.add_argument("--pool_type", default="max", type=str,
                        help="type of pooling")

    parser.add_argument("--noise_rate", default=0.0, type=float,
                        help="rate for label noise")

    parser.add_argument("--alpha", default=1.0, type=float, 
                        help="parameter for adaminimax")

    parser.add_argument("--power_alpha", default=0.6, type=float,
                        help="power parameter for tiada")

    args = parser.parse_args()
    random_seed(args.seed)

    if args.data == 'sentment140':
        if os.path.isfile(f'data/train_data_{args.imratio}'):
            print('loading data...')
            train = torch.load(f'data/train_data_{args.imratio}')
            val = torch.load(f'data/val_data_{args.imratio}')
            test = torch.load(f'data/test_data_{args.imratio}')
        else:
            trainset = Sent140Dataset("../../data", "train", noise_rate=args.noise_rate)
            train, val = ImbalanceGenerator(trainset, imratio=args.imratio)
            torch.save(train, f'data/train_data_{args.imratio}')
            torch.save(val, f'data/val_data_{args.imratio}')
            test = Sent140Dataset("../../data", "test")
            # test = ImbalanceGenerator(testset, imratio=args.imratio, split=False)
            torch.save(test, f'data/test_data_{args.imratio}')
        # test = ImbalanceGenerator(testset, imratio=0.2)
        args.n_labels = 2
        args.n_classes = 2

    else:
        print('Do not support this data')

    st = time.time()

    
    if args.methods == 'pdsm':
        args.outer_update_lr = 1e-2
        args.inner_update_lr = 1e-1
        learner = pdsm.Learner(args)

    elif args.methods == 'sgda':
        args.outer_update_lr = 1e-1
        args.inner_update_lr = 5e-2
        # larger batch size
        args.inner_batch_size = int(args.inner_batch_size * 2)
        args.batch_size = int(args.batch_size * 2)
        learner = pdsm.Learner(args)

    elif args.methods == 'tiada':
        args.outer_update_lr = 1e-1
        args.inner_update_lr = 5e-2
        args.power_alpha = 0.6
        learner = tiada.Learner(args)

    elif args.methods == 'adaminimax':
        args.outer_update_lr = 1e-2
        args.inner_update_lr = 1e-2
        args.gamma = 0.1
        args.alpha = 0.5
        learner = adaminimax.Learner(args)
    else:
        print('No such method, please change the method name!')

    print(args)
    global_step = 0
    auc_all_test = []
    loss_all_test = []
    auc_all_train = []
    loss_all_train = []
    for epoch in range(args.epoch):
        print(f"[epoch/epochs]:{epoch}/{args.epoch}")
        train_loader = DataLoader(train, shuffle=True, batch_size=args.inner_batch_size, collate_fn=collate_pad)
        val_loader = DataLoader(val, shuffle=True, batch_size=args.batch_size, collate_fn=collate_pad)
        test_loader = DataLoader(test, batch_size=args.test_batch_size, collate_fn=collate_pad)
        auc, loss = learner(train_loader, val_loader, training=True, epoch=epoch)
        auc_all_train.append(round(auc, 4))
        loss_all_train.append(round(loss,4))
        print('training Loss:', loss_all_train)
        print( 'training Auc:', auc_all_train)
        print("---------- Testing Mode -------------")

        auc, loss = learner.test(test_loader)
        auc_all_test.append(round(auc, 4))
        loss_all_test.append(round(loss,4))

        print(f'{args.methods} Test loss:, {loss_all_test}')
        print(f'{args.methods} Test auc:, {auc_all_test}')
        global_step += 1

    file_name = f'{args.methods}_outlr{args.outer_update_lr}_inlr{args.inner_update_lr}_alpha{args.alpha}_gamma{args.gamma}_seed{args.seed}_grad_diff'
    save_path = 'logs/ada'
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    total_time = (time.time() - st) / 3600
    files = open(os.path.join(save_path, file_name)+'.txt', 'w')
    files.write(str({'Exp configuration': str(args), 'AVG Train AUC': str(auc_all_train),
               'AVG Test AUC': str(auc_all_test), 'AVG Train LOSS': str(loss_all_train), 'AVG Test LOSS': str(loss_all_test), 'time': total_time}))
    files.close()
    torch.save((auc_all_train, auc_all_test, loss_all_train, loss_all_test), os.path.join(save_path, file_name))
    print(args)
    print(f'time:{total_time} h')
if __name__ == "__main__":
    main()
