import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.autograd import Variable

from operation import apply_augment
from networks import get_model
from utils import PolicyHistory
from torchvision.utils import make_grid


def perturb_param(param, delta):
    amt = random.uniform(0, delta)
    if random.random() < 0.5:
        return max(0, param-amt)
    else:
        return min(1, param+amt)


class MixedAugment(nn.Module):
    def __init__(self, sub_policies, after_transforms, n_class, k_ops=1, sampling='prob', temp=1.0, save_dir=None, infer_n_class=None, resize=False):
        super(MixedAugment, self).__init__()
        # array [shearing, constract, flipping, ...]
        self.ops_names = sub_policies
        self.after_transforms = after_transforms
        self.k_ops = k_ops
        self.sampling = sampling
        self.temp = temp
        self.vis_class = n_class if infer_n_class is None else infer_n_class
        self.resize = resize

        self.history = PolicyHistory(sub_policies, save_dir, self.vis_class)

    def save_history(self, class2label=None):
        self.history.save(class2label)

    def plot_history(self):
        return self.history.plot()

    def stop_gradient(self, trans_image, magnitude):
        images = trans_image
        adds = 0

        images = images - magnitude
        adds = adds + magnitude
        images = images.detach() + adds
        return images

    def get_ada_mixed_images(self, images, magnitudes, weights, use_cuda):

        trans_image_list = []
        for i, image in enumerate(images):
            pil_img = transforms.ToPILImage()(image)
            # Prepare transformed image for mixing
            inner_list = []
            for k, ops_name in enumerate(self.ops_names):
                trans_image = apply_augment(pil_img, ops_name, magnitudes[i][k])
                trans_image = self.after_transforms(trans_image)
                trans_image = self.stop_gradient(trans_image.cuda(), magnitudes[i][k])
                inner_list.append(trans_image)

            trans_image_list.append(torch.stack(inner_list, dim=0))

        all_trans_image = torch.stack(trans_image_list, dim=0)

        # batched_images = all_trans_image.reshape(len(images), len(self.ops_names), 3, 32, 32) # hard coded 32
        # batched_images = torch.split(all_trans_image, len(self.ops_names))
        sm_weights = torch.nn.functional.softmax(weights, dim=-1)
        batch = torch.einsum("ji,jiklm->jklm", (sm_weights, all_trans_image)).reshape(all_trans_image.shape[0], all_trans_image.shape[2], all_trans_image.shape[3], all_trans_image.shape[4])

        # batch = torch.stack(batch, dim=0)

        return batch

    def get_vec_mixed_images(self, images, magnitudes, weights, use_cuda):
        sum_images = []
        for image in images:
            pil_img = transforms.ToPILImage()(image)
            trans_images_list = []
            # Prepare transformed image for mixing
            for i, ops_name in enumerate(self.ops_names):
                trans_image = apply_augment(pil_img, ops_name, magnitudes[i])
                trans_image = self.after_transforms(trans_image)
                trans_images_list.append(trans_image)

            sm_weights = torch.nn.functional.softmax(weights, dim=-1)
            if use_cuda:
                sum_img = sum(w * self.stop_gradient(trans_image.cuda(), m)
                              for (m, w, trans_image) in zip(magnitudes, sm_weights, trans_images_list))
            else:
                sum_img = sum(w * self.stop_gradient(trans_image, m)
                              for (m, w, trans_image) in zip(magnitudes, sm_weights, trans_images_list))
            sum_images.append(sum_img)

        return torch.stack(sum_images, dim=0)

    def get_vec_mixed_features(self, images, magnitudes, weights, model, use_cuda):
        """Return the mixed latent feature

        Args:
            images ([Tensor]): [description]
            magnitudes ([Tensor]): [description]
            weights ([Tensor]): [description]
            model ([nn.Model]): [Feature extraction model]
            use_cuda ([bool]): [description]

        Returns:
            [Tensor]: Mixed latent feature
        """
        adj_magnitudes = magnitudes.clamp(0, 1)
        trans_image_list = []
        for image in images:
            pil_img = transforms.ToPILImage()(image)
            # Prepare transformed image for mixing
            for i, ops_name in enumerate(self.ops_names):
                trans_image = apply_augment(pil_img, ops_name, adj_magnitudes[i])
                trans_image = self.after_transforms(trans_image)
                trans_image = self.stop_gradient(trans_image.cuda(), adj_magnitudes[i])
                trans_image_list.append(trans_image)

        all_trans_image = torch.stack(trans_image_list, dim=0)
        all_features = model.extract_feature(all_trans_image)
        batched_features = all_features.reshape(len(images), len(self.ops_names), -1)
        sm_weights = torch.nn.functional.softmax(weights, dim=-1)
        batch = sm_weights.matmul(batched_features)
        return batch

    def get_ada_mixed_features(self, images, magnitudes, weights, model, use_cuda):
        """Return the mixed latent feature

        Args:
            images ([Tensor]): [description]
            magnitudes ([Tensor]): [description]
            weights ([Tensor]): [description]
            model ([nn.Model]): [Feature extraction model]
            use_cuda ([bool]): [description]

        Returns:
            [Tensor]: Mixed latent feature
        """
        trans_image_list = []
        for i, image in enumerate(images):
            pil_img = transforms.ToPILImage()(image)
            # Prepare transformed image for mixing
            for k, ops_name in enumerate(self.ops_names):
                trans_image = apply_augment(pil_img, ops_name, magnitudes[i][k])
                trans_image = self.after_transforms(trans_image)
                trans_image = self.stop_gradient(trans_image.cuda(), magnitudes[i][k])
                trans_image_list.append(trans_image)

        all_trans_image = torch.stack(trans_image_list, dim=0)
        all_features = model.extract_feature(all_trans_image)
        batched_features = all_features.reshape(len(images), len(self.ops_names), -1)
        sm_weights = torch.nn.functional.softmax(weights, dim=-1)
        batch = [w.matmul(f) for w, f in zip(sm_weights, batched_features)]
        batch = torch.stack(batch, dim=0)
        return batch

    def get_proj_mixed_features(self, images, model, projection, use_cuda):
        if use_cuda:
            outputs = projection(model.extract_feature(images.cuda()))
        else:
            outputs = projection(model.extract_feature(images))

        magnitudes, weights = torch.split(outputs, len(self.ops_names), dim=1)
        magnitudes = torch.sigmoid(magnitudes)

        batch = self.get_ada_mixed_features(images, magnitudes, weights, model, use_cuda)

        return batch

    def get_cnn_mixed_features(self, images, model, aug_cnn, use_cuda):

        outputs = aug_cnn(images)
        magnitudes, weights = torch.split(outputs, 2, dim=0)

        return self.get_ada_mixed_features(images, magnitudes, weights, model, use_cuda)

    def get_vec_aug_images(self, images, magnitudes, weights, use_cuda):
        adj_magnitudes = magnitudes.clamp(0, 1)
        trans_images = []
        for image in images:
            trans_image = transforms.ToPILImage()(image)
            if self.sampling == 'prob':
                idxs = torch.multinomial(weights, self.k_ops)
            elif self.sampling == 'max':
                idxs = torch.topk(weights, self.k_ops)[1]
            for i in idxs:
                trans_image = apply_augment(
                    trans_image, self.ops_names[i], adj_magnitudes[i])
            trans_image = self.after_transforms(trans_image)
            trans_images.append(trans_image)

        if use_cuda:
            batch = torch.stack(trans_images, dim=0).cuda()
        else:
            batch = torch.stack(trans_images, dim=0)

        return batch

    def get_ada_aug_images(self, images, magnitudes, weights, use_cuda, targets=None, writer=None, step=None, delta=0.3, kops=3):
        # visualization
        if kops > 0:
            if targets is not None:
                for k in range(self.vis_class):
                    idxs = (targets == k).nonzero().squeeze()
                    mean_lambda = magnitudes[idxs].mean(0).detach().cpu().tolist()
                    mean_p = weights[idxs].mean(0).detach().cpu().tolist()
                    std_lambda = magnitudes[idxs].std(0).detach().cpu().tolist()
                    std_p = weights[idxs].std(0).detach().cpu().tolist()
                    self.history.add(k, mean_lambda, mean_p, std_lambda, std_p)

            display_ops = []
            trans_images = []
            if self.sampling == 'prob':
                idx_matrix = torch.multinomial(weights, kops)
            elif self.sampling == 'max':
                idx_matrix = torch.topk(weights, kops, dim=1)[1]

            for i, image in enumerate(images):
                trans_image = transforms.ToPILImage()(image)

                # if step is not None:
                #     writer.add_image('before aug', np.transpose(np.array(trans_image), (2,1,0)), i)

                for idx in idx_matrix[i]:
                    trans_image = apply_augment(trans_image, self.ops_names[idx], perturb_param(magnitudes[i][idx], delta))
                    # trans_image = apply_augment(trans_image, self.ops_names[idx], magnitudes[i][idx])
                    # display_ops.append(f'{i}: {self.ops_names[idx]}, mag: {magnitudes[i][idx]}')

                # if step is not None:
                #     writer.add_image('after aug', np.transpose(np.array(trans_image), (2,1,0)), i)

                trans_image = self.after_transforms(trans_image)

                # if step is not None:
                #     writer.add_image('after transforms', np.array(trans_image), i)
                trans_images.append(trans_image)

            # if step is not None:
            #     writer.add_text('operations', '\n'.join(display_ops))
        else:
            trans_images = []
            for i, image in enumerate(images):
                trans_image = transforms.ToPILImage()(image)
                trans_image = self.after_transforms(trans_image)
                trans_images.append(trans_image)

        if use_cuda:
            batch = torch.stack(trans_images, dim=0).cuda()
        else:
            batch = torch.stack(trans_images, dim=0)

        return batch

    def get_proj_aug_images(self, images, model, projection, use_cuda, targets, writer=None, step=None):
        if self.resize:
            resize_imgs = F.interpolate(images, size=32)
        else:
            resize_imgs = images

        if use_cuda:
            outputs = projection(model.extract_feature(resize_imgs.cuda()))
        else:
            outputs = projection(model.extract_feature(resize_imgs))

        magnitudes, weights = torch.split(outputs, len(self.ops_names), dim=1)
        magnitudes = torch.sigmoid(magnitudes)
        weights = torch.nn.functional.softmax(weights/self.temp, dim=-1)

        batch = self.get_ada_aug_images(images, magnitudes, weights, use_cuda, targets, writer=writer, step=step, kops=self.k_ops)

        return batch

    def get_proj_mix_images(self, images, model, projection, use_cuda, targets, writer=None, step=None):
        if self.resize:
            resize_imgs = F.interpolate(images, size=32)
        else:
            resize_imgs = images

        if use_cuda:
            outputs = projection(model.extract_feature(resize_imgs.cuda()))
        else:
            outputs = projection(model.extract_feature(resize_imgs))

        magnitudes, weights = torch.split(outputs, len(self.ops_names), dim=1)
        magnitudes = torch.sigmoid(magnitudes)
        weights = torch.nn.functional.softmax(weights/self.temp, dim=-1)

        batch = self.get_ada_mixed_images(images, magnitudes, weights, use_cuda)

        return batch

    def get_cnn_aug_images(self, images, aug_cnn, use_cuda, targets):
        outputs = aug_cnn(images.cuda()) if use_cuda else aug_cnn(images)
        magnitudes, weights = torch.split(outputs, 2, dim=0)
        return self.get_ada_aug_images(images, magnitudes, weights, use_cuda, targets, kops=self.k_ops)


class Network(nn.Module):
    def __init__(
            self,
            model_name,
            num_classes,
            n_channel,
            use_cuda,
            use_parallel,
            temperature,
            criterion,
            latent=False,
            search=False,
            writer=None):
        super(Network, self).__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self.n_channel = n_channel

        self.use_cuda = use_cuda
        self.use_parallel = use_parallel
        self.temperature = torch.tensor(temperature).cuda() if use_cuda else torch.tensor(temperature)
        self._criterion = criterion
        self.latent = latent
        self.search = search
        self.writer = writer

        self.model = self.create_model()

    def add_augment_agent(
                        self,
                        sub_policies,
                        after_transforms,
                        aug_mode,
                        search=True,
                        ops_weights=None,
                        magnitudes=None,
                        projection=None,
                        aug_cnn=None,
                        k_ops=1,
                        sampling='prob',
                        temperature=1.0,
                        save_dir=None,
                        infer_n_class=None,
                        resize=False,
                        n_proj_layer=0):

        assert len(sub_policies) > 0  # at least one augmentation to use augment agent
        self.add_aug = True
        self.after_transforms = after_transforms
        self.sub_policies = sub_policies
        self.aug_mode = aug_mode
        self.n_proj_layer = n_proj_layer

        if aug_mode == 'cnn':
            if not search:
                import IPython; IPython.embed(); exit(1)
                #  TODO load cnn parameters
            else:
                self.aug_cnn = get_model(
                                self.model_name,
                                2*len(sub_policies),  # output w and m for each operation
                                self.n_channel,
                                self.use_cuda,
                                self.use_parallel)
                self._augment_parameters = self.aug_cnn.parameters()
        elif aug_mode == 'projection':
            if self.n_proj_layer > 0:
                model = [nn.Linear(self.model.fc.in_features, 128), nn.ReLU()]
                for i in range(self.n_proj_layer):
                    model.append(nn.Linear(128, 128))
                    model.append(nn.ReLU())
                model.append(nn.Linear(128, 2*len(sub_policies)))
            else:
                model = [nn.Linear(self.model.fc.in_features, 2*len(sub_policies))]

            self.projection = nn.Sequential(*model).cuda()
            self._augment_parameters = self.projection.parameters()
        elif aug_mode == 'vector':
            self._initialize_augment_parameters()
            if not search:  # load weights and magnitude vector
                assert ops_weights is not None and magnitudes is not None
                assert len(sub_policies) == len(ops_weights) and len(sub_policies) == len(magnitudes)
                self.update_augment_parameters(ops_weights, magnitudes)

        self.mix_augment = MixedAugment(sub_policies, after_transforms, self.num_classes,
                                        k_ops, sampling, temperature, save_dir=save_dir, 
                                        infer_n_class=infer_n_class, resize=resize)

    def update_augment_parameters(self, ops_weights, magnitudes):
        if self.use_cuda:
            self.magnitudes = Variable(torch.tensor(
                magnitudes).cuda(), requires_grad=False)
            self.ops_weights = Variable(torch.tensor(
                ops_weights).cuda(), requires_grad=False)
        else:
            self.magnitudes = Variable(
                torch.tensor(magnitudes), requires_grad=False)
            self.ops_weights = Variable(
                    torch.tensor(ops_weights), requires_grad=False)

    def _initialize_augment_parameters(self):
        num_ops = len(self.sub_policies)  # TODO here sub policy is 1D array
        # TODO later need to initialize a larger Variable tensor for more variables
        if self.use_cuda:
            self.magnitudes = Variable(
                0.5*torch.ones(num_ops).cuda(), requires_grad=True)
            self.ops_weights = Variable(
                1e-3*torch.ones(num_ops).cuda(), requires_grad=True)
        else:
            self.magnitudes = Variable(
                0.5*torch.ones(num_ops), requires_grad=True)
            self.ops_weights = Variable(
                1e-3*torch.ones(num_ops), requires_grad=True)

        self._augment_parameters = [
            self.magnitudes,
            self.ops_weights
        ]

    def set_augmenting(self, value):
        assert value in [False, True]
        self.add_aug = value

    def set_search(self, value):
        assert value in [False, True]
        self.search = value

    def create_model(self):
        return get_model(
            self.model_name,
            self.num_classes,
            self.n_channel,
            self.use_cuda,
            self.use_parallel)

    def new(self):
        network_new = Network(
            self.model_name,
            self.num_classes,
            self.n_channel,
            self.use_cuda,
            self.use_parallel,
            self.temperature.detach().item(),
            self._criterion,
            self.latent,
            True)

        network_new.add_augment_agent(self.sub_policies, self.after_transforms, self.aug_mode, search=True)

        return network_new

    def update_temperature(self, value):
        self.temperature.data.sub_(self.temperature.data - value)

    def augment_parameters(self):
        return self._augment_parameters

    def genotype(self, valid_queue=None):
        def _parse(magnitudes, ops_weights):
            gene = []
            for idx, sub_policy in enumerate(self.sub_policies):
                gene.append(tuple([sub_policy,
                                   magnitudes[idx].data.detach().item(),
                                   ops_weights[idx].data.detach().item()]))
            return gene
        if self.aug_mode == 'vector':
            magnitudes = self.magnitudes.clamp(0, 1)
            ops_weights = torch.nn.functional.softmax(self.ops_weights, dim=-1)
            return _parse(magnitudes, ops_weights)
        else:
            return None

    def foward_latent_search(self, images):
        if self.aug_mode == 'vector':
            mix_feat = self.mix_augment.get_vec_mixed_features(
                images, self.magnitudes, self.ops_weights, use_cuda=self.use_cuda, model=self.model)
        elif self.aug_mode == 'projection':
            mix_feat = self.mix_augment.get_proj_mixed_features(
                images, use_cuda=self.use_cuda, model=self.model, projection=self.projection)
        elif self.aug_mode == 'cnn':
            mix_feat = self.mix_augment.get_cnn_mixed_features(
                images, use_cuda=self.use_cuda, model=self.model, aug_cnn=self.aug_cnn)
        output = self.model.classify(mix_feat)
        return output

    def forward_input_search(self, images):
        if self.aug_mode == 'projection':
            mix_image = self.mix_augment.get_proj_mix_images(images, model=self.model, projection=self.projection, use_cuda=self.use_cuda, targets=None)
        elif self.aug_mode == 'vector':
            mix_image = self.mix_augment.get_vec_mixed_images(
                images, self.magnitudes, self.ops_weights, use_cuda=self.use_cuda, model=self.model)
        output = self.model(mix_image)
        return output

    def get_aug_images(self, images, targets, step=None):

        if self.aug_mode == 'vector':
            aug_image = self.mix_augment.get_vec_aug_images(
                images, self.magnitudes, self.ops_weights, use_cuda=self.use_cuda)
        elif self.aug_mode == 'projection':
            aug_image = self.mix_augment.get_proj_aug_images(
                images, use_cuda=self.use_cuda, model=self.model, projection=self.projection, targets=targets, writer=self.writer, step=step)
        elif self.aug_mode == 'cnn':
            aug_image = self.mix_augment.get_cnn_aug_images(
                images, use_cuda=self.use_cuda, model=self.model, aug_cnn=self.aug_cnn, targets=targets)

        if step is not None:
            self.writer.add_image('train_aug', make_grid(aug_image, normalize=True), step)

        return aug_image

    def forward_aug_input(self, images, targets, step):
        aug_image = self.get_aug_images(images, targets, step)
        output = self.model(aug_image)
        return output

    def foward_search(self, images, latent):
        if latent:
            return self.foward_latent_search(images)
        else:
            return self.forward_input_search(images)

    def forward_test(self, images):
        return self.model(images)

    def forward(self, images, targets=None, step=None):
        if self.add_aug:
            if self.search:
                return self.foward_search(images, self.latent)
            else:
                return self.forward_aug_input(images, targets, step)
        else:
            return self.forward_test(images)

    def _loss(self, images, target):
        logits = self(images)
        return self._criterion(logits, target)
