import pickle
import itertools
from collections import defaultdict
from tqdm import tqdm
# ML Imports
import pandas as pd
import numpy as np
import tensorflow_hub as hub
import torch
from torch import nn, optim
from sklearn.metrics import confusion_matrix
# Plotting imports
import matplotlib.pyplot as plt
import seaborn as sns


class MultiEncoder(nn.Module):
    """
    Stacked encoder as described in the paper.
    """

    def __init__(self, encoders, **kwargs):
        super().__init__()
        self.encoders = nn.ModuleList(encoders)
        self.input_size = tuple([e.input_size for e in encoders])
        self.hidden_size = tuple([e.hidden_size for e in encoders])
        self.encoded_size = tuple([e.encoded_size for e in encoders])

    def forward(self, features):
        for e in self.encoders:
            features = e.model(features)
        return features


class MultiDecoder(nn.Module):
    """
    Class is not used in experiments/code, merely here for experimental purposes.
    """

    def __init__(self, decoders, **kwargs):
        super().__init__()
        self.decoders = nn.ModuleList(decoders)
        self.input_size = tuple([d.input_size for d in decoders])
        self.hidden_size = tuple([d.hidden_size for d in decoders])
        self.encoded_size = tuple([d.encoded_size for d in decoders])

    def forward(self, features):
        for d in self.decoders:
            features = d.model(features)
        return features


class Encoder(nn.Module):
    """
    Simple encoder.
    """

    def __init__(self, input_size=784, hidden_size=None, encoded_size=128, **kwargs):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.encoded_size = encoded_size

        if hidden_size is not None:
            self.model = nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, encoded_size),
                nn.Tanh(),
            )
        else:
            self.model = nn.Sequential(
                nn.Linear(input_size, encoded_size),
                nn.Tanh(),
            )

    def forward(self, features):
        return self.model(features)


class Decoder(nn.Module):
    """
    Simple decoder.
    """

    def __init__(self, input_size=784, hidden_size=None, encoded_size=128, activation='tanh', **kwargs):
        activation_to_nn = {'sigmoid': nn.Sigmoid(), 'tanh': nn.Tanh(), 'none': nn.Identity()}

        if activation not in activation_to_nn.keys():
            raise ValueError(f'Activation needs to be in {activation_to_nn.keys()}.')
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.encoded_size = encoded_size
        self.activation = activation

        if hidden_size is not None:
            self.model = nn.Sequential(
                nn.Linear(encoded_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, input_size),
                activation_to_nn[activation],
            )
        else:
            self.model = nn.Sequential(
                nn.Linear(encoded_size, input_size),
                activation_to_nn[activation],
            )

    def forward(self, features):
        return self.model(features)


class MLPBinaryClassifier(nn.Module):
    """
    Simple MLP classifier.
    """

    def __init__(self, input_size=512, hidden_size=None, **kwargs):
        super().__init__()
        self.hidden_size = hidden_size

        if hidden_size is not None:
            self.model = nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
                nn.Sigmoid(),
            )
        else:
            self.model = nn.Sequential(
                nn.Linear(input_size, 1),
                nn.Sigmoid(),
            )

    def forward(self, features):
        return self.model(features)


def requires_grad(model, requires_grad=True):
    for param in model.parameters():
        param.requires_grad = requires_grad


def alfr_ds(train_loader, adversary_hidden_size=64, epochs=30, latent_dimensions=128, clip_gradients=False, alpha=None,
            output_activation='tanh', device='cpu'):
    """
    ALFR-DS as explained in the paper.
    :param train_loader: Trianing data loader.
    :param adversary_hidden_size: Hidden size of the adversary
    :param epochs: List of epochs, each element contains number of epochs for inner loop.
    :param latent_dimensions: Latent dimensions.
    :param clip_gradients: Whether or not to clip gradients for extra stabilization (Default false)
    :param alpha: If None performs dampening, otherwise set an explicit alpha (algorithm = ALFR-S)
    :param output_activation: Output activation function
    :param device: Device (cpu/gpu).
    :return: Triple of encoder, decoder and experimental results.
    """
    result = []

    loss_decoder = nn.MSELoss()  # Mean Squared Error for reconstruction loss
    loss_adversary = nn.BCELoss()  # Binary Cross Entropy for binary classification loss

    cur_encoders = []

    batch_size, input_size = next(iter(train_loader))[0].shape
    decoder = Decoder(input_size=input_size, hidden_size=latent_dimensions, encoded_size=latent_dimensions,
                      activation=output_activation).to(device)

    if isinstance(epochs, int):
        epochs = [epochs]

    for stack, stack_epochs in enumerate(epochs):
        # Create new encoder
        encoder = Encoder(input_size=input_size if len(cur_encoders) == 0 else latent_dimensions,
                          hidden_size=latent_dimensions, encoded_size=latent_dimensions).to(device)
        # Add to stack
        cur_encoders.append(encoder)

        auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))

        epoch = 0
        adversary = MLPBinaryClassifier(input_size=latent_dimensions, hidden_size=adversary_hidden_size)
        adversary.to(device)
        adversary_optimizer = optim.Adam(adversary.parameters())
        adversary_turn = False

        for _ in tqdm(range(stack_epochs)):

            # Define current adversary

            total_reconstruction_error = 0
            total_adversary_loss = 0
            total_correct = 0
            total_samples = 0
            for inputs, targets in train_loader:
                inputs = inputs.to(device)
                targets = targets.to(device)

                requires_grad(decoder, True)
                requires_grad(adversary, True)
                requires_grad(encoder, True)

                # reset the gradients back to zero
                # PyTorch accumulates gradients on subsequent backward passes
                auto_encoder_optimizer.zero_grad()
                adversary_optimizer.zero_grad()

                # compute reconstructions
                encoded = inputs
                for e in cur_encoders:
                    encoded = e(encoded)
                decoded = decoder(encoded)
                adversary_prediction = adversary(encoded)

                # Compute losses
                reconstruction_error = loss_decoder(decoded, inputs)
                adversary_loss = loss_adversary(adversary_prediction, targets)

                # Compute correct/samples accuracy
                batch_correct = torch.sum(adversary_prediction.round() == targets).item()
                batch_samples = list(targets.size())[0]

                if alpha is not None:
                    # Compute losses
                    if adversary_turn:
                        requires_grad(decoder, False)
                        requires_grad(adversary, True)
                        requires_grad(encoder, False)
                        adversary_loss.backward()
                        if clip_gradients:
                            torch.nn.utils.clip_grad_norm_(adversary.parameters(), 1)
                    else:
                        requires_grad(decoder, True)
                        requires_grad(adversary, False)
                        requires_grad(encoder, True)
                        ((alpha * -adversary_loss) + reconstruction_error).backward()
                        if clip_gradients:
                            torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1)

                    # Switch turn
                    adversary_turn = not adversary_turn
                else:
                    # Compute adversary accuracy
                    adversary_accuracy = batch_correct / batch_samples

                    # Compute dampening
                    zero_acc = max(1 - (torch.sum(targets).item() / batch_samples),
                                   torch.sum(targets).item() / batch_samples)
                    try:
                        adversarial_dampening = max(adversary_accuracy - zero_acc, 0) / (1 - zero_acc)
                    except ZeroDivisionError:
                        adversarial_dampening = 0

                    # Freeze adversary, unfreeze auto-encoder
                    requires_grad(decoder, True)
                    requires_grad(adversary, False)
                    requires_grad(encoder, True)

                    # Ensures two things:
                    # - Allows adversaries to 'catch up' if previous accuracy was low
                    # - Stabilizes network (whenever accuracy stays 0.5, we are basically done)
                    ((adversarial_dampening * -adversary_loss) + reconstruction_error).backward(retain_graph=True)

                    # Clip encoder grads
                    if clip_gradients:
                        torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1)

                    requires_grad(decoder, False)
                    requires_grad(adversary, True)
                    requires_grad(encoder, False)

                    ((1 - adversarial_dampening) * adversary_loss).backward()

                    # Clip adversary grads
                    if clip_gradients:
                        torch.nn.utils.clip_grad_norm_(adversary.parameters(), 1)

                # Second compute decoder loss
                auto_encoder_optimizer.step()
                adversary_optimizer.step()

                # add the mini-batch training loss to epoch loss
                total_reconstruction_error += reconstruction_error.item()
                total_adversary_loss += adversary_loss.item()
                total_correct += batch_correct
                total_samples += batch_samples

            # compute the epoch training loss
            reconstruction_error = total_reconstruction_error / len(train_loader)
            adversary_loss = total_adversary_loss / len(train_loader)
            train_accuracy = total_correct / total_samples

            # display the epoch training loss
            result.append({
                'epoch': epoch,
                'stack': stack,
                'latent_dimensions': latent_dimensions,
                'adversary_hidden_size': adversary_hidden_size,
                'adversary_loss': adversary_loss,
                'adversary_accuracy': train_accuracy,
                'reconstruction_error': reconstruction_error,
                'output_activation': output_activation
            })
            epoch += 1
        requires_grad(encoder, False)

    return MultiEncoder(cur_encoders), decoder, pd.DataFrame(result).set_index('epoch')


def alfr(train_loader, epochs, adversary_hidden_size=64, latent_dimensions=128,
         alpha=1, clip_gradients=False, output_activation='tanh', device='cpu'):
    """
    Basic ALFR as described in the original paper.
    :param train_loader: Train loader.
    :param epochs: Number of epochs.
    :param adversary_hidden_size: Hidden size of adversary.
    :param latent_dimensions: Latent dimensions.
    :param alpha: Alpha as desribed in the paper (default is 1).
    :param clip_gradients: Whether or not to clip gradients.
    :param output_activation: Output activation.
    :param device: cpu/gpu
    :return: Triple containing encoder, decoder and experimental results.
    """

    result = []
    batch_size, input_size = next(iter(train_loader))[0].shape

    encoder = Encoder(input_size=input_size, hidden_size=latent_dimensions, encoded_size=latent_dimensions).to(device)
    decoder = Decoder(input_size=input_size, hidden_size=latent_dimensions, encoded_size=latent_dimensions,
                      activation=output_activation).to(device)

    loss_decoder = nn.MSELoss()  # Mean Squared Error for reconstruction loss
    loss_adversary = nn.BCELoss()  # Binary Cross Entropy for binary classification loss

    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))
    epoch = 0

    adversary = MLPBinaryClassifier(input_size=latent_dimensions, hidden_size=adversary_hidden_size)
    adversary_turn = False
    adversary.to(device)
    adversary_optimizer = optim.Adam(adversary.parameters())

    for epoch in tqdm(range(epochs)):

        # Define current adversary
        total_reconstruction_error = 0
        total_adversary_loss = 0
        total_correct = 0
        total_samples = 0
        for inputs, targets in train_loader:
            # reshape mini-batch data to [N, 784] matrix
            # load it to the active device
            # print(batch_features)
            inputs = inputs.to(device)
            targets = targets.to(device)

            requires_grad(decoder, True)
            requires_grad(adversary, True)
            requires_grad(encoder, True)

            # reset the gradients back to zero
            # PyTorch accumulates gradients on subsequent backward passes
            auto_encoder_optimizer.zero_grad()
            adversary_optimizer.zero_grad()

            # compute reconstructions
            encoded = encoder(inputs)
            decoded = decoder(encoded)
            adversary_prediction = adversary(encoded)

            # Compute losses
            reconstruction_error = loss_decoder(decoded, inputs)
            adversary_loss = loss_adversary(adversary_prediction, targets)

            if adversary_turn:
                requires_grad(decoder, False)
                requires_grad(adversary, True)
                requires_grad(encoder, False)
                adversary_loss.backward()
                if clip_gradients:
                    torch.nn.utils.clip_grad_norm_(adversary.parameters(), 1)
            else:
                requires_grad(decoder, True)
                requires_grad(adversary, False)
                requires_grad(encoder, True)
                ((alpha * -adversary_loss) + reconstruction_error).backward()
                if clip_gradients:
                    torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1)

            # Switch turn
            adversary_turn = not adversary_turn

            # Second compute decoder loss
            auto_encoder_optimizer.step()
            adversary_optimizer.step()

            # add the mini-batch training loss to epoch loss
            total_reconstruction_error += reconstruction_error.item()
            total_adversary_loss += adversary_loss.item()
            total_correct += torch.sum(adversary_prediction.round() == targets).item()
            total_samples += list(targets.size())[0]

        # compute the epoch training loss
        reconstruction_error = total_reconstruction_error / len(train_loader)
        adversary_loss = total_adversary_loss / len(train_loader)
        train_accuracy = total_correct / total_samples

        # display the epoch training loss
        result.append({
            'epoch': epoch,
            'alpha': alpha,
            'clip_gradients': clip_gradients,
            'adversary_hidden_size': adversary_hidden_size,
            'adversary_loss': adversary_loss,
            'adversary_accuracy': train_accuracy,
            'reconstruction_error': reconstruction_error,
        })
    return encoder, decoder, pd.DataFrame(result).set_index('epoch')


def train_adversary(train_loader, encoder, adversary, optimizer, epochs=2, device='cpu'):
    """
    Method to train an adversary.
    :param train_loader: Training data.
    :param encoder: Encoder to use.
    :param adversary: The adversary to use.
    :param optimizer: The optimizer to use.
    :param epochs: Number of epochs
    :param device: Device to use cpu/gpu
    :return:
    """
    requires_grad(encoder, False)
    loss_adversary = nn.BCELoss()

    for epoch in tqdm(range(epochs)):
        for inputs, targets in train_loader:
            inputs = inputs.view(-1, 784).to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            encoded = encoder(inputs)
            adversary_prediction = adversary(encoded)
            adversary_loss = loss_adversary(adversary_prediction, targets)
            adversary_loss.backward()
            optimizer.step()


def test_against_adversaries(train_loader, val_loader, encoder, epochs=10, device='cpu'):
    """
    Method to test against adversaries. Returns dataframe with results.
    :param train_loader: Train loader.
    :param val_loader: Valuation loader.
    :param encoder: The encoder to test.
    :param epochs: Number of epochs to train
    :param device: cpu/gpu
    :return: Dataframe with results.
    """

    result = []
    loss_adversary = nn.BCELoss()  # Binary Cross Entropy for binary classification loss
    latent_size = encoder.encoded_size if isinstance(encoder.encoded_size, int) else encoder.encoded_size[-1]

    # Define adversary we wish to test
    adversaries = [
        MLPBinaryClassifier(input_size=latent_size, hidden_size=None).to(device),
        MLPBinaryClassifier(input_size=latent_size, hidden_size=latent_size // 4).to(device),
        MLPBinaryClassifier(input_size=latent_size, hidden_size=latent_size // 2).to(device),
    ]
    encoder.to(device)
    adversaries_params = []
    for adversary in adversaries:
        adversaries_params += list(adversary.parameters())

    adversaries_optimizer = optim.Adam(adversaries_params)
    requires_grad(encoder, False)

    for epoch in tqdm(range(epochs)):
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            adversaries_optimizer.zero_grad()
            encoded = encoder(inputs)
            for adversary in adversaries:
                adversary_prediction = adversary(encoded)
                adversary_loss = loss_adversary(adversary_prediction, targets)
                adversary_loss.backward()
            adversaries_optimizer.step()

    # Inefficient since we do not use a loader, but works
    loss_adversary = nn.BCELoss(reduction='sum')  # Binary Cross Entropy sum
    for adversary in adversaries:
        tn, fp, fn, tp = 0, 0, 0, 0
        total_loss = 0
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            encoded = encoder(inputs)
            adversary_prediction = adversary(encoded)

            try:
                tn_batch, fp_batch, fn_batch, tp_batch = confusion_matrix(
                    adversary_prediction.round().detach().numpy(),
                    targets.detach().numpy()
                ).ravel()
            except:
                continue

            tn += tn_batch
            fp += fp_batch
            fn += fn_batch
            tp += tp_batch
            total_loss += loss_adversary(adversary_prediction, targets)

        result.append({
            'adversary_hidden_size': adversary.hidden_size,
            'n_epochs_trained': epochs,
            'tn': tn,
            'fp': fp,
            'fn': fn,
            'tp': tp,
            'val_accuracy': (tp + tn) / (tn + fp + fn + tp),
            'total_loss': total_loss,
            'mean_loss': total_loss / (tn + fp + fn + tp)
        })
    return pd.DataFrame(result)


class Run:
    """
    Simple class which contains experimental results of a certain run.
    """

    def get_score(self, metric='val_accuracy'):
        if metric in self.evaluation_results:
            return self.evaluation_results[metric]

        tn, fp, fn, tp = self.evaluation_results.tn, self.evaluation_results.fp, self.evaluation_results.fn, self.evaluation_results.tp

        if metric == 'val_accuracy':
            return (tp + tn) / (tn + fp + fn + tp)
        if metric == 'f1_score':
            return tp / (tp + 0.5 * (fp + fn))
        if metric == 'mcc':
            return (tp * tn - fp * fn) / ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
        if metric == 'total_loss':
            return self.evaluation_results.total_loss
        return None

    def __init__(self, train_results, evaluation_results, trial=None):
        self.train_results = train_results
        self.evaluation_results = evaluation_results
        self.trial = trial


class Experiment:
    """
    Simpple class which contains experimental results.
    """

    def __init__(self):
        self.runs = []

    def add_run(self, r: Run):
        self.runs.append(r)

    def evaluate(self):
        return None

    def remove_trial(self, trial):
        new_runs = []
        for run in self.runs:
            if run.trial != trial:
                new_runs.append(run)
        self.runs = new_runs

    def save(self, out_path):
        with open(out_path, 'wb') as f:
            # Pickle self
            pickle.dump(self, f)

    def plot_reconstruction_loss_per_trial(self, ylim=None, legend=None, tofile=None):
        if ylim is not None:
            plt.ylim(ylim)

        data = pd.concat([run.train_results for run in self.runs])
        data['MSE'] = data['reconstruction_error']
        data.drop('reconstruction_error', axis=1)
        data['Algorithm'] = list(
            itertools.chain.from_iterable([[run.trial] * len(run.train_results) for run in self.runs]))
        data['Epoch'] = list(
            itertools.chain.from_iterable([list(range(1, len(run.train_results) + 1, 1)) for run in self.runs]))
        data = data.reset_index()

        sns.lineplot(data=data, x="Epoch", y="MSE", hue='Algorithm', palette='bright')
        if legend is not None:
            plt.legend(**legend)

        if tofile is not None:
            plt.savefig(tofile, format="pdf", transparent=True, bbox_inches='tight')

        plt.show()

    def get_score_against_adversaries(self, metric='val_accuracy'):
        labels = ['Logistic Regression', 'MLP(32)', 'MLP(64)']
        data = defaultdict(list)
        for run in self.runs:
            data[run.trial].append(run.get_score(metric))
        result = []
        for trial in data:
            df_trial = pd.DataFrame(data[trial])
            x = dict(zip(labels, df_trial.mean().round(2).astype(str) + ' ± ' + df_trial.std().round(2).astype(str)))
            x['Algorithm'] = trial
            result.append(x)
        return pd.DataFrame(result).set_index('Algorithm')

    def plot_accuracy_against_adversaries_per_trial(self, ymin=0):
        labels = ['0', '32', '64']
        data = defaultdict(list)
        for run in self.runs:
            data[run.trial].append(run.evaluation_results.val_accuracy)

        x = np.arange(len(labels))  # the label locations
        width = 0.35  # the width of the bars

        fig, ax = plt.subplots()
        x = np.arange(len(labels))
        for trial in data:
            avg_data = np.array(data[trial]).sum(axis=0) / len(data[trial])
            plt.bar(x - 0.5 + (1 / len(data)) * trial, avg_data, (1 / len(data)))
        plt.xticks(x, labels)
        plt.xlabel("Accuracy")
        plt.ylabel("Opponents")
        plt.ylim([ymin, 1])

        plt.legend(["Trial %d" % d for d in data])
        plt.show()
        # print(self.runs[0].evaluation_results)

    @classmethod
    def load(cls, in_path):
        with open(in_path, 'rb') as f:
            obj: cls = pickle.load(f)
        return obj

