import numpy as np
from typing import Tuple, Callable, List
import itertools

import torch
from torch import optim, Tensor
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential
import torch.nn.functional as F

from model.manifold.transform import DataTransformer
from utils import AverageMeter
from dataset import Adult, Compas, FeatIndex, CompData


class Residual(Module):
    """ Residual layer for Synthesizer """

    def __init__(self, i, o):
        super(Residual, self).__init__()
        self.fc = Linear(i, o)
        self.bn = BatchNorm1d(o)
        self.relu = ReLU()

    def forward(self, input_):
        """ Apply the Residual layer to the `input_` """
        out = self.fc(input_)
        out = self.bn(out)
        out = self.relu(out)
        return torch.cat([out, input_], dim=1)


class Generator(Module):
    """ Generator for Synthesizer """

    def __init__(self, embedding_dim, generator_dim, data_dim):
        super(Generator, self).__init__()
        dim = embedding_dim
        seq = []
        for item in list(generator_dim):
            seq += [Residual(dim, item)]
            dim += item
        seq.append(Linear(dim, data_dim))
        self.seq = Sequential(*seq)

    def forward(self, input_):
        """Apply the Generator to the `input_`."""
        data = self.seq(input_)
        return data


class Discriminator(Module):
    """Discriminator for the CTGANSynthesizer."""

    def __init__(self, input_dim, discriminator_dim, pac=10):
        super(Discriminator, self).__init__()
        dim = input_dim * pac
        self.pac = pac
        self.pacdim = dim
        seq = []
        for item in list(discriminator_dim):
            seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
            dim = item

        seq += [Linear(dim, 1)]
        self.seq = Sequential(*seq)

    def calc_gradient_penalty(self, real_data, fake_data, device, pac=10, lambda_=10):
        """ Compute the gradient penalty """

        alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
        alpha = alpha.repeat(1, pac, real_data.size(1))
        alpha = alpha.view(-1, real_data.size(1))

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)

        disc_interpolates = self(interpolates)

        gradients = torch.autograd.grad(
            outputs=disc_interpolates, inputs=interpolates,
            grad_outputs=torch.ones(disc_interpolates.size(), device=device),
            create_graph=True, retain_graph=True, only_inputs=True,
        )[0]

        gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
        gradient_penalty = ((gradients_view) ** 2).mean() * lambda_

        return gradient_penalty

    def forward(self, input_):
        """Apply the Discriminator to the `input_`."""
        assert input_.size()[0] % self.pac == 0
        return self.seq(input_.view(-1, self.pacdim))


class BaseSynthesizer():
    def save(self, path: str):
        """ Save generator for inference purpose """
        assert hasattr(self, "_generator"), "No generator in Synthesizer"
        print("=> Saving Synthesizer to %s" % path)
        torch.save(self, path)
        return

    @classmethod
    def load(self, path: str, device: str = "cuda:1"):
        """ Load generator for inference purpose """
        print("\n=> Loading Synthesizer from %s" % path)
        device = torch.device(device)
        model = torch.load(path)
        model.set_device(device)
        return model


class Synthesizer(BaseSynthesizer):
    """ Synthesizer to generate individually fair sample """

    def __init__(
            self,
            embedding_dim: int = 128,
            generator_dim=(256, 256),
            discriminator_dim=(256, 256),
            generator_lr=2e-4,
            generator_decay=1e-6,
            discriminator_lr=2e-4,
            discriminator_decay=0,
            batch_size=4096,
            discriminator_steps=1,
            verbose=False,
            epochs=500,
            pac=16,
            device: str = "cuda:1",
    ):
        self._embedding_dim = embedding_dim
        self._generator_dim = generator_dim
        self._discriminator_dim = discriminator_dim

        self._generator_lr = generator_lr
        self._generator_decay = generator_decay
        self._discriminator_lr = discriminator_lr
        self._discriminator_decay = discriminator_decay

        self._batch_size = batch_size
        self._discriminator_steps = discriminator_steps
        self._verbose = verbose
        self._epochs = epochs
        self.pac = pac

        self._device = torch.device(device)

        self._data_trans = DataTransformer()

    def _apply_activate(self, data: Tensor, feat_idx: FeatIndex):
        """ Apply proper activation function to the output of the generator """

        data_t = []
        st = 0
        for feat, idx_list in feat_idx.feat2idx.items():
            ed = st + len(idx_list)
            if feat in feat_idx.cat_feat:
                transformed = F.gumbel_softmax(data[:, st:ed], tau=0.2)
                data_t.append(transformed)
            elif feat in feat_idx.num_feat:
                data_t.append(torch.tanh(data[:, st:st + 1]))  # the first index is the normalized value
                transformed = F.gumbel_softmax(data[:, st + 1:ed], tau=0.2)
                data_t.append(transformed)
            else:
                raise ValueError("Unknown feature %s" % feat)
            st = ed

        return torch.cat(data_t, dim=1)

    def _select_sample(self, X, comp_data, cond: bool):
        comp_idx = comp_data.idx_sen_or

        if cond:
            true_idx, false_idx = comp_data.loaders[1].cond_idx
            total_idx = true_idx + false_idx
            sample_idx = np.random.choice(total_idx, self._batch_size, replace=False)
            # true_idx_ = np.random.choice(true_idx, self._batch_size // 2, replace=False)
            # false_idx_ = np.random.choice(false_idx, self._batch_size // 2, replace=False)
            # sample_idx = np.concatenate([true_idx_, false_idx_])
        else:
            sample_idx = np.random.choice(np.arange(len(comp_idx[0])), self._batch_size, replace=False)

        if np.random.random() <= 0.5:
            base_sample = X[comp_idx[0][sample_idx]]
            comp_sample = X[comp_idx[1][sample_idx]]
        else:
            base_sample = X[comp_idx[1][sample_idx]]
            comp_sample = X[comp_idx[0][sample_idx]]

        return base_sample, comp_sample

    def _cond_loss(self, fake, comp_sample, feat_idx: FeatIndex):
        """ Compute the cross entropy loss on sensitive column """
        loss = torch.zeros(1).to(self._device)
        for feat in feat_idx.sen_feat:
            idx = feat_idx.feat2idx[feat]
            gt = (comp_sample[:, idx] == 1.).nonzero(as_tuple=False)[:, 1].to(self._device)
            loss += F.cross_entropy(fake[:, idx], gt)
        return loss

    def fit(self, X: np.ndarray, feat_idx: FeatIndex, comp_data: CompData, cond: bool = True,
            comp_func: Callable = None):
        """
        Fit Synthesizer with individually comparable samples
        Feed with normalized numerical feature and unnormalized one-hot categorical feature
        """

        # numerical feature transformation
        self._data_trans.fit(X, feat_idx)
        X = self._data_trans.transform(X)
        update_feat_idx = self._data_trans.feat_idx

        dim_cond_vec = len(update_feat_idx.sen_idx)
        data_dim = len(update_feat_idx.cat_idx) + len(update_feat_idx.num_idx)

        self._generator = Generator(
            embedding_dim=self._embedding_dim + data_dim + dim_cond_vec,
            generator_dim=self._generator_dim,
            data_dim=data_dim,
        ).to(self._device)
        discriminator = Discriminator(
            data_dim * 3,
            self._discriminator_dim,
            pac=self.pac,
        ).to(self._device)

        optimizerG = optim.Adam(
            self._generator.parameters(),
            lr=self._generator_lr,
            betas=(0.5, 0.9),
            weight_decay=self._generator_decay,
        )
        optimizerD = optim.Adam(
            discriminator.parameters(),
            lr=self._discriminator_lr,
            betas=(0.5, 0.9),
            weight_decay=self._discriminator_decay,
        )
        # schedulerG = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=self._epochs // 4, gamma=0.2)
        # schedulerD = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=self._epochs // 4, gamma=0.2)

        mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device, dtype=torch.float32)
        std = mean + 1

        cat_ratio_list = []
        num_ratio_list = []
        sen_ratio_list = []

        steps_per_epoch = max(len(X) // self._batch_size, 1)
        for i in range(self._epochs):
            loss_d_record, pen_record = AverageMeter(), AverageMeter()
            cond_loss_record, loss_g_record = AverageMeter(), AverageMeter()
            for j in range(steps_per_epoch):

                """ Train discriminator D with gradient penalty """
                for n in range(self._discriminator_steps):
                    base_sample, comp_sample = self._select_sample(X, comp_data, cond=cond)
                    base_sample = torch.from_numpy(base_sample).float().to(self._device)
                    comp_sample = torch.from_numpy(comp_sample).float().to(self._device)

                    fakez = torch.normal(mean=mean, std=std)
                    cond_vec = comp_sample[:, update_feat_idx.sen_idx]

                    fakez = torch.cat([fakez, base_sample, cond_vec], dim=1)
                    fake = self._generator(fakez)
                    fakeact = self._apply_activate(fake, update_feat_idx)

                    fake_cat = torch.cat([fakeact, base_sample, fakeact - base_sample], dim=1)
                    real_cat = torch.cat([comp_sample, base_sample, comp_sample - base_sample], dim=1)
                    y_fake = discriminator(fake_cat)
                    y_real = discriminator(real_cat)

                    pen = discriminator.calc_gradient_penalty(
                        real_cat, fake_cat, self._device, self.pac,
                    )
                    loss_d = -(torch.mean(y_real) - torch.mean(y_fake))

                    optimizerD.zero_grad()
                    pen.backward(retain_graph=True)
                    loss_d.backward()
                    optimizerD.step()

                    loss_d_record.update(loss_d.cpu().item(), self._batch_size)
                    pen_record.update(pen.cpu().item(), self._batch_size)

                """ Train generator G with sensitive attribute conditional loss """

                base_sample, comp_sample = self._select_sample(X, comp_data, cond=cond)
                base_sample = torch.from_numpy(base_sample).float().to(self._device)
                comp_sample = torch.from_numpy(comp_sample).float().to(self._device)

                fakez = torch.normal(mean=mean, std=std)
                cond_vec = comp_sample[:, update_feat_idx.sen_idx]

                fakez = torch.cat([fakez, base_sample, cond_vec], dim=1)
                fake = self._generator(fakez)
                fakeact = self._apply_activate(fake, update_feat_idx)
                y_fake = discriminator(torch.cat([fakeact, base_sample, fakeact - base_sample], dim=1))

                cond_loss = self._cond_loss(fake, comp_sample, update_feat_idx)
                loss_g = -torch.mean(y_fake)

                optimizerG.zero_grad()
                (loss_g + cond_loss).backward()
                optimizerG.step()

                loss_g_record.update(loss_g.cpu().item(), self._batch_size)
                cond_loss_record.update(cond_loss.cpu().item(), self._batch_size)

            # schedulerG.step()
            # schedulerD.step()

            if comp_func is not None:
                cat_ratio, num_ratio, sen_ratio = comp_func(
                    self._data_trans.reverse_transform(fakeact.detach().cpu().numpy()),
                    self._data_trans.reverse_transform(base_sample.detach().cpu().numpy()),
                    return_ratio=True,
                )
                cat_ratio_list.append(cat_ratio)
                num_ratio_list.append(num_ratio)
                sen_ratio_list.append(sen_ratio)

            print("Epoch: [%d/%d];\n"
                  "Loss D: %.5f; Penalty: %.5f; Total: %.5f;\n"
                  "Loss G: %.5f; Cond.: %.5f; Total: %.5f;"
                  % (i + 1, self._epochs, loss_d_record.avg, pen_record.avg, loss_d_record.avg + pen_record.avg,
                     loss_g_record.avg, cond_loss_record.avg, loss_g_record.avg + cond_loss_record.avg))

        return cat_ratio_list, num_ratio_list, sen_ratio_list

    def sample(self, X: np.ndarray) -> np.ndarray:
        """ Sample comparable data that randomly perturb the sensitive attribute """

        self._generator.eval()

        steps = X.shape[0] // self._batch_size + 1
        data = []
        update_feat_idx = self._data_trans.feat_idx
        mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device, dtype=torch.float32)
        std = mean + 1
        for i in range(steps):
            curr_X = X[i * self._batch_size: (i + 1) * self._batch_size]
            curr_X = self._data_trans.transform(curr_X)
            cond_vec = sen_perturb(curr_X, update_feat_idx)

            curr_X = torch.from_numpy(curr_X).float().to(self._device)
            cond_vec = torch.from_numpy(cond_vec).float().to(self._device)

            fakez = torch.normal(mean=mean, std=std).to(self._device)
            if len(curr_X) != self._batch_size:
                fakez = fakez[:len(curr_X)]

            fakez = torch.cat([fakez, curr_X, cond_vec], dim=1)

            with torch.no_grad():
                fake = self._generator(fakez)
            fakeact = self._apply_activate(fake, update_feat_idx)
            data.append(fakeact.cpu().numpy())

        data = np.concatenate(data, axis=0)

        return self._data_trans.reverse_transform(data)

    def sample_all(self, X: np.ndarray) -> List:
        """ Sample comparable data with all possible sensitive attribute """

        self._generator.eval()

        steps = X.shape[0] // self._batch_size + 1
        update_feat_idx = self._data_trans.feat_idx
        N_cond_vec = np.prod([len(update_feat_idx.feat2idx[feat]) for feat in update_feat_idx.sen_feat]) - 1
        data = [[] for _ in range(N_cond_vec)]

        mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device, dtype=torch.float32)
        std = mean + 1

        for i in range(steps):
            curr_X = X[i * self._batch_size: (i + 1) * self._batch_size]
            curr_X = self._data_trans.transform(curr_X)
            cond_vec = sen_all_diff(curr_X, update_feat_idx)

            curr_X = torch.from_numpy(curr_X).float().to(self._device)
            for j in range(N_cond_vec):
                curr_cond_vec = torch.from_numpy(cond_vec[:, j]).float().to(self._device)

                fakez = torch.normal(mean=mean, std=std).to(self._device)
                if len(curr_X) != self._batch_size:
                    fakez = fakez[:len(curr_X)]

                fakez = torch.cat([fakez, curr_X, curr_cond_vec], dim=1)

                with torch.no_grad():
                    fake = self._generator(fakez)
                fakeact = self._apply_activate(fake, update_feat_idx)
                fakeact = fakeact.cpu().numpy()

                data[j].append(fakeact)

        data = [np.concatenate(list_, axis=0) for list_ in data]

        return [self._data_trans.reverse_transform(array) for array in data]

    def set_device(self, device):
        """Set the `device` to be used ('GPU' or 'CPU)."""
        self._device = device
        if self._generator is not None:
            self._generator.to(self._device)


def sen_perturb(X: np.ndarray, feat_idx: FeatIndex) -> np.ndarray:
    """ Randomly perturb one categorical sensitive feature in X and return it as conditional vector """

    X = np.copy(X)
    cond_vec = []
    selected_feat = np.random.choice(feat_idx.sen_feat, size=len(X))
    for x, sen_feat in zip(X, selected_feat):
        curr_val = np.argmax(x[feat_idx.feat2idx[sen_feat]])
        avail_set = [i for i in range(len(feat_idx.feat2idx[sen_feat])) if i != curr_val]
        perturb_val = np.random.choice(avail_set)

        # one-hot encoding
        perturb = np.zeros(len(feat_idx.feat2idx[sen_feat]))
        perturb[perturb_val] = 1.
        x[feat_idx.feat2idx[sen_feat]] = perturb
        curr_cond_vec = x[feat_idx.sen_idx]
        cond_vec.append(curr_cond_vec.reshape(1, -1))

    cond_vec = np.concatenate(cond_vec, axis=0)

    return cond_vec


def sen_all_diff(X: np.ndarray, feat_idx: FeatIndex) -> np.ndarray:
    """ Sample every different conditional vector """

    X = np.copy(X)
    N = np.prod([len(feat_idx.feat2idx[feat]) for feat in feat_idx.sen_feat])
    all_avail_vec = np.zeros((N, len(feat_idx.sen_idx)))
    idx_list = [[i for i in range(len(feat_idx.feat2idx[feat]))] for feat in feat_idx.sen_feat]

    previous_idx = 0
    for i, list_ in enumerate(idx_list):
        if i == 0:
            continue
        else:
            previous_idx += len(idx_list[i - 1])
            idx_list[i] = [e + previous_idx for e in list_]

    all_avail = itertools.product(*idx_list)
    for i, set_ in enumerate(all_avail):
        for idx in set_:
            all_avail_vec[i, idx] = 1

    cond_vec = []
    for i, feat in enumerate(X):
        sen_feat = feat[list(feat_idx.sen_idx)]
        same_vec_idx = np.where((all_avail_vec == tuple(sen_feat)).all(axis=1))[0][0]
        curr_cond_vec = all_avail_vec.copy()
        curr_cond_vec = np.delete(curr_cond_vec, same_vec_idx, axis=0)
        cond_vec.append(curr_cond_vec)

    cond_vec = np.asarray(cond_vec)

    return cond_vec


if __name__ == "__main__":
    dataset = Adult()
    train_X, train_y = dataset.train_data()
    sen_all_diff(train_X, dataset.feat_idx)
