import os
import pickle
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import numpy as np
import models
from args import get_cifar10_args
from imbalanced_datasets import get_dataset, get_transform, PC
from train_eval import train
from torch.utils.data import DataLoader
import yaml
from globa_utils import setup_seed, generate_seed_set
from  val_sampler import split_train_val


model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(models.__dict__[name]))

print(model_names)

args = get_cifar10_args(model_names)
print(args)


def one_round_training(seed):
    device = "cuda:0"
    num_k = args.k
    num_classes = args.num_class

    test_perf_list = []
    val_perf_list = []
    perf_bias_list = []

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    dst_rand_seed = PC.get_global_random_seed('im_cifar' + str(num_classes))
    whole_train_dst = get_dataset('im_cifar' + str(num_classes), split='train',
                                  rand_number=dst_rand_seed, is_wrapper=True)
    test_dst = get_dataset('im_cifar'+str(num_classes), 'test', rand_number=dst_rand_seed)
    test_dst.transform = get_transform('im_cifar'+str(num_classes), t_type='test')

    if 'split_free' in args.val_method:
        if args.val_method in ['split_free', 'split_free_joint', 'split_free_random']:
            val_strategy_map = {'split_free': 'DB', 'split_free_joint': 'DB_ADJOINT', 'split_free_random': 'RANDOM'}
            val_dst = get_dataset('im_cifar' + str(num_classes), split='val', rand_number=seed,
                                  val_method=val_strategy_map[args.val_method])
        elif args.val_method == 'split_free_holdout':
            train_val_index_list = split_train_val(whole_train_dst.indexset,
                                                   whole_train_dst.get_label_list(), seed=seed, k=num_k, val_ratio=0.2)
            val_dst = get_dataset('im_cifar' + str(num_classes), split='val', rand_number=seed, val_method='DB_ADJOINT')
        elif args.val_method == 'split_free_test':
            val_dst = test_dst
    elif args.val_method == 'LZO':
        val_dst = get_dataset('im_cifar' + str(num_classes), split='val', rand_number=seed, val_method='LZO')
    else:
        train_val_index_list = split_train_val(whole_train_dst.indexset,
                                               whole_train_dst.get_label_list(), seed=seed, k=num_k, val_ratio=0.2)

    for i in range(num_k):
        if args.val_method == 'split_free_holdout':
            train_index, val_index = train_val_index_list[i]
            train_dst = whole_train_dst.get_dataset_by_indexes(train_index)
        elif 'split_free' not in args.val_method and args.val_method != 'LZO':
            train_index, val_index = train_val_index_list[i]
            train_dst = whole_train_dst.get_dataset_by_indexes(train_index)
            val_dst = whole_train_dst.get_dataset_by_indexes(val_index)
        else:
            train_dst = whole_train_dst.get_dataset_by_indexes(whole_train_dst.indexset)
            if args.val_method == 'split_free_noval':
                # train set as val set
                val_dst = whole_train_dst.get_dataset_by_indexes(whole_train_dst.indexset)

        train_dst.transform = get_transform('im_cifar'+str(num_classes), t_type='train')
        val_dst.transform = get_transform('im_cifar'+str(num_classes), t_type='test')

        print("="*20+"curr k "+str(i)+'='*20)
        print('train set size %d, val set size %d, test set size %d'%(len(train_dst), len(val_dst), len(test_dst)))
        setup_seed(seed=42)
        train_loader = DataLoader(train_dst, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
                                  pin_memory=True)
        val_loader = DataLoader(val_dst, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
                                pin_memory=True)
        test_loader = DataLoader(test_dst, batch_size=128, shuffle=False, num_workers=args.workers, pin_memory=True)
        # === prepare data end ===

        # === training module set up ===
        model = models.__dict__[args.arch](num_classes)
        model.to(device)

        # define loss function (criterion) and optimizer
        criterion = nn.CrossEntropyLoss().to(device)
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150])
        if args.arch in ['resnet1202', 'resnet110']:
            # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
            # then switch back. In this setup it will correspond for first epoch.
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr * 0.1

        # === training ===
        test_perf, best_val_perf, val_test_bias = train(train_loader, val_loader, test_loader,
              model, criterion, optimizer, lr_scheduler,
              args.epochs, device, save_dir=args.save_dir)

        # === recording ===
        if abs(abs(test_perf - best_val_perf) - val_test_bias) > 0.0000001:
            raise Exception("test perf - val perf not match val_test_bias")
        test_perf_list.append(test_perf)
        val_perf_list.append(best_val_perf)
        perf_bias_list.append(val_test_bias)

        # === clear ===
        del model
        torch.cuda.empty_cache()

    return np.mean(test_perf_list), np.mean(val_perf_list), np.mean(perf_bias_list)


def Kfold_cross_validation():
    import numpy as np
    val_performance_list = []
    test_performance_list = []
    performance_bias_list = []
    seed_set = generate_seed_set(5)
    for i, s in enumerate(seed_set):
        print("="*20+str(i)+"="*20)
        test_perf, val_perf, perf_bias = one_round_training(s)
        val_performance_list.append(val_perf)
        test_performance_list.append(test_perf)
        performance_bias_list.append(perf_bias)

    print(args)
    print(val_performance_list)
    print(test_performance_list)
    print(performance_bias_list)
    print("val average performance", np.mean(val_performance_list))
    print("test average performance", np.mean(test_performance_list))
    print("val performance std ", np.std(val_performance_list))
    print("performance bias ", np.mean(performance_bias_list))

    with open(args.save_name, 'wb') as f:
        pickle.dump({'args':args,
                    'val_performance_list':val_performance_list,
                     'test_performance_list':test_performance_list,
                     'performance_bias_list':performance_bias_list}, f)


def JKfold_cross_validation():
    repeat_j = 4 # 2
    import numpy as np
    np.random.seed(0)
    # 5 round to get average
    seed_set = np.random.randint(0, 10000, size=5).tolist()
    seeds_for_kfold_list = []
    for s in seed_set:
        np.random.seed(s)
        seeds_for_kfold_list.append(np.random.randint(0, 10000, size=repeat_j).tolist())

    val_performance_list = []
    test_performance_list = []
    performance_bias_list = []

    for i in range(len(seeds_for_kfold_list)):
        jk_val_performance_list = []
        jk_test_performance_list = []
        jk_performance_bias_list = []
        print("=" * 20 + str(i) + "=" * 20)
        seeds_for_kfold = seeds_for_kfold_list[i]
        for j in range(repeat_j):
            test_perf, val_perf, perf_bias = one_round_training(seeds_for_kfold[j])
            jk_val_performance_list.append(val_perf)
            jk_test_performance_list.append(test_perf)
            jk_performance_bias_list.append(perf_bias)

        val_performance_list.append(np.mean(jk_val_performance_list))
        test_performance_list.append(np.mean(jk_test_performance_list))
        performance_bias_list.append(np.mean(jk_performance_bias_list))

    print(args)
    print(val_performance_list)
    print(test_performance_list)
    print(performance_bias_list)
    print("val average performance", np.mean(val_performance_list))
    print("test average performance", np.mean(test_performance_list))
    print("val performance std ", np.std(val_performance_list))
    print("performance bias ", np.mean(performance_bias_list))

    with open(args.save_name, 'wb') as f:
        pickle.dump({'args': args,
                     'val_performance_list': val_performance_list,
                     'test_performance_list': test_performance_list,
                     'performance_bias_list': performance_bias_list}, f)


def main():
    if args.J == 1:
        Kfold_cross_validation()
    elif args.J > 1:
        JKfold_cross_validation()
    else:
        raise Exception("J value error %d" % args.J)


if __name__ == '__main__':
    main()
