__author__ = "Anon"
__version__ = "0.1"

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from models import inn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import _LRScheduler
import random
import itertools
from torch.utils.data._utils.collate import default_collate

NORM = {'cifar100': { '256': [(0.5071, 0.4865, 0.4409),(0.2673, 0.2564, 0.2762)]},
        'cifar10': {'256': [(0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)]},
        'stl10': {'256': [(0.4467107057571411, 0.4398099184036255, 0.4066465198993683), (0.26034098863601685, 0.2565772533416748, 0.2712674140930176)]},
        'cub200': {'256': [(0.48562151193618774, 0.49943259358406067, 0.43239057064056396), (0.22643816471099854, 0.22179797291755676, 0.2606353461742401)]},
        'pets': {'256': [(0.47828713059425354, 0.4458635747432709, 0.3956933319568634), (0.26269590854644775, 0.25722986459732056, 0.26532548666000366)]},
        'inn_cars': {'256': [(0.4708038866519928, 0.4602188467979431, 0.4550218880176544), (0.29070574045181274, 0.2897859811782837, 0.2982892692089081)]},
        'inn_stl10': {'256': [(0.4467107057571411, 0.4398099184036255, 0.4066465198993683), (0.26034098863601685, 0.2565772533416748, 0.2712674140930176)]},
        'inn_cifar100': { '256': [(0.5071, 0.4865, 0.4409),(0.2673, 0.2564, 0.2762)]},
        'inn_cifar10': {'256': [(0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)]},
        'inn_cub200':{'256': [(0.485, 0.499, 0.432), (0.226, 0.221, 0.260)]},
        'inn_cub20':{'256': [(0.485, 0.499, 0.432), (0.226, 0.221, 0.260)]},
        'cub20':{'256': [(0.485, 0.499, 0.432), (0.226, 0.221, 0.260)]},
        'inn_opets': {'256': [(0.478, 0.445, 0.395), (0.262, 0.257, 0.265)]},
        'opets': {'256': [(0.478, 0.445, 0.395), (0.262, 0.257, 0.265)]},
        'inn_tiny':{'256': [(0.480, 0.448, 0.397), (0.276, 0.268, 0.281)]},
        'inn_bmw10': {'256': [(0.45809218287467957, 0.4521467685699463, 0.45898690819740295), (0.2697070837020874, 0.2785671651363373, 0.2733687460422516)]},
        'bmw10': {'256': [(0.45809218287467957, 0.4521467685699463, 0.45898690819740295), (0.2697070837020874, 0.2785671651363373, 0.2733687460422516)]},
        }

def reload_dict(model, state_dict):
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict, strict=False)
    return model


def get_net(net_name, num_classes, pretrained, load_weights_from=None, opts=None):
    if net_name.lower() == 'reduced_resnet18':
         from models.small import resnet18
         model = resnet18(num_classes=num_classes)
         if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            model.load_state_dict(state_dict)
         return model
    elif net_name.lower() == 'resnet18':
        from models.resnet import resnet18
        model = resnet18(num_classes=num_classes)
        if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            model.load_state_dict(state_dict)
        return model
    elif net_name.lower() == 'resnet50':
        from models.resnet import resnet50
        model = resnet50(num_classes=num_classes)
        if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            model.load_state_dict(state_dict)
        return model
    elif net_name.lower() == 'inn50':
        from models.inn import resnet50
        model = resnet50(pretrained=pretrained, num_classes=num_classes)
        if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            try:
                model.load_state_dict(state_dict, strict=False)
            except:
                model = reload_dict(model, state_dict)
        return model
    elif net_name.lower() == 'inn18':
        from models.inn import resnet18
        model = inn.resnet18(pretrained=pretrained, num_classes=num_classes)
        if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            try:
                model.load_state_dict(state_dict, strict=False)
            except:
                model = reload_dict(model, state_dict)
        return model
    elif net_name.lower() == 'modified_inn18':
        from models.modified_inn import resnet18
        model = resnet18(pretrained=pretrained, num_classes=num_classes, feat_dim=opts['Feat_dim'])
        if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            try:
                model.load_state_dict(state_dict)
            except:
                model = reload_dict(model, state_dict)
        return model
    elif net_name.lower() == 'modified_resnet18':
        from models.modified_resnet import resnet18
        model = resnet18(pretrained=pretrained, num_classes=num_classes, feat_dim=opts['Feat_dim'])
        if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            try:
                model.load_state_dict(state_dict)
            except:
                model = reload_dict(model, state_dict)
        return model
    elif net_name.lower() == 'inn_vgg11' or net_name.lower() == 'reduced_inn_vgg11':
        from models.inn_vgg import VGG
        model = VGG(vgg_name='VGG11', num_classes=num_classes)
        if load_weights_from is not None:
            model.load_state_dict(torch.load(load_weights_from))
        return model
    elif net_name.lower() == 'vgg11':
        from models.vgg import VGG
        model = VGG(vgg_name='VGG11', num_classes=num_classes)
        if load_weights_from is not None:
            model.load_state_dict(torch.load(load_weights_from))
        return model
    elif net_name.lower() == 'reduced_inn18':
        from models.reduced.inn_resnet import resnet18
        if pretrained:
            raise NotImplementedError('Reduced models are not pre-trained')
        model = resnet18(class_in_activation=opts['class_in_activation'], num_classes=opts['NUM_CLASSES'])
        if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            try:
                model.load_state_dict(state_dict, strict=False)
            except:
                model = reload_dict(model, state_dict)
        return model
    elif net_name.lower() == 'modified_reduced_inn18':
        from models.reduced.modified_inn_resnet import resnet18
        model = resnet18(num_classes=num_classes, class_in_activation = opts['class_in_activation'], feat_dim=opts['Feat_dim'])
        if load_weights_from is not None:
            state_dict = torch.load(load_weights_from, map_location=torch.device('cpu'))
            try:
                model.load_state_dict(state_dict)
            except:
                model = reload_dict(model, state_dict)
        return model
    else:
        raise NotImplementedError('ARCH {} currently not supported!'.format(net_name))


class CustomTensorDataset(torch.utils.data.Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)

class BinImageIntensities(object):
    """Bins the image intensity values.

    Args:
        num_bins (float): The number of bins to be created for an image intensity values
    """

    def __init__(self, num_bins=256):
        self.num_bins = num_bins
        self.num_intensities = 256

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be binned.

        Returns:
            PIL Image: Binned image
        """
        bin_size = int(self.num_intensities/self.num_bins)
        return Image.fromarray((np.floor(np.asarray(img, dtype=np.uint8)/bin_size)*bin_size).astype(np.uint8))

    def __repr__(self):
        return self.__class__.__name__ + '(bin_range={})'.format(self.num_bins)


class RepeatChannels(object):
    """Create negative images .
    """

    def __init__(self):
        self.repeat_times = 3

    def __call__(self, img):
        """
        Args:
            img (tensor): tensor to be negated.

        Returns:
            tensor: negated tensor
        """
        return img.repeat(self.repeat_times, 1, 1)

    def __repr__(self):
        return self.__class__.__name__ + '(repeat_times={})'.format(self.repeat_times)


class NegativeImages(object):
    """Create negative images .
    """

    def __init__(self):
        self.max_intensity = 1

    def __call__(self, img):
        """
        Args:
            img (tensor): tensor to be negated.

        Returns:
            tensor: negated tensor
        """
        return 1 - img

    def __repr__(self):
        return self.__class__.__name__ + '(bin_range={})'.format(self.num_bins)


class RandomSwitchChannels(object):
    """Bins the image intensity values.

    Args:
        new_order (list): The order to place the RGB channels ex: [1, 2, 0] = GBR
    """

    def __init__(self):
        self.orders = list(itertools.permutations([0, 1, 2]))
        # self.orders.remove((0, 1, 2))

    def __call__(self, img):
        """
        Args:
            img (Tensor): Normalised tensor C, H, W

        Returns:
            Tensor: tensor with switched channels
        """
        print(self.orders)
        random_ordering = self.orders[random.randint(0, len(self.orders)-1)]
        tmp_tensor = img.clone()
        img[0] = tmp_tensor[random_ordering[0]]
        img[1] = tmp_tensor[random_ordering[1]]
        img[2] = 255-tmp_tensor[random_ordering[2]]
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(new order={})'.format(self.new_order)


class SwitchChannels(object):
    """Bins the image intensity values.

    Args:
        new_order (list): The order to place the RGB channels ex: [1, 2, 0] = GBR
    """

    def __init__(self, new_order):
        self.new_order = new_order

    def __call__(self, img):
        """
        Args:
            img (Tensor): Normalised tensor C, H, W

        Returns:
            Tensor: tensor with switched channels
        """
        ret_img = img.clone()
        ret_img[0] = img[self.new_order[0]]
        ret_img[1] = img[self.new_order[1]]
        ret_img[2] = img[self.new_order[2]]
        return ret_img

    def __repr__(self):
        return self.__class__.__name__ + '(new order={})'.format(self.new_order)

def innet_collate(batch):
    batch = default_collate(batch)
    assert len(batch) == 4
    # image, quantised color, greyscale, edges, rotation_labels, supervised_target
    batch_size, num_samples, c, h, w = batch[0].size()
    batch[0] = batch[0].view([batch_size*num_samples, c, h, w])
    batch[1] = batch[1].view([batch_size*num_samples, -1])
    batch[2] = batch[2].view([batch_size*num_samples]).squeeze()
    batch[3] = batch[3].view([batch_size*num_samples]).squeeze()
    # print(batch[0].shape, batch[1].shape, batch[2].shape)
    return batch

class INNTransform(object):
    def __init__(self, start_transform, pos_conf_range=[0.9, 1.0], num_negatives=None, neg_conf_factor=None):
        self.start_transform = start_transform
        self.pos_conf_range = pos_conf_range
        self.neg_conf_factor = neg_conf_factor
        self.num_negatives = num_negatives

    def __call__(self, img, y, num_classes):
        img = self.start_transform(img)
        #Assume the img is now resized, cropped, flipped, to_tensor, normalised etc.
        imgs = [img]
        ys = [y]
        bools = [1]
        y_original = [y]
        confidence = torch.FloatTensor(self.num_negatives+1).uniform_(self.pos_conf_range[0], self.pos_conf_range[1])
        y_negs = random.sample(range(0, num_classes), self.num_negatives+1)
        if y in y_negs:
            y_negs.remove(y)

        for i in range(self.num_negatives):
            tmp_img = torch.clone(img)
            imgs.append(tmp_img)
            ys.append(y_negs[i])
            bools.append(0)
            y_original.append(y)

        y_in_tensor = torch.LongTensor(ys)
        bools = torch.LongTensor(bools)
        y_original = torch.LongTensor(y_original)

        label_shape = torch.Size((y_in_tensor.size(0), num_classes))
        if self.neg_conf_factor is not None:
            y_dist = torch.rand(label_shape, device=img.device)/self.neg_conf_factor #ensuring all values are in [0, 0.25)
        else:
            y_dist = torch.zeros(label_shape, device=img.device)
        y_dist.scatter_(1, y_in_tensor.unsqueeze(1), confidence.unsqueeze(1))

        return torch.stack(imgs, dim=0), y_dist, bools, y_original

class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """

    def __init__(self, optimizer, total_iters, last_epoch=-1):
        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

class ExperimentSettings:
    def __init__(self, settings):
        self.settings = settings

    def dump(self, path):
        import json
        with open(path + '/settings.json', 'w') as f:
            json.dump(self.settings, f, indent=4)

def evaluate(net, dataloader, criterion):
    net.eval()
    correct = 0.0
    loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            x, y = data
            outputs = net(x.cuda())
            _, preds = outputs.max(1)
            loss += criterion(outputs, y.cuda())
            correct += preds.eq(y.cuda()).sum()

    return correct.float() / len(dataloader.dataset), loss.float() / len(dataloader)


def show(tensors):
    img = make_grid(tensors)
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)), interpolation='nearest')
    plt.show()


import numpy as np
from PIL import Image

class PermutePixels(object):

    def __init__(self):
        pass

    def __call__(self, img):
        """
        Args:
            img (Tensor): Normalised tensor C, H, W

        Returns:
            Tensor: tensor with permuted pixels
        """
        rand_perm = torch.randperm(img.shape[1] * img.shape[2]).view(1, img.shape[1], img.shape[2])
        r = img[0].view(-1)[rand_perm].view(img.shape[1], img.shape[2])
        g = img[1].view(-1)[rand_perm].view(img.shape[1], img.shape[2])
        b = img[2].view(-1)[rand_perm].view(img.shape[1], img.shape[2])
        p_img = img.clone()
        p_img[0] = r
        p_img[1] = g
        p_img[2] = b
        return p_img

    def __repr__(self):
        return self.__class__.__name__ + 'Rand perm!'


def accuracy_cls(model, input, target, num_classes=1000):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        batch_size = input.size(0)
        y_in = torch.rand([batch_size, num_classes]).cuda()
        try:
           _, output, _ = model(input, y_in)
        except:
            output = model(input)
        vals, preds = output.max(1)
        correct = preds.eq(target).sum()
    return correct


def accuracy_inn(model, inputs, target, y_dist, num_classes=1000):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        batch_size = inputs.size(0)
        pred_matrix = torch.zeros([batch_size, num_classes]).cuda(device=inputs.device)
        for i in range(num_classes):
            y_in = y_dist[i].repeat(batch_size).view(batch_size, -1)
            output = model(inputs.cuda(), y_in.cuda())
            pred_matrix[:, i] = output.cpu()[:, 1]
        vals, pred = pred_matrix.max(1)
        correct = pred.t().eq(target).sum()
        del vals, pred, pred_matrix
        return correct


def accuracy_inn_v2(model, input, target, y_dist, num_classes=1000):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    model.eval()
    with torch.no_grad():
        batch_size = input.size(0)
        en_input = input.repeat_interleave(num_classes, dim=0)
        y_in = y_dist.repeat([batch_size, 1])
        output = model(en_input.cuda(), y_in.cuda())
        output = output[:, 1]
        output = output.view([batch_size, num_classes])
        vals, preds = output.max(1)
        correct = preds.eq(target).sum()

    return correct


def validate(model, val_loader, opts):
    # switch to evaluate mode
    inn_top1, inn_top5, cls_top1, cls_top5 = 0, 0, 0, 0
    model.eval()
    total_inn_matrix = None
    total_cls_matrix = None
    with torch.no_grad():
        label_shape = torch.Size((opts['NUM_CLASSES'], opts['NUM_CLASSES']))
        if 'neg_red_factor' in opts and opts['neg_red_factor'] is not None:
            y_dist = torch.rand(label_shape).cuda() / opts['neg_red_factor'] # ensuring all values are in [0, 1/neg_red_factor)
        else:
            y_dist = torch.zeros(label_shape).cuda()
        indicator = torch.LongTensor([i for i in range(opts['NUM_CLASSES'])])
        y_dist.scatter_(1, indicator.unsqueeze(1).cuda(device=y_dist.device),
                        torch.cat(opts['NUM_CLASSES'] * [torch.FloatTensor([1])]).unsqueeze(1).cuda(device=y_dist.device))
        for i, (images, ys) in enumerate(val_loader):
                images = images.cuda(y_dist.device, non_blocking=True)
                ys = ys.cuda(y_dist.device, non_blocking=True)
                # measure accuracy and record loss
                if opts['weight']['cls'] == 0:
                    if opts['NUM_CLASSES'] > 25:
                        inn_top1 += accuracy_inn(model, images, ys, y_dist, num_classes=opts['NUM_CLASSES'])
                    else:
                        inn_top1 += accuracy_inn_v2(model, images, ys, y_dist, num_classes=opts['NUM_CLASSES'])
                    cls_top1 += torch.tensor(0)
                else:
                    cls_top1 += accuracy_cls(model, images, ys, num_classes=opts['NUM_CLASSES'])
                    inn_top1 += torch.tensor(0)

    inn_top1 = inn_top1.float()/len(val_loader.dataset)
    inn_top5 /= len(val_loader.dataset)
    cls_top1 = cls_top1.float()/len(val_loader.dataset)
    cls_top5 /= len(val_loader.dataset)
    return inn_top1, inn_top5, cls_top1, cls_top5, total_inn_matrix, total_cls_matrix
