import os
import random
import logging
import numpy as np
from math import inf
from scipy import stats

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from randaugment import RandAugmentMC

import caltech
import imbalance_cifar

train_criterion = nn.CrossEntropyLoss(reduction='none').cuda()  # (Note)
val_criterion = nn.CrossEntropyLoss().cuda()

def sparse2coarse(targets):
    """Convert Pytorch CIFAR100 sparse targets to coarse targets.
    Usage:
        trainset = torchvision.datasets.CIFAR100(path)
        trainset.targets = sparse2coarse(trainset.targets)
    """
    coarse_labels = np.array([ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,  
                               3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                               6, 11,  5, 10,  7,  6, 13, 15,  3, 15,  
                               0, 11,  1, 10, 12, 14, 16,  9, 11,  5, 
                               5, 19,  8,  8, 15, 13, 14, 17, 18, 10, 
                               16, 4, 17,  4,  2,  0, 17,  4, 18, 17, 
                               10, 3,  2, 12, 12, 16, 12,  1,  9, 19,  
                               2, 10,  0,  1, 16, 12,  9, 13, 15, 13, 
                              16, 19,  2,  4,  6, 19,  5,  5,  8, 19, 
                              18,  1,  2, 15,  6,  0, 17,  8, 14, 13])
    return coarse_labels[targets]

def get_transforms(args):
    if args.dataset  in ['caltech256', 'stanfordcars']:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        val_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])
        strong_train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomResizedCrop(224),
            RandAugmentMC(n=2, m=10),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.dataset  == 'tinyimagenet':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            # transforms.Resize(224),
            transforms.ToTensor(),
            normalize,
        ])
        val_transform = transforms.Compose([
            # transforms.Resize(224),
            transforms.ToTensor(),
            normalize,
        ])
        strong_train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            RandAugmentMC(n=2, m=10),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        if args.dataset == 'mnist':
            normalize = transforms.Normalize(mean=(0.13066373765468597,),
                                            std=(0.30810782313346863,))
            val_transform = transforms.Compose([
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
            # if args.dataset == 'cifar10':
            #     normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
            #                                     std=[0.2470, 0.2435, 0.2616])
            # else:
            #     normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
            #                                     std=[0.2675, 0.2565, 0.2761])
            val_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])

        train_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize,
            ])

        strong_train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            RandAugmentMC(n=2, m=10),
            transforms.ToTensor(),
            normalize,
            ])
        
    return train_transform, val_transform, strong_train_transform

def save_checkpoint(state, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def param_grad_norms(model):
    names = []
    grad_norms = []
    for name, param in model.named_parameters():
        grad = param.grad.view(-1)
        grad_norms.append(grad.norm().item())
        names.append(name)

    return names, grad_norms


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class GradualWarmupScheduler(_LRScheduler):
    """
    from: https://github.com/ildoonet/pytorch-gradual-warmup-lr
    Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        total_epoch: target learning rate is reached at total_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    """

    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
        if self.last_epoch <= self.total_epoch:
            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                param_group['lr'] = lr
        else:
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.total_epoch)

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.total_epoch)
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)


def get_instance_noisy_label(n, dataset, labels, num_classes, feature_size, norm_std, seed): 
    # n -> noise_rate 
    # dataset -> mnist, cifar10 # not train_loader
    # labels -> labels (targets)
    # label_num -> class number
    # feature_size -> the size of input images (e.g. 28*28)
    # norm_std -> default 0.1
    # seed -> random_seed 
    print("building dataset...")
    label_num = num_classes
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))
    torch.cuda.manual_seed(int(seed))

    P = []
    flip_distribution = stats.truncnorm((0 - n) / norm_std, (1 - n) / norm_std, loc=n, scale=norm_std)
    flip_rate = flip_distribution.rvs(labels.shape[0])

    if isinstance(labels, list):
        labels = torch.FloatTensor(labels)
    labels = labels.cuda()

    W = np.random.randn(label_num, feature_size, label_num)


    W = torch.FloatTensor(W).cuda()
    for i, (x, y) in enumerate(dataset):
        # 1*m *  m*10 = 1*10
        x = x.cuda()
        A = x.reshape(1, -1).mm(W[y]).squeeze(0)
        A[y] = -inf
        A = flip_rate[i] * F.softmax(A, dim=0)
        A[y] += 1 - flip_rate[i]
        P.append(A)
    P = torch.stack(P, 0).cpu().numpy()
    l = [i for i in range(label_num)]
    new_label = [np.random.choice(l, p=P[i]) for i in range(labels.shape[0])]
    record = [[0 for _ in range(label_num)] for i in range(label_num)]

    for a, b in zip(labels, new_label):
        a, b = int(a), int(b)
        record[a][b] += 1


    pidx = np.random.choice(range(P.shape[0]), 1000)
    cnt = 0
    for i in range(1000):
        if labels[pidx[i]] == 0:
            a = P[pidx[i], :]
            cnt += 1
        if cnt >= 10:
            break
    return np.array(new_label)

class IndexedDataset(Dataset):
        def __init__(self, dataset, transform=None, train=False, corruption=None, noisy=0., seed=0, idn=False, num_classes=10, feature_size=3*32*32, norm_std=0.1, args=None, group=[0,0]):
            if dataset == 'cifar10':
                if args.longtail:
                    self.dataset = imbalance_cifar.IMBALANCECIFAR10(root='./data', imb_type='', imb_factor=args.imb_factor, rand_number=args.seed, train=train, transform=transform, download=True)
                else:
                    self.dataset = datasets.CIFAR10(root='./data', train=train, transform=transform, download=True)
                if noisy > 0 and train:
                    assert seed > 0
                    if idn:
                        idn_str = '_idn'
                    else:
                        idn_str = ''
                    print(idn_str)
                    label_path = f'output_cifar10/record/cifar10_transform{idn_str}_noisy_{noisy:.1f}_resnet18_1.0_128_seed_{seed}_lr_0.1_mile_epochs_200_record/noisy_label.pkl'
                    if not os.path.exists(label_path):
                        if idn:
                            targets = torch.tensor(self.dataset.targets)
                            self.dataset.targets = get_instance_noisy_label(noisy, zip(torch.from_numpy(self.dataset.data).float(), targets), targets, num_classes, feature_size, norm_std, seed)
                        else:
                            nlabels = len(self.dataset.targets)
                            nlabels_to_change = int(noisy * nlabels)
                            nclasses = len(np.unique(self.dataset.targets))
                            print('flipping ' + str(nlabels_to_change) + ' labels')

                            # Randomly choose which labels to change, get indices
                            labels_inds_to_change = np.random.choice(
                                np.arange(nlabels), nlabels_to_change, replace=False)

                            # Flip each of the randomly chosen labels
                            for _, label_ind_to_change in enumerate(labels_inds_to_change):

                                # Possible choices for new label
                                label_choices = np.arange(nclasses)

                                # Get true label to remove it from the choices
                                true_label = self.dataset.targets[label_ind_to_change]

                                # Remove true label from choices
                                label_choices = np.delete(
                                    label_choices,
                                    true_label)  # the label is the same as the index of the label

                                # Get new label and relabel the example with it
                                noisy_label = np.random.choice(label_choices, 1)
                                self.dataset.targets[label_ind_to_change] = noisy_label[0]
                        torch.save(self.dataset.targets, label_path)
                    self.dataset.targets = torch.load(label_path)
                self.targets = self.dataset.targets
            elif dataset == 'cifar10c':
                self.dataset = caltech.CIFAR10C(root='./data/CIFAR-10-C', name=corruption, transform=transform)
                self.targets = self.dataset.targets
            elif dataset == 'cinic10':
                cinic_mean = [0.47889522, 0.47227842, 0.43047404]
                cinic_std = [0.24205776, 0.23828046, 0.25874835]
                if not train:
                    self.dataset = datasets.ImageFolder('./data/CINIC-10' + '/test',
                        transform=transforms.Compose([transforms.ToTensor(),
                        transforms.Normalize(mean=cinic_mean,std=cinic_std)]))
                else:
                    self.dataset = datasets.ImageFolder('./data/CINIC-10' + '/train',
                        transform=transforms.Compose([transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, 4),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=cinic_mean,std=cinic_std)]))
                self.targets = self.dataset.targets
            elif dataset == 'cifar10a':
                normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])
                val_transform = transforms.Compose([
                    transforms.ToTensor(),
                    normalize,
                ])
                self.dataset = datasets.ImageFolder('./data/CIFAR-10-A-32/box', transform=val_transform)
            elif dataset == 'cifar100':
                self.dataset = datasets.CIFAR100(root='./data', train=train, transform=transform, download=True)
                self.targets = self.dataset.targets
            elif dataset == 'cifar100sup':
                self.dataset = datasets.CIFAR100(root='./data', train=train, transform=transform, download=True)
                self.dataset.targets = sparse2coarse(self.dataset.targets)
                self.targets = self.dataset.targets
            elif dataset == 'caltech256':
                self.dataset = caltech.Caltech256(root='./data', train=train, transform=transform)
                self.targets = np.array(self.dataset.targets)
            elif dataset == 'stanfordcars':
                if train:
                    self.dataset = caltech.StanfordCars(root='./data', split='train', transform=transform, download=True)
                else:
                    self.dataset = caltech.StanfordCars(root='./data', split='test', transform=transform, download=True)
                self.targets = np.array(self.dataset.targets)
            elif dataset == 'tinyimagenet':
                if train:
                    self.dataset = caltech.TinyImageNet(root='./data/tiny-imagenet-200', split='train', transform=transform)
                else:
                    self.dataset = caltech.TinyImageNet(root='./data/tiny-imagenet-200', split='val', transform=transform)
                self.targets = np.array(self.dataset.targets)
            elif dataset == 'waterbirds':
                from waterbird_loader import WB_MultiDomainLoader, WB_DomainTest
                root_path = './data/waterbirds'
                if train:
                    self.dataset = WB_MultiDomainLoader(dataset_root_dir=root_path, train_split='train')  # , 'D2'
                    self.groups = self.dataset.groups
                else:
                    self.dataset = WB_DomainTest(dataset_root_dir=root_path,split='test', group=group)
                self.targets = self.dataset.labels               
            elif dataset == 'cmnist':
                from colored_mnist import load_dataloaders
                if train:
                    saved_filename = 'train.pth'
                else:
                    saved_filename = 'test.pth'
                if not os.path.exists(os.path.join(args.save_dir, saved_filename)):
                    self.dataset = load_dataloaders(args, train)
                    torch.save(self.dataset, os.path.join(args.save_dir, saved_filename))
                else:
                    self.dataset = torch.load(os.path.join(args.save_dir, saved_filename))
                self.targets = self.dataset.targets
                self.groups = self.dataset.group_array
                self.targets_all = self.dataset.targets_all
                torch.save(self.groups, os.path.join(args.save_dir, 'groups.pth'))
            else:
                self.dataset = datasets.MNIST(root='./data', train=train, transform=transform, download=True)
                self.targets = self.dataset.targets

        def __getitem__(self, index):
            data, target = self.dataset[index]
            # Your transformations here (or set it in CIFAR10)
            return data, target, index

        def __len__(self):
            return len(self.targets)

def get_name(args):
    if args.subset_size < 1:
        if args.greedy:
            grd = f'_grd_w_{args.subset_criterion}_{args.metric}' 
            grd += f'_st_{args.st_grd}' if args.st_grd > 0 else ''
            grd += f'_feature' if args.cluster_features else ''
            grd += f'_ca' if args.cluster_all else ''
        elif args.barrier:
            grd = f'_barrier' 
            grd += f'_by_class' if args.by_class else ''
            grd += f'_start_20' if args.start_k else ''
        else: 
            grd = f'_rand_rsize_{args.random_subset_size}_{args.select_subset}'
        grd += f'_{args.select_subset}'
        if args.interp:
            grd += f'_interp_{args.grad_layer}'
    else:
        grd = ''
    grd += f'_start_{args.start_epoch}' if args.start_epoch > 0 else ''
    grd += f'_seed_{args.seed}_lr_{args.lr}'
    if not args.bn:
        grd += f'_no-bn'
    folder = f'/{args.dataset}'
    if args.dataset == 'mnist':
        args.no_transforms = True
    folder += f'_transform' if not args.no_transforms else ''
    return f'{folder}_{args.arch}_{args.subset_size}_{args.batch_size}{grd}_{args.run}'


def set_logger(args, logger):
    logging.basicConfig(
        filename=f"{args.save_dir}/output.log",
        filemode='a',
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)

    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(ch)

    return logger


def get_optimizer_and_scheduler(args, model):
    if (args.lr_schedule == 'mile' or args.lr_schedule == 'cosine') and args.gamma == -1:
        lr = args.lr
        b = 0.1
    else:
        lr = args.lr
        b = args.gamma

    args.logger.info(f'lr schedule: {args.lr_schedule}, epochs: {args.epochs}')
    args.logger.info(f'lr: {lr}, b: {b}')

    if args.ig == 'adam':
        args.logger.info('using adam')
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=args.weight_decay)
    elif args.ig == 'adagrad':
        optimizer = torch.optim.Adagrad(model.parameters(), lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    if args.lr_schedule == 'exp':
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=b)
    elif args.lr_schedule == 'step':
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=args.gamma)
    elif args.lr_schedule == 'mile':
        if args.dataset != 'caltech256':
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=np.array([60, 120, 160]), gamma=args.gamma)
        else:
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=np.array([20,30]), gamma=0.1)
    elif args.lr_schedule == 'cosine':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, eta_min=5e-4)
    elif args.lr_schedule == 'plateau': 
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, min_lr=5e-4, verbose=True, factor=args.gamma)
    else:  # constant lr
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs, gamma=1.0)

    return optimizer, lr_scheduler
