import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import models
from enum import Enum
from torch.utils.data import DataLoader
import torch
from args import get_cifar_fe_args
from sklearn.preprocessing import normalize
import sys
sys.path.append("..")
from globa_utils import setup_seed
from imbalanced_datasets import get_dataset, get_transform, PC
from train_eval import train


class FeatureExtractor(object):
    def __init__(self, feature_extractor_type, model, device):
        self.feature_extractor = model
        self.feature_extractor.eval()
        self.feature_extractor.to(device)
        self.fe_type = feature_extractor_type
        self.device = device

    def get_features(self, datas, is_logits=False):
        with torch.no_grad():
            datas = datas.to(self.device)
            if self.fe_type == 'simclr':
                features = self.feature_extractor(datas)
            else:
                logits, features = self.feature_extractor(datas)
            if is_logits:
                return logits.unsqueeze(dim=0)
            else:
                return features.unsqueeze(dim=0)

    def extractor_features_from_dst(self, dst, num_classes, is_logits=False):
        feature_list = []
        label_list = []
        dst_loader = DataLoader(dst, batch_size=64, shuffle=False, num_workers=2)
        for datas, labels in dst_loader:
            features = self.get_features(datas, is_logits=is_logits).squeeze(dim=0)
            feature_list.append(features)
            label_list.append(labels)
        feature_list = torch.cat(feature_list, dim=0)
        label_list = torch.cat(label_list, dim=0)
        if len(dst) != feature_list.shape[0]:
            raise Exception('extractor features error!')

        # split feature by class
        class_split_index = []
        for c in range(num_classes):
            class_split_index.append([])
        for i in range(len(dst)):
            class_split_index[label_list[i]].append(i)
        class_split_features = {}
        for c in range(num_classes):
            class_split_features[c] = feature_list[class_split_index[c], ...]

        return class_split_features


def load_feature_extractor(path, model_name='resnet110' ,device="cuda:0", num_classes=10, fe_type='default'):
    if fe_type == 'default':
        checkpoint = torch.load(path)
        if model_name == 'resnet110':
            model = models.resnet110(num_classes=num_classes)#torch.nn.DataParallel(
        elif model_name == 'resnet20':
            model = models.resnet20(num_classes=num_classes)  # torch.nn.DataParallel(
        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)
    elif fe_type == 'simclr':
        from simclr.resnet_big import SupConResNet
        model = SupConResNet(name='resnet18')
        ckpt = torch.load(path, map_location='cpu')
        state_dict = ckpt['model']

        if torch.cuda.is_available():
            if torch.cuda.device_count() > 1:
                model.encoder = torch.nn.DataParallel(model.encoder)
            else:
                new_state_dict = {}
                for k, v in state_dict.items():
                    k = k.replace("module.", "")
                    new_state_dict[k] = v
                state_dict = new_state_dict
            model = model.cuda()
            cudnn.benchmark = True
            model.load_state_dict(state_dict)

    return model


# CUDA_VISIBLE_DEVICES=1 python feature_extractor.py --arch resnet20 --num-class 100 --save-dir /data/omf/model/DataValidation/CV/feature_extractor/cifar100/
def main():
    # set up random seed
    model_names = sorted(name for name in models.__dict__
                         if name.islower() and not name.startswith("__")
                         and name.startswith("resnet")
                         and callable(models.__dict__[name]))

    print(model_names)

    args = get_cifar_fe_args(model_names)
    device = "cuda:0"
    num_classes = args.num_class
    random_seed = PC.get_global_random_seed('im_cifar' + str(num_classes))
    setup_seed(seed=random_seed)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # === prepare data begin ===
    train_dst = get_dataset('im_cifar' + str(num_classes), split='train', rand_number=random_seed)
    train_dst.transform = get_transform('im_cifar' + str(num_classes), t_type='train')
    test_dst = get_dataset('im_cifar' + str(num_classes), split='test', rand_number=random_seed)
    test_dst.transform = get_transform('im_cifar' + str(num_classes), t_type='test')
    train_loader = DataLoader(train_dst, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
                              pin_memory=True)

    val_loader = DataLoader(train_dst, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
                            pin_memory=True)
    test_loader = DataLoader(test_dst, batch_size=128, shuffle=False, num_workers=args.workers, pin_memory=True)
    print("train size" + str(len(train_dst)))
    # === prepare data end ===

    # === training module set up ===
    model = models.__dict__[args.arch](num_classes=num_classes)
    model.to(device)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150])
    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this setup it will correspond for first epoch.
        print('update lr for resnet')
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * 0.1

    # === training ===
    train(train_loader, val_loader, test_loader,
                                     model, criterion, optimizer, lr_scheduler,
                                     args.epochs, device, save_dir=args.save_dir)


if __name__ == '__main__':
    main()