import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import models
from enum import Enum
from torch.utils.data import DataLoader
import torch
from args import get_cifar_fe_args
from sklearn.preprocessing import normalize
import sys
sys.path.append("..")
from globa_utils import setup_seed
from imbalanced_datasets import get_dataset, get_transform, PC
from train_eval import train
from  val_sampler import split_train_val


# CUDA_VISIBLE_DEVICES=5 python kfoldCV_angle_distribution.py --arch resnet20 --num-class 100 --save-dir /data/omf/model/DataValidation/CV/feature_extractor/cifar100/kfoldCV_resnet20/
def main():

    NUM_K = 5

    model_names = sorted(name for name in models.__dict__
                         if name.islower() and not name.startswith("__")
                         and name.startswith("resnet")
                         and callable(models.__dict__[name]))

    print(model_names)

    args = get_cifar_fe_args(model_names)
    num_classes = args.num_class

    random_seed = PC.get_global_random_seed('im_cifar' + str(num_classes))
    setup_seed(seed=random_seed)
    device = "cuda:0"

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # === prepare data begin ===
    whole_train_dst = get_dataset('im_cifar' + str(num_classes), split='train',
                                  rand_number=random_seed, is_wrapper=True)
    train_val_index_list = split_train_val(whole_train_dst.indexset,
                                           whole_train_dst.get_label_list(), seed=PC.get_kfoldCV_seed(),
                                           k=NUM_K, val_ratio=0.2)

    for k in range(NUM_K):
        train_index, val_index = train_val_index_list[k]
        train_dst = whole_train_dst.get_dataset_by_indexes(train_index)
        val_dst = whole_train_dst.get_dataset_by_indexes(val_index)

        train_dst.transform = get_transform('im_cifar'+ str(num_classes), t_type='train')
        val_dst.transform = get_transform('im_cifar'+ str(num_classes), t_type='test')

        print("="*20+"curr k "+str(k)+'='*20)
        print('train set size %d, val set size %d'%(len(train_dst), len(val_dst)))

        train_loader = DataLoader(train_dst, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
                                  pin_memory=True)
        val_loader = DataLoader(val_dst, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
                                pin_memory=True)
        # === prepare data end ===

        # === training module set up ===
        model = models.__dict__[args.arch](num_classes=num_classes)
        model.to(device)

        # define loss function (criterion) and optimizer
        criterion = nn.CrossEntropyLoss().to(device)
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150])
        if args.arch in ['resnet1202', 'resnet110']:
            # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
            # then switch back. In this setup it will correspond for first epoch.
            print('update lr for resnet')
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr * 0.1

        # === training ===
        train(train_loader, val_loader, val_loader,
                                         model, criterion, optimizer, lr_scheduler,
                                         args.epochs, device, save_dir=args.save_dir+str(k)+'_model.th')


if __name__ == '__main__':
    main()