from SplitDataset import Partition
import torch
import numpy as np
import random
import torch.nn.functional as F

class MCMCSampler:

    def __init__(self, model, img_shape, sample_size, label, mean_vector, max_len=8192):
        """
        Inputs:
            model - Neural network to use for modeling E_theta
            img_shape - Shape of the images to model
            sample_size - Batch size of the samples
            max_len - Maximum number of data points to keep in the buffer
        """
        super().__init__()
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.label = torch.tensor([label for _ in range(self.sample_size)])
        self.mean_vector = mean_vector
        self.max_len = max_len
        self.examples = [(torch.rand((1,)+img_shape)*2-1) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps=60, step_size=10):
        """
        Function for getting a new batch of "fake" images.
        Inputs:
            steps - Number of iterations in the MCMC algorithm
            step_size - Learning rate nu in the algorithm above
        """
        # Choose 95% of the batch from the buffer, 5% generate from scratch
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size-n_new), dim=0)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(device)

        # Perform MCMC sampling
        inp_imgs =  MCMCSampler.generate_samples(self.model, inp_imgs, self.mean_vector.to(device), steps=steps, step_size=step_size)

        # Add new images to the buffer and remove old ones if needed
        self.examples = list(inp_imgs.to(torch.device("cpu")).chunk(self.sample_size, dim=0)) + self.examples
        self.examples = self.examples[:self.max_len]
        return [(img.to('cpu'), int(self.label[0])) for img in inp_imgs]

    @staticmethod
    def generate_samples(model, inp_imgs, mean_vector, steps=60, step_size=10, return_img_per_step=False):
        """
        Function for sampling images for a given model.
        Inputs:
            model - Neural network to use for modeling E_theta
            inp_imgs - Images to start from for sampling. If you want to generate new images, enter noise between -1 and 1.
            steps - Number of iterations in the MCMC algorithm.
            step_size - Learning rate nu in the algorithm above
            return_img_per_step - If True, we return the sample at every iteration of the MCMC
        """
        # Before MCMC: set model parameters to "required_grad=False"
        # because we are only interested in the gradients of the input.
        is_training = model.training
        model.eval()
        for p in model.parameters():
            p.requires_grad = False
        inp_imgs.requires_grad = True

        # Enable gradient calculation if not already the case
        had_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

        # We use a buffer tensor in which we generate noise each loop iteration.
        # More efficient than creating a new tensor every iteration.
        noise = torch.randn(inp_imgs.shape, device=inp_imgs.device)

        # List for storing generations at each step (for later analysis)
        imgs_per_step = []

        # Loop over K (steps)
        for _ in range(steps):
            # Part 1: Add noise to the input.
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            # Part 2: calculate gradients for the current input.
            out_imgs = -model(inp_imgs)
            # out = F.cross_entropy(out_imgs, label)
            out_imgs = torch.mean(out_imgs, dim=0)
            out = F.mse_loss(out_imgs, mean_vector)
            out.backward()
            
            inp_imgs.grad.data.clamp_(-0.03, 0.03) # For stabilizing and preventing too high gradients

            # Apply gradients to our current samples
            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())

        # Reactivate gradients for parameters for training
        for p in model.parameters():
            p.requires_grad = True
        model.train(is_training)

        # Reset gradient calculation to setting before this function
        torch.set_grad_enabled(had_gradients_enabled)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0)
        else:
            return inp_imgs

class MCMCSet:

    def __init__(self, data = []) -> None:
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

    def naive_core_set(self, number):
        index = list(np.random.randint(0, len(self.data), min(len(self.data), number)))
        return MCMCSet([self.data[i] for i in index])
    
    def weighted_core_set(self, number):
        number = min(len(self.data), number)
        P = torch.tensor([i + 1.0 for i in range(len(self.data))])
        index = list(torch.multinomial(P, number))
        return MCMCSet([self.data[i] for i in index])

    def combine(self, new_set):
        if isinstance(new_set, Partition):
            new_set = [new_set.data[i] for i in new_set.index]
        elif isinstance(new_set, MCMCSet):
            new_set = new_set.data
        return MCMCSet(self.data + list(new_set))
    
    def abstract(self, step = 10):
        new_data = []
        for i in range(0, len(self.data), step):
            new_sample = torch.zeros_like(self.data[0][0])
            for j in range(i, min(i + step, len(self.data))):
                new_sample += self.data[j][0] / (min(i + step, len(self.data)) - i)
            new_data.append((new_sample, self.data[0][1]))
        return MCMCSet(new_data)

