from itertools import chain
import torch
import math
from omegaconf import OmegaConf
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule
from sklearn.metrics.cluster import _supervised
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning import Trainer, seed_everything
import os
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from torchvision.datasets import MNIST
from torch.utils.data import Dataset
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score


class TotalCodingRate(torch.nn.Module):
    """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """
    def __init__(self, eps=0.01):
        super(TotalCodingRate, self).__init__()
        self.eps = eps

    def compute_discrimn_loss(self, W):
        p, m = W.shape  # [d, B]
        I = torch.eye(p, device=W.device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.

    def forward(self, x):
        return - self.compute_discrimn_loss(x.T)


class MaximalCodingRateReduction(torch.nn.Module):
    """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/loss.py """

    def __init__(self, eps=0.01, gamma=1):
        super(MaximalCodingRateReduction, self).__init__()
        self.eps = eps
        self.gamma = gamma

    def compute_discrimn_loss(self, W):
        p, m = W.shape
        I = torch.eye(p, device=W.device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.

    def compute_compress_loss(self, W, Pi):
        p, m = W.shape
        k, _, _ = Pi.shape
        I = torch.eye(p, device=W.device).expand((k, p, p))
        trPi = Pi.sum(2) + 1e-8
        scale = (p / (trPi * self.eps)).view(k, 1, 1)
        W = W.view((1, p, m))
        log_det = torch.logdet(I + scale * W.mul(Pi).matmul(W.transpose(1, 2)))
        compress_loss = (trPi.squeeze() * log_det / (2 * m)).sum()
        return compress_loss

    def forward(self, X, Y, num_classes=None):
        # This function support Y as label integer or membership probablity.
        if len(Y.shape) == 1:
            # if Y is a label vector
            if num_classes is None:
                num_classes = Y.max() + 1
            Pi = torch.zeros((num_classes, 1, Y.shape[0]), device=Y.device)
            for indx, label in enumerate(Y):
                Pi[label, 0, indx] = 1
        else:
            # if Y is a probility matrix
            if num_classes is None:
                num_classes = Y.shape[1]
            Pi = Y.T.reshape((num_classes, 1, -1))

        W = X.T
        discrimn_loss = self.compute_discrimn_loss(W)
        compress_loss = self.compute_compress_loss(W, Pi)
        return discrimn_loss, compress_loss


class BaseModule(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.save_hyperparameters()
        self.best_evaluation_stats = {}
        self.ae_train = False
        self.automatic_optimization = False
        self.best_accuracy = - np.infty
        self.gating_net = GatingNet(cfg)
        self.clustering_net = Clustering(cfg)
        self.mcrr = MaximalCodingRateReduction(eps=cfg.eps, gamma=cfg.gamma)
        self.total_coding = TotalCodingRate(eps=cfg.eps)

        self.val_cluster_list = []
        self.val_cluster_list_gated = []
        self.val_label_list = []
        self.open_gates = []

    def global_gates_step(self, x):
        mu, sparse_x, gates = self.gating_net(x)
        ae_emb = self.clustering_net.encoder(x * gates)
        cluster_logits = self.clustering_net.clustering_head(ae_emb)
        y_hat = cluster_logits.argmax(dim=-1)
        reg_loss = 0
        aux_loss = 0
        for y in y_hat.unique():
            x_i = x[y_hat == y]
            gates_i = gates[y_hat == y]
            glob_gates_mu, glob_gates = self.gating_net.global_forward(x_i.size(0), y)
            reg_loss = reg_loss + self.gating_net.regularization(glob_gates_mu)
            aux_y_hat = self.clustering_net.aux_classifier(x_i * gates_i * glob_gates)
            aux_loss = aux_loss + F.cross_entropy(aux_y_hat, y.reshape(1).repeat(x_i.size(0)))
        aux_loss = aux_loss / len(y_hat.unique())
        reg_loss = reg_loss / len(y_hat.unique())
        self.log('train/glob_gates_reg_loss', reg_loss.item())
        self.log('train/glob_gates_ce_loss', aux_loss.item())
        return aux_loss + self.cfg.global_gates_reg_lambda * reg_loss

    def ae_step(self, x):
        if self.current_epoch > self.cfg.ae_non_gated_epochs:
            mu, _, gates = self.gating_net(x)
            reg_loss = self.gating_net.regularization(mu)
            tcr_loss = self.total_coding(gates) / x.size(0)
            self.log("pretrain/gates_reg_loss", reg_loss.item())
            self.log("pretrain/gates_tcr_loss", tcr_loss.item())
            loss = self.local_gates_reg_lambda(0., self.cfg.local_gates_reg_lambda) * reg_loss + tcr_loss
        else:
            gates = torch.ones_like(x, device=x.device).float()
            loss = 0

        # task 1: reconstruct x from x
        x_recon = self.clustering_net.pretrain_forward(x)
        x_recon_loss = F.l1_loss(x_recon, x)
        self.log("pretrain/x_recon_loss", x_recon_loss.item())

        # task 2: reconstruct x from gated x:
        x_recon_from_gated = self.clustering_net.pretrain_forward(x * gates)
        x_from_gated_x_recon_loss = F.l1_loss(x_recon_from_gated, x)
        self.log("pretrain/x_from_gated_x_recon_loss", x_from_gated_x_recon_loss.item())

        # task 3: reconstruct x from randomly masked x
        mask_rnd = torch.rand(x.size()).to(x.device)
        mask = torch.ones(x.size()).to(x.device).float()
        mask[mask_rnd < self.cfg.mask_percentage] = 0
        x_recon_masked = self.clustering_net.pretrain_forward(x * mask)
        input_noised_recon_loss = F.l1_loss(x_recon_masked, x)
        self.log("pretrain/input_noised_recon_loss", input_noised_recon_loss.item())

        # task 4: reconstruct x from noisy embedding
        e = self.clustering_net.encoder(x)
        e = e * torch.normal(mean=1., std=self.cfg.latent_noise_std, size=e.size(), device=e.device)
        recon_noised = self.clustering_net.decoder(e)
        noised_aug_loss = F.l1_loss(recon_noised, x)
        self.log("pretrain/latent_noised_recon_loss", noised_aug_loss.item())

        # combined loss:
        loss = loss + x_recon_loss + self.cfg.local_gates_reg_lambda * x_from_gated_x_recon_loss + \
               self.cfg.local_gates_reg_lambda * input_noised_recon_loss + noised_aug_loss
        return loss

    def training_step(self, batch, batch_idx):
        ae_opt, clust_opt, glob_gates_opt = self.optimizers()
        pretrain_sched, sch = self.lr_schedulers()
        x, _ = batch
        x = x.reshape(x.size(0), -1)

        # reconstruction step + local gates training
        if self.current_epoch <= self.cfg.ae_pretrain_epochs:
            ae_opt.zero_grad()
            loss = self.ae_step(x)
            self.manual_backward(loss)
            ae_opt.step()
            pretrain_sched.step()
            return

        # clusters compression step
        clust_opt.zero_grad()
        mu, _, gates = self.gating_net(x)
        ae_emb = self.clustering_net.encoder(x * gates.detach())
        cluster_logits = self.clustering_net.clustering_head(ae_emb)
        loss = self.mcrr_loss(ae_emb.detach(), cluster_logits)
        self.log("train/mcrr_loss", loss.item())
        self.manual_backward(loss)
        clust_opt.step()

        # global gates training
        if self.current_epoch > self.cfg.ae_pretrain_epochs + self.cfg.start_global_gates_training_on_epoch:
            glob_gates_opt.zero_grad()
            loss = self.global_gates_step(x)
            self.manual_backward(loss)
            glob_gates_opt.step()
        sch.step()

    def configure_optimizers(self):
        pretrain_optimizer = torch.optim.Adam(
            params=chain(
                self.clustering_net.encoder.parameters(),
                self.clustering_net.decoder.parameters(),
                self.gating_net.net.parameters(),
            ),
            lr=1e-3)

        cluster_optimizer = torch.optim.Adam(
            params=chain(
                self.clustering_net.clustering_head.parameters(),
            ),
            lr=1e-2)

        glob_gates_opt = torch.optim.SGD(
            params=chain(
                self.clustering_net.aux_classifier.parameters(),
                self.gating_net.global_gates_net.parameters(),
            ),
            lr=1e-1)

        steps = self.dataset.__len__() // self.cfg.batch_size * (
                self.cfg.trainer.max_epochs - self.cfg.ae_pretrain_epochs)
        pretrain_steps = self.dataset.__len__() // self.cfg.batch_size * self.cfg.ae_pretrain_epochs
        print(f"Cosine annealing LR scheduling is applied during {steps} steps")
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=cluster_optimizer,
            T_max=steps,
            eta_min=1e-6)
        pretrain_sched = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=pretrain_optimizer,
            T_max=pretrain_steps,
            eta_min=1e-6)
        return [pretrain_optimizer, cluster_optimizer, glob_gates_opt], [pretrain_sched, sched]

    def local_gates_reg_lambda(self, min_val, max_val):
        epoch = self.current_epoch - self.cfg.ae_pretrain_epochs
        total_epochs = self.cfg.ae_pretrain_epochs - self.cfg.ae_non_gated_epochs
        return min_val + 0.5 * (max_val - min_val) * (1. + np.cos(epoch * math.pi / total_epochs))

    def validation_step(self, batch, batch_idx):
        x, y = batch
        gates = self.gating_net.get_gates(x)
        ae_emb = self.clustering_net.encoder(x * gates)
        cluster_logits = self.clustering_net.clustering_head(ae_emb)
        y_hat = cluster_logits.argmax(dim=-1)
        self.val_cluster_list.append(y_hat.cpu())
        self.val_label_list.append(y.cpu())
        self.open_gates.append(self.gating_net.num_open_gates(x))

        # plot samples and gates
        if (self.current_epoch + 1) % 50 == 0:
            for i in range(self.cfg.n_clusters):
                _, glob_gates = self.gating_net.global_forward(1, torch.tensor(i, device=x.device).long())
                glob_gates_img = glob_gates.reshape(28, 28).cpu().numpy()
                self.save_image(glob_gates_img, f'global_gates_y_{i}.png', 28)

            for i, (gx, label, xx, gated_xx, y_hat_i) in enumerate(zip(gates, y, x, x * gates, y_hat)):
                label = label.cpu().numpy()
                gates_img = gx.reshape(28, 28).cpu().numpy()
                orig_img = xx.reshape(28, 28).cpu().numpy()
                gated_img = gated_xx.reshape(28, 28).cpu().numpy()
                self.save_image(gates_img, f'gates_sample_{i}_y_{label}_y_hat_{y_hat_i}_batch_idx_{batch_idx}.png', 28)
                self.save_image(orig_img, f'orig_sample_{i}_y_{label}_y_hat_{y_hat_i}_batch_idx_{batch_idx}.png', 28)
                self.save_image(gated_img, f'result_sample_{i}_y_{label}_y_hat_{y_hat_i}_batch_idx_{batch_idx}.png', 28)
                if i == 1:
                    break

    def save_image(self, array, name, size=None):
        out_dir = f"./outputs_{os.path.basename(__file__)}/epoch_{self.current_epoch}"
        os.makedirs(out_dir, exist_ok=True)
        if isinstance(size, tuple):
            image = array.reshape((size[0], size[1]))
        elif size is not None:
            image = array.reshape((size, size))
        else:
            image = array
        plt.clf()
        fig, ax = plt.subplots()
        ax.set_title(name)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        plt.imshow(image, cmap='gray')
        plt.savefig(f"{out_dir}/{name}")
        plt.close()

    def on_validation_epoch_start(self):
        self.val_cluster_list = []
        self.val_cluster_list_gated = []
        self.val_label_list = []
        self.open_gates = []

    @staticmethod
    def clustering_accuracy(labels_true, labels_pred):
        """Compute clustering accuracy."""
        from scipy.optimize import linear_sum_assignment
        labels_true, labels_pred = _supervised.check_clusterings(labels_true, labels_pred)
        value = _supervised.contingency_matrix(labels_true, labels_pred)
        [r, c] = linear_sum_assignment(-value)
        return value[r, c].sum() / len(labels_true)

    @staticmethod
    def cluster_match(cluster_mtx, label_mtx, n_classes=10, print_result=True):
        """Author: https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/func.py"""
        # verified to be consistent to optimimal assignment problem based algorithm
        cluster_indx = list(cluster_mtx.unique())
        assigned_label_list = []
        assigned_count = []
        while (len(assigned_label_list) <= n_classes) and len(cluster_indx) > 0:
            max_label_list = []
            max_count_list = []
            for indx in cluster_indx:
                mask = cluster_mtx == indx
                label_elements, counts = label_mtx[mask].unique(return_counts=True)
                for assigned_label in assigned_label_list:
                    counts[label_elements == assigned_label] = 0
                max_count_list.append(counts.max())
                max_label_list.append(label_elements[counts.argmax()])

            max_label = torch.stack(max_label_list)
            max_count = torch.stack(max_count_list)
            assigned_label_list.append(max_label[max_count.argmax()])
            assigned_count.append(max_count.max())
            cluster_indx.pop(max_count.argmax().item())
        total_correct = torch.tensor(assigned_count).sum().item()
        total_sample = cluster_mtx.shape[0]
        acc = total_correct / total_sample
        if print_result:
            print('{}/{} ({}%) correct'.format(total_correct, total_sample, acc * 100))
        else:
            return total_correct, total_sample, acc

    def on_validation_epoch_end(self):
        """ Based on https://github.com/zengyi-li/NMCE-release/blob/main/NMCE/func.py"""
        if not (self.ae_train and self.current_epoch < self.cfg.ae_pretrain_epochs) and self.current_epoch > 0:
            if self.current_epoch < self.cfg.ae_pretrain_epochs - 1:
                return
            else:
                cluster_mtx = torch.cat(self.val_cluster_list, dim=0)
            label_mtx = torch.cat(self.val_label_list, dim=0)
            _, _, acc_single = self.cluster_match(
                cluster_mtx,
                label_mtx,
                n_classes=label_mtx.max() + 1,
                print_result=False)
            if self.best_accuracy < acc_single:
                print("New best accuracy:", acc_single)
                self.best_accuracy = acc_single
                meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.clustering_net.state_dict()}
                torch.save(meta_dict, 'sparse_model_best.pth')
            nmi = normalized_mutual_info_score(label_mtx.numpy(), cluster_mtx.numpy())
            ari = adjusted_rand_score(label_mtx.numpy(), cluster_mtx.numpy())
            format_str = ''  # '_kmeans' if self.current_epoch == 9 else ''
            self.log(f'val/acc_single{format_str}', acc_single)  # this is ACC
            self.log(f'val/NMI{format_str}', nmi)
            self.log(f'val/ARI{format_str}', ari)
            self.log("val/num_open_gates", np.mean(self.open_gates).item())
            self.log("val/num_open_global_gates", self.gating_net.open_global_gates())
            meta_dict = {"gating": self.gating_net.state_dict(), "clustering": self.clustering_net.state_dict()}
            torch.save(meta_dict, 'sparse_model_last.pth')

    def mcrr_loss(self, c, logits):
        logprobs = torch.log_softmax(logits, dim=-1)
        prob = GumbleSoftmax(self.tau())(logprobs)
        discrimn_loss, compress_loss = self.mcrr(F.normalize(c), prob, num_classes=self.cfg.n_clusters)
        discrimn_loss /= c.size(1)
        compress_loss /= c.size(1)
        self.log(f'train/discrimn_loss', -discrimn_loss.item())
        self.log(f'train/compress_loss', compress_loss.item())
        return self.cfg.gamma * compress_loss - discrimn_loss

    def tau(self):
        return self.cfg.tau


class MNISTTab(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.data = x
        self.targets = y

    def __getitem__(self, index: int):
        return torch.tensor(self.data[index]).float(), torch.tensor(self.targets[index]).long()

    @classmethod
    def setup(cls, data_dir):
        x_train = MNIST(data_dir, train=True, download=True).data.reshape(-1, 784).cpu().numpy()
        y_train = MNIST(data_dir, train=True, download=True).targets.cpu().numpy()
        x_train = x_train / 255.
        return cls(x_train, y_train)

    def __len__(self) -> int:
        return len(self.data)

    def num_classes(self):
        return len(np.unique(self.targets))

    def num_features(self):
        return self.data.shape[-1]


class MNISTClustering(BaseModule):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.dataset = MNISTTab.setup(self.cfg.data_dir)
        print(f"Dataset length: {self.dataset.__len__()}")
        self.cfg.input_dim = self.dataset.num_features()
        self.cfg.n_clusters = self.dataset.num_classes()

    def train_dataloader(self):
        return DataLoader(self.dataset,
                          batch_size=self.cfg.batch_size,
                          drop_last=True,
                          shuffle=True,
                          num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.dataset,
                          batch_size=self.cfg.batch_size,
                          drop_last=False,
                          shuffle=False,
                          num_workers=0)


class GumbleSoftmax(torch.nn.Module):
    def __init__(self, tau, straight_through=False):
        super().__init__()
        self.tau = tau
        self.straight_through = straight_through

    def forward(self, logps):
        gumble = torch.rand_like(logps).log().mul(-1).log().mul(-1)
        logits = logps + gumble
        out = (logits / self.tau).softmax(dim=1)
        if not self.straight_through:
            return out
        else:
            out_binary = (logits * 1e8).softmax(dim=1).detach()
            out_diff = (out_binary - out).detach()
            return out_diff + out


class Clustering(torch.nn.Module):
    def __init__(self, cfg):
        super(Clustering, self).__init__()
        self.cfg = cfg
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(cfg.input_dim, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 2048),
            torch.nn.BatchNorm1d(2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, cfg.n_clusters),
        )
        self.clustering_head = torch.nn.Sequential(
            torch.nn.Linear(cfg.n_clusters, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, cfg.n_clusters),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(cfg.n_clusters, 2048),
            torch.nn.BatchNorm1d(2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, cfg.input_dim),
        )
        self.aux_classifier = torch.nn.Sequential(
            torch.nn.Linear(cfg.input_dim, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, cfg.n_clusters),
        )
        self.encoder.apply(self.init_weights_normal)
        self.clustering_head.apply(self.init_weights_normal)
        self.decoder.apply(self.init_weights_normal)
        self.aux_classifier.apply(self.init_weights_normal)

    @staticmethod
    def init_weights_normal(m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, std=0.001)
            if 'bias' in vars(m).keys():
                m.bias.data.fill_(0.0)

    def pretrain_forward(self, x):
        return self.decoder(self.encoder(x))


class GatingNet(torch.nn.Module):
    def __init__(self, cfg):
        super(GatingNet, self).__init__()
        self.cfg = cfg
        self._sqrt_2 = math.sqrt(2)
        self.sigma = 0.5
        self.net = torch.nn.Sequential(
            torch.nn.Linear(cfg.input_dim, cfg.gates_hidden_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(cfg.gates_hidden_dim, cfg.input_dim),
            torch.nn.Tanh()
        )
        self.net.apply(self.init_weights)
        self.global_gates_net = torch.nn.Embedding(self.cfg.n_clusters, self.cfg.input_dim)
        torch.nn.init.normal_(self.global_gates_net.weight, std=0.01)

    @staticmethod
    def init_weights(m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, std=0.001)
            if 'bias' in vars(m).keys():
                m.bias.data.fill_(0.0)

    def global_forward(self, batch_size, y):
        noise = torch.normal(mean=0, std=self.sigma, size=(batch_size, self.cfg.input_dim),
                             device=self.global_gates_net.weight.device)
        z = torch.tanh(self.global_gates_net(y)).reshape(1, -1).repeat(batch_size, 1) + .5 * noise * self.training
        gates = self.hard_sigmoid(z)
        return torch.tanh(self.global_gates_net(y)), gates

    def open_global_gates(self):
        return self.hard_sigmoid(torch.tanh(self.global_gates_net.weight)).sum(dim=1).mean().cpu().item()

    def forward(self, x):
        noise = torch.normal(mean=0, std=self.sigma, size=x.size(), device=x.device)
        mu = self.net(x)
        z = mu + .5 * noise * self.training
        gates = self.hard_sigmoid(z)
        sparse_x = x * gates
        return mu, sparse_x, gates

    @staticmethod
    def hard_sigmoid(x):
        return torch.clamp(x + .5, 0.0, 1.0)

    def regularization(self, mu, reduction_func=torch.mean):
        return reduction_func(0.5 - 0.5 * torch.erf((-1 / 2 - mu) / (self.sigma * self._sqrt_2)))

    def get_gates(self, x):
        with torch.no_grad():
            gates = self.hard_sigmoid(self.net(x))
        return gates

    def num_open_gates(self, x, ):
        return self.get_gates(x).sum(dim=1).cpu().median(dim=0)[0].item()


if __name__ == "__main__":
    config = OmegaConf.create(dict(
        # GatingNet
        sigma=0.5,
        gates_hidden_dim=784,
        global_gates_reg_lambda=10,
        local_gates_reg_lambda=100,
        start_global_gates_training_on_epoch=100,

        # Autoencoder:
        ae_pretrain_epochs=100,
        ae_non_gated_epochs=10,
        mask_percentage=0.9,
        latent_noise_std=0.01,

        # MCRR:
        gamma=4,
        eps=0.1,

        # Dataset:
        dataset="MNIST",
        data_dir=".",
        input_dim=784,
        n_clusters=10,
        batch_size=256,
        repitions=20,
        tau=100,

        trainer=dict(
            gpus=1,
            auto_select_gpus=True,
            max_epochs=700,
            deterministic=True,
            logger=True,
            log_every_n_steps=20,
            check_val_every_n_epoch=10,
            enable_checkpointing=False,
        )
    ))

    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    original_cfg = config.copy()
    seed_everything(777)
    np.random.seed(777)
    if not os.path.exists(config.dataset):
        os.makedirs(config.dataset)
    model = MNISTClustering(config)
    logger = TensorBoardLogger(config.dataset, name=os.path.basename(__file__), log_graph=False)
    trainer = Trainer(**config.trainer, callbacks=[LearningRateMonitor(logging_interval='step')])
    trainer.logger = logger
    trainer.fit(model)
