from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, AdamW, get_scheduler
from datasets import load_metric, load_dataset, get_dataset_config_names
import numpy as np
from torch.utils.data import DataLoader
from train_eval import train, setup_seed
import torch
import os
from utils import split_train_val
from nlp_dataset import get_dataset, get_val_set, get_num_class
from args import get_args
from utils import PathConfig, generate_seed_set, read_config


PC = PathConfig()

args = get_args()
print(args)
cfg = read_config(cfg_path=PC.get_dataset_config_path() + args.dst_name + '.yaml')


def compute_metrics(eval_pred):
    metric = load_metric("f1")
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels, average="macro")


def one_round_training(seed):
    device = "cuda:0"
    model_name = "distilbert-base-uncased"
    dst_name = args.dst_name
    num_classes = get_num_class(dst_name)
    num_k = args.k
    val_method = args.val_method

    test_f1_list = []
    val_f1_list = []
    f1_bias_list = []

    whole_train_dst, test_dst = get_dataset(dst_name, model_name)
    if val_method in ['holdout', 'kfold', 'jkfold', 'LZO', 'split_free_holdout']:
        train_val_index_list = split_train_val([i for i in range(len(whole_train_dst))], whole_train_dst['labels'],
                                               seed=seed, k=num_k, val_ratio=0.2)
    for i in range(num_k):
        if 'split_free' in val_method:
            if val_method == 'split_free_holdout':
                train_dst = whole_train_dst.select(train_val_index_list[i][0])
            else:
                train_dst = whole_train_dst
            if val_method == 'split_free_test':
                val_dst = test_dst
            elif val_method == 'split_free_noval':
                val_dst = whole_train_dst
            else:
                val_method_map = {'split_free_joint': 'DB_ADJOINT', 'split_free_random': 'RANDOM', 'split_free_holdout': 'DB_ADJOINT'}
                val_dst = get_val_set(dst_name, model_name, method=val_method_map[val_method], seed=seed, val_num_per_class = cfg['val_num_per_class'])
        elif val_method == 'LZO':
            train_dst = whole_train_dst
            val_dst = get_val_set(dst_name, model_name, method='LZO', seed=seed)
        else:
            train_dst = whole_train_dst.select(train_val_index_list[i][0])
            val_dst = whole_train_dst.select(train_val_index_list[i][1])

        setup_seed(cfg["training_seed"])
        model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels= num_classes)
        print("----------------train dst size %d val dst size %d ------------" % (len(train_dst), len(val_dst)))
        model.to(device)
        train_dst.set_format("torch")
        val_dst.set_format("torch")
        test_dst.set_format("torch")
        train_dataloader = DataLoader(train_dst, shuffle=True, batch_size=args.batch_size)
        eval_dataloader = DataLoader(val_dst, batch_size=128)
        test_dataloader = DataLoader(test_dst, batch_size=128)

        optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        num_epochs = args.num_epochs
        lr_scheduler = get_scheduler(
            "linear",
            optimizer=optimizer,
            num_warmup_steps=0,
            num_training_steps=num_epochs * len(train_dataloader)
        )
        test_f1, val_f1, f1_bias, _ = train(model, train_dataloader, eval_dataloader, test_dataloader,
                                            optimizer, lr_scheduler, num_epochs, device)
        if abs(abs(test_f1-val_f1) - f1_bias) > 0.0000001:
            raise Exception("test f1 - val f1 not match f1_bias")
        test_f1_list.append(test_f1)
        val_f1_list.append(val_f1)
        f1_bias_list.append(f1_bias)

        del model
        torch.cuda.empty_cache()

    return np.mean(test_f1_list), np.mean(val_f1_list), np.mean(f1_bias_list)


def Kfold_cross_validation():
    seed_set = generate_seed_set(10)
    val_performance_list = []
    test_performance_list = []
    performance_bias_list = []
    for i, s in enumerate(seed_set):
        print("=" * 20 + str(i) + "=" * 20)
        test_perf, val_perf, perf_bias = one_round_training(s)
        val_performance_list.append(val_perf)
        test_performance_list.append(test_perf)
        performance_bias_list.append(perf_bias)

    print(args)
    print(val_performance_list)
    print(test_performance_list)
    print(performance_bias_list)
    print("val average performance", np.mean(val_performance_list))
    print("test average performance", np.mean(test_performance_list))
    print("val performance std ", np.std(val_performance_list))
    print("performance bias ", np.mean(performance_bias_list))

    import pickle
    save_path = os.path.join(args.save_path, args.dst_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    with open(os.path.join(save_path, args.val_method+'.pkl'), 'wb') as f:
        pickle.dump({'args':args,
                     'val_performance_list':val_performance_list,
                     'test_performance_list':test_performance_list,
                     'performance_bias_list':performance_bias_list}, f)


def JKfold_cross_validation():
    repeat_j = args.J # 2
    import numpy as np
    np.random.seed(0)
    # 5 round to get average
    seed_set = np.random.randint(0, 10000, size=10).tolist()
    seeds_for_kfold_list = []
    for s in seed_set:
        np.random.seed(s)
        seeds_for_kfold_list.append(np.random.randint(0, 10000, size=repeat_j).tolist())

    val_performance_list = []
    test_performance_list = []
    performance_bias_list = []

    for i in range(len(seeds_for_kfold_list)):
        jk_val_performance_list = []
        jk_test_performance_list = []
        jk_performance_bias_list = []
        print("=" * 20 + str(i) + "=" * 20)
        seeds_for_kfold = seeds_for_kfold_list[i]
        for j in range(repeat_j):
            test_perf, val_perf, perf_bias = one_round_training(seeds_for_kfold[j])
            jk_val_performance_list.append(val_perf)
            jk_test_performance_list.append(test_perf)
            jk_performance_bias_list.append(perf_bias)

        val_performance_list.append(np.mean(jk_val_performance_list))
        test_performance_list.append(np.mean(jk_test_performance_list))
        performance_bias_list.append(np.mean(jk_performance_bias_list))

    print(args)
    print(seeds_for_kfold_list)
    print(val_performance_list)
    print(test_performance_list)
    print(performance_bias_list)
    print("val average performance", np.mean(val_performance_list))
    print("test average performance", np.mean(test_performance_list))
    print("val performance std ", np.std(val_performance_list))
    print("performance bias ", np.mean(performance_bias_list))

    import pickle
    save_path = os.path.join(args.save_path, args.dst_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    with open(os.path.join(save_path, args.val_method+'.pkl'), 'wb') as f:
        pickle.dump({'args':args,
                     'val_performance_list':val_performance_list,
                     'test_performance_list':test_performance_list,
                     'performance_bias_list':performance_bias_list}, f)


def main():
    if args.J == 1:
        Kfold_cross_validation()
    elif args.J > 1:
        JKfold_cross_validation()
    else:
        raise Exception("J value error %d"%args.J)


# CUDA_VISIBLE_DEVICES=4 python eval_main.py --num_epochs 15 --batch_size 8 --dst_name reuters --val_method holdout --k 1
if __name__ == '__main__':
    main()




