import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchmetrics import CohenKappa, SpearmanCorrCoef
import random
import argparse
import os
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from omegaconf import OmegaConf
import copy
from torch import nn, optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import torch.distributions as distributions
import numpy as np
import torch
from torchmetrics import Metric
from torch.distributions import Categorical
import torch.nn.functional as F
from torchmetrics.classification import MulticlassCalibrationError


class MetricsLogger(object):
    def __init__(self, out_dir):
        self.metrics = {
            'test_mae': [],
            'test_accuracy': [],
            'test_oneoff_accuracy': [],
            'test_unimodality': [],
            'test_entropy_ratio': [],
            'test_kappa': [],
            'test_spearman': [],
            'test_ece': [],
        }
        self.out_file = os.path.join(out_dir, f"test_results.txt")

    def update(self, result: dict):
        for k, v in result.items():
            self.metrics[k].append(v.item())

    def write_intermediate(self, workdir, result):
        with open(f'{workdir}/RESULTS.txt', mode='w') as w:
            for k, v in result.items():
                w.write(f'{k} {v:.4f}\n')

    def write(self):
        with open(self.out_file, mode='a') as w:
            for k, v in self.metrics.items():
                w.write('{}: {}\n'.format(k, ['{:.4f}'.format(val) for val in v]))
                w.write('{}: mean = {:.4f}, std = {:.4f}\n'.format(k, np.mean(v).item(), np.std(v).item()))


class ExactAccuracy(Metric):
    def __init__(self, compute_on_step=True, dist_sync_on_step=False):
        super().__init__()

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        self.to(preds.device)
        preds = torch.argmax(preds, dim=-1)
        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        return self.correct.float() / self.total


class OneOffAccuracy(Metric):
    def __init__(self, compute_on_step=True, dist_sync_on_step=False):
        super().__init__()

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        self.to(preds.device)
        preds = torch.argmax(preds, dim=-1)
        self.correct += torch.sum(preds == target) + torch.sum(preds == target - 1) + torch.sum(preds == target + 1)
        self.total += target.numel()

    def compute(self):
        return self.correct.float() / self.total


class MAE(Metric):
    def __init__(self, compute_on_step=True, dist_sync_on_step=False):
        super().__init__()
        self.add_state("loss_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        self.to(preds.device)
        preds = torch.argmax(preds, dim=-1)
        self.loss_sum += F.l1_loss(preds.float(), target.float(), reduction='sum')
        self.total += target.numel()

    def compute(self):
        return self.loss_sum.float() / self.total


class EntropyRatio(Metric):
    def __init__(self, compute_on_step=False, dist_sync_on_step=False, output_logits=False):
        super().__init__()

        self.add_state("entropy_correct", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("entropy_incorrect", default=torch.tensor(0.), dist_reduce_fx="sum")

        self.output_logits = output_logits

    def _entropy(self, pred: torch.Tensor):
        if pred.nelement() == 0: return torch.tensor(0, device=pred.device)
        return Categorical(probs=pred).entropy().mean()

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        self.to(preds.device)
        if self.output_logits: preds = torch.softmax(preds, -1)
        preds_classes = torch.argmax(preds, dim=-1)
        self.entropy_correct += self._entropy(preds[preds_classes == target])
        self.entropy_incorrect += self._entropy(preds[preds_classes != target])

    def compute(self):
        return self.entropy_incorrect / self.entropy_correct


class Unimodality(Metric):
    def __init__(self, compute_on_step=True, dist_sync_on_step=False, output_logits=False):
        super().__init__()

        self.add_state("unimodal", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.output_logits = output_logits

    def _is_unimodal(self, pred: torch.Tensor):
        prob = pred.cpu().numpy()
        res = True
        argmax = np.argmax(prob)
        for i in range(argmax, 0, -1):
            res = res & (prob[i] >= prob[i - 1])
        for i in range(argmax, len(prob) - 1):
            res = res & (prob[i] >= prob[i + 1])
        return res

    def update(self, preds: torch.Tensor):
        self.to(preds.device)
        if self.output_logits:
            preds = torch.softmax(preds, -1)
        for i in range(preds.size(0)):
            self.unimodal += torch.tensor(self._is_unimodal(preds[i]), device=preds.device)
        self.total += preds.size(0)

    def compute(self):
        return self.unimodal.float() / self.total


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="cfg.yaml")
    args = parser.parse_args()
    return args


class OTLoss(nn.Module):

    def __init__(self, n_classes, cost='linear'):
        super().__init__()
        self.num_classes = n_classes
        C0 = np.expand_dims(np.arange(n_classes), 0).repeat(n_classes, axis=0) / self.num_classes
        C1 = np.expand_dims(np.arange(n_classes), 1).repeat(n_classes, axis=1) / self.num_classes

        self.C = np.abs(C0 - C1)
        if cost == 'quadratic':
            self.C = self.C ** 2
        elif cost == 'linear':
            pass

    def forward(self, output_probs, target_class):
        C = torch.tensor(self.C, device=output_probs.device).float()
        costs = C[target_class.long()]
        transport_costs = torch.sum(costs * output_probs, dim=1)
        result = torch.mean(transport_costs)
        return result


class UnimodalNormal(nn.Module):
    def __init__(self, num_classes, input_dim, dist_func, min_sigma, bins_limit, sigma_scaling_dim,
                 learn_bin_limit=False):
        super(UnimodalNormal, self).__init__()
        self.num_classes = num_classes
        self.dist_func = dist_func
        self.min_sigma = min_sigma
        self.learn_bin_limit = learn_bin_limit
        if learn_bin_limit:
            self.bins_limit = nn.Parameter(torch.randn(1)).requires_grad_(True)
        else:
            self.bins_limit = bins_limit
        self.mu_output = nn.Sequential(
            nn.Linear(input_dim, 1),
            nn.Tanh()
        )
        self.sigma = nn.Sequential(
            nn.Linear(input_dim, sigma_scaling_dim),
            nn.ReLU(),
            nn.Linear(sigma_scaling_dim, 1),
            nn.Softplus()
        )

        self.sigma_gp = nn.Sequential(
            nn.Linear(input_dim, 1),
            nn.Softplus()
        )

    def calc_normal_output_probs(self, mu, sig):
        if self.learn_bin_limit:
            bins_limit = torch.sigmoid(self.bins_limit).clamp_min(0.5)
            thresholds = torch.arange(0, self.num_classes + 1, device=mu.device) / self.num_classes * (
                        2 * bins_limit) - bins_limit
        else:
            thresholds = torch.arange(0, self.num_classes + 1, device=mu.device) / self.num_classes * (
                        2 * self.bins_limit) - self.bins_limit
        dist_func_instance = getattr(distributions, self.dist_func)(mu, sig)
        probs = torch.zeros(mu.size(0), self.num_classes, device=mu.device).float()
        for i in range(self.num_classes):
            probs[:, i] = (dist_func_instance.cdf(thresholds[i + 1]) - dist_func_instance.cdf(thresholds[i])).squeeze()
        norm_matrix = torch.diag(1. / torch.sum(probs, dim=1))
        return torch.matmul(norm_matrix, probs)

    def calc_output_probs(self, x):
        mu = self.mu_output(x)
        sig = self.sigma(x).clamp(min=self.min_sigma, max=1e2)
        output_probs = self.calc_normal_output_probs(mu=mu, sig=sig)
        return output_probs

    def forward(self, x, get_outputs=False):
        return self.get_outputs(x) if get_outputs else self.calc_output_probs(x)

    def get_outputs(self, x):
        mu = self.mu_output(x)
        sig = self.sigma(x).clamp(min=self.min_sigma, max=1e2)
        output_probs = self.calc_normal_output_probs(mu=mu, sig=sig)
        return output_probs, mu, sig


class MultilayerPerceptron(torch.nn.Module):
    def __init__(self, num_features,
                 num_hidden_1, num_hidden_2):
        super().__init__()
        self.layers = torch.nn.Sequential(
            # 1st hidden layer
            torch.nn.Linear(num_features, num_hidden_1, bias=False),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.BatchNorm1d(num_hidden_1),

            # 2nd hidden layer
            torch.nn.Linear(num_hidden_1, num_hidden_2, bias=False),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.BatchNorm1d(num_hidden_2),
        )

    def forward(self, x):
        return self.layers(x)


class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
    # adopted from https://github.com/galatolofederico/pytorch-balanced-batch/blob/master/sampler.py
    def __init__(self, dataset, labels=None):
        self.labels = labels
        self.dataset = dict()
        self.balanced_max = 0
        # Save all the indices for all the classes
        for idx in range(0, len(dataset)):
            label = self._get_label(dataset, idx)
            if label not in self.dataset:
                self.dataset[label] = list()
            self.dataset[label].append(idx)
            self.balanced_max = len(self.dataset[label]) \
                if len(self.dataset[label]) > self.balanced_max else self.balanced_max

        # Oversample the classes with fewer elements than the max
        for label in self.dataset:
            while len(self.dataset[label]) < self.balanced_max:
                self.dataset[label].append(random.choice(self.dataset[label]))
        self.keys = list(self.dataset.keys())
        self.currentkey = 0
        self.indices = [-1] * len(self.keys)

    def __iter__(self):
        while self.indices[self.currentkey] < self.balanced_max - 1:
            self.indices[self.currentkey] += 1
            yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]]
            self.currentkey = (self.currentkey + 1) % len(self.keys)
        self.indices = [-1] * len(self.keys)

    def _get_label(self, dataset, idx, labels=None):
        if self.labels is not None:
            return self.labels[idx].item()

    def __len__(self):
        return self.balanced_max * len(self.keys)


class FiremanTabularDataset(torch.utils.data.Dataset):
    def __init__(self, config, task, data_splits):
        super().__init__()
        self.X, self.Y = data_splits[task]

    def get_labels(self):
        return self.Y.numpy()

    def __len__(self):
        return len(self.X)

    @staticmethod
    def split_dataset():
        feature_list = ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10']
        df_train = pd.read_csv("fireman_splits_balanced/fireman_balanced_train.csv")
        Y_train = torch.from_numpy(df_train['response'].values).long()
        X_train = torch.from_numpy(df_train[feature_list].values).float()

        df_val = pd.read_csv("fireman_splits_balanced/fireman_balanced_valid.csv")
        Y_val = torch.from_numpy(df_val['response'].values).long()
        X_val = torch.from_numpy(df_val[feature_list].values).float()

        df_test = pd.read_csv("fireman_splits_balanced/fireman_balanced_test.csv")
        Y_test = torch.from_numpy(df_test['response'].values).long()
        X_test = torch.from_numpy(df_test[feature_list].values).float()

        num_classes = len(np.unique(df_train['response'].values))

        print('size(X_train)={}'.format(len(X_train)))
        print('size(X_test)={}'.format(len(X_test)))
        print('size(X_val)={}'.format(len(X_val)))
        print(f'Num of classes: {num_classes}')
        return {
            'train': (X_train, Y_train),
            'test': (X_test, Y_test),
            'val': (X_val, Y_val),
        }

    def __getitem__(self, idx):
        return {
            'image': self.X[idx],
            'label': self.Y[idx]
        }


class UNICORNN(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters()
        self.transition_layer = torch.nn.ReLU(True)
        self.output_layers = UnimodalNormal(config.num_classes, config.ordinal_input_dim, config.dist_func,
                                            config.min_sigma, config.bins_limit, config.sigma_scaling_dim)
        self.loss_func = OTLoss(config.num_classes)
        self.transforms = None
        torch.set_printoptions(precision=10)
        self.val_ece = MulticlassCalibrationError(num_classes=config.num_classes, n_bins=10)
        self.val_entropy_ratio = EntropyRatio(output_logits=config.output_logits)
        self.test_mae = MAE()
        self.test_ece = MulticlassCalibrationError(num_classes=config.num_classes, n_bins=10)
        self.test_accuracy = ExactAccuracy()
        self.test_one_off_accuracy = OneOffAccuracy()
        self.test_entropy_ratio = EntropyRatio(output_logits=config.output_logits)
        self.test_unimodality = Unimodality(output_logits=config.output_logits)
        self.test_kappa = CohenKappa(task="multiclass", num_classes=config.num_classes, weights='quadratic')
        self.test_spearman = SpearmanCorrCoef()
        self.test_metrics = {}
        self.summary_writer = None
        self.sigma_scaling = config.sigma_scaling
        self.dataset_class = FiremanTabularDataset
        self.data_splits = FiremanTabularDataset.split_dataset()
        self.backbone_model = MultilayerPerceptron(self.data_splits['train'][0].shape[1], 300, 300)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.lr, weight_decay=self.config.wd)
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                gamma=self.config.gamma,
                milestones=[int(v) for v in self.config.lr_sched.split(',')]
            ),
            'interval': 'epoch',
            'frequency': 1}
        return [optimizer], [scheduler]

    def build_data_loader(self, task):
        if task == 'train':
            shuffle = True
            batch_size = self.config.train_batch_size
            workers = self.config.train_workers
        elif task == 'test':
            shuffle = False
            batch_size = self.config.test_batch_size
            workers = self.config.test_workers
        elif task == 'val':
            shuffle = False
            batch_size = self.config.val_batch_size
            workers = self.config.val_workers
        transformed_dataset = self.dataset_class(self.config, task, self.data_splits)

        sampler = None
        if task == 'train':
            return DataLoader(transformed_dataset,
                              batch_size=batch_size,
                              num_workers=workers,
                              drop_last=True,
                              sampler=BalancedBatchSampler(transformed_dataset, transformed_dataset.Y))

        return DataLoader(transformed_dataset,
                          batch_size=batch_size,
                          shuffle=shuffle,
                          sampler=sampler,
                          num_workers=workers,
                          drop_last=True)

    def forward(self, x, get_outputs=False):
        x = self.backbone_model(x)
        x = self.transition_layer(x)
        x = self.output_layers(x, get_outputs)
        return x

    def train_dataloader(self):
        loader = self.build_data_loader('train')
        return loader

    def val_dataloader(self):
        return self.build_data_loader('val')

    def test_dataloader(self):
        return self.build_data_loader('test')

    def training_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label'].float()
        y_hat = self(x)
        loss = self.loss_func(y_hat, y)
        self.log_dict({'train_loss': loss}, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def log_image_to_tb(self, images, true_labels, pred_labels):
        for i, image in enumerate(images):
            true_label = true_labels[i]
            pred_label = pred_labels[i]
            self.summary_writer.add_image(f'pred: {pred_label} label: {true_label}', image, self.current_epoch)
        self.summary_writer.flush()

    def validation_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label'].float()
        y_hat = self(x)
        try:
            y_probs = self.loss_func.to_proba(y_hat)
        except AttributeError:
            y_probs = y_hat
        if isinstance(y_probs, tuple):
            y_probs = y_probs[0]
        self.val_entropy_ratio(y_probs, y)
        self.val_ece(y_probs, y)
        self.log_dict({
            'val_loss': self.loss_func(y_hat, y),
            'val_mae': MAE()(y_probs, y),
            'val_accuracy': ExactAccuracy()(y_probs, y),
            'val_oneoff_accuracy': OneOffAccuracy()(y_probs, y),
            'val_unimodality': Unimodality(output_logits=self.config.output_logits)(y_probs),
            'val_kappa': CohenKappa(
                task="multiclass",
                num_classes=self.config.num_classes,
                weights='quadratic').to(y_probs.device)(y_probs, y.int()),
            'val_spearman': SpearmanCorrCoef().to(y_probs.device)(y_probs.argmax(dim=1).float(), y.float()),
        }, on_step=False, on_epoch=True, logger=True)

    def on_validation_epoch_end(self):
        self.log_dict({
            'val_ece': self.val_ece.compute(),
            'val_entropy_ratio': self.val_entropy_ratio.compute(),
        }, on_step=False, on_epoch=True)
        self.val_ece = MulticlassCalibrationError(num_classes=self.config.num_classes, n_bins=10)

    def test_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label'].float()
        y_hat = self(x)
        try:
            y_hat = self.loss_func.to_proba(y_hat)
        except AttributeError:
            pass
        if isinstance(y_hat, tuple):
            y_hat = y_hat[0]
        if self.config.error_analysis:
            preds_classes = torch.argmax(y_hat, dim=-1)
            wrong_preds = x[preds_classes != y]
            true_labels = y[preds_classes != y]
            pred_labels = preds_classes[preds_classes != y]
            self.log_image_to_tb(wrong_preds, true_labels, pred_labels)

        self.test_entropy_ratio(y_hat, y)
        self.test_mae(y_hat, y)
        self.test_accuracy(y_hat, y)
        self.test_one_off_accuracy(y_hat, y)
        self.test_unimodality(y_hat)
        self.test_ece(y_hat, y)
        self.test_kappa.to(y_hat.device)(y_hat, y.int())
        self.test_spearman.to(y_hat.device)(y_hat.argmax(dim=1).float(), y.float())

    def on_test_epoch_end(self):
        self.test_metrics = {
            'test_mae': self.test_mae.compute(),
            'test_accuracy': self.test_accuracy.compute(),
            'test_oneoff_accuracy': self.test_one_off_accuracy.compute(),
            'test_unimodality': self.test_unimodality.compute(),
            'test_entropy_ratio': self.test_entropy_ratio.compute(),
            'test_ece': self.test_ece.compute(),
            'test_kappa': self.test_kappa.compute(),
            'test_spearman': self.test_spearman.compute(),
        }
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True)

    def on_train_end(self):
        if self.config.error_analysis:
            self.summary_writer.close()


def sigma_scaling(model, config, logger):
    criterion = nn.CrossEntropyLoss().cuda()
    if config.criterion == 'mse':
        criterion = nn.MSELoss().cuda()
    model = model.cuda()
    optimizer = optim.Adam(model.output_layers.sigma.parameters(), lr=config.lr, weight_decay=config.sigma_scaling_wd)
    data_loader = model.train_dataloader()
    val_loader = model.val_dataloader()
    train_loss_list = []
    val_loss_list = []
    val_ece_list = []
    best_ece = 1000
    best_sigma_nn = None
    print('Starting sigma scaling')
    for i in range(config.scaling_epochs):
        print('Starting epoch %d' % i)
        sum_train_losses, num_train_batches = 0.0, 0
        model.train()
        for j, batch in tqdm(enumerate(data_loader)):
            input, labels = batch["image"].cuda(), batch["label"].cuda()
            if config.criterion == 'mse':
                labels = F.one_hot(labels.to(torch.int64), num_classes=config.num_classes).cuda().float()
            else:
                labels = labels.type(torch.LongTensor).cuda()

            def eval():
                nonlocal sum_train_losses, num_train_batches
                optimizer.zero_grad()
                with torch.no_grad():
                    o = model.backbone_model(input)
                    o = model.transition_layer(o)
                preds = model.output_layers(o, False)
                loss = criterion(preds, labels)
                loss.backward()
                sum_train_losses += loss.item()
                num_train_batches += 1
                return loss

            optimizer.step(eval)
        train_loss = sum_train_losses / num_train_batches
        train_loss_list.append(train_loss)
        print('Train Loss: %.3f' % train_loss)
        sum_val_losses = 0.0
        num_val_batches = 0
        ece = MulticlassCalibrationError(num_classes=config.num_classes, n_bins=10)
        model.eval()
        for j, batch in tqdm(enumerate(val_loader)):
            input, labels = batch["image"].cuda(), batch["label"].cuda()
            if config.criterion == 'mse':
                target = F.one_hot(labels.to(torch.int64), num_classes=config.num_classes).cuda().float()
            else:
                target = labels.type(torch.LongTensor).cuda()
            with torch.no_grad():
                preds = model(input)
            loss = criterion(preds, target)
            sum_val_losses += loss.item()
            num_val_batches += 1
            ece.update(preds, labels)
        val_loss = sum_val_losses / num_val_batches
        val_loss_list.append(val_loss)
        val_ece = ece.compute()
        val_ece_list.append(val_ece.item())
        if val_ece < best_ece:
            best_ece = val_ece
            best_sigma_nn = copy.deepcopy(model.output_layers.sigma)
        print('Validation Loss: %.3f' % val_loss)
        print('Validation ECE: %.3f' % val_ece)
    path = logger.root_dir + f'/version_{logger.version}/sigma_scaling.png'
    save_plot(train_loss_list, val_loss_list, val_ece_list, path)
    model.output_layers.sigma = best_sigma_nn
    print('Finished sigma scaling')


def save_plot(train_loss, val_loss, val_ece, path):
    epochs = range(1, len(train_loss) + 1)
    plt.plot(epochs, train_loss, 'b', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.plot(epochs, val_ece, 'g', label='Validation ECE')
    plt.title('Training and Validation Loss with ECE')
    plt.xlabel('Epochs')
    plt.ylabel('Loss/ECE')
    plt.legend()
    plt.savefig(path)
    plt.close()


if __name__ == "__main__":
    args = parse_args()
    config = OmegaConf.load(args.config)
    metrics_logger = MetricsLogger(".")
    seed_everything(0)
    if not os.path.exists(config.output_dir): os.makedirs(config.output_dir)
    trainer = Trainer(**config.trainer)
    trainer.logger = TensorBoardLogger(config.output_dir, name="UNICORNN", log_graph=False)
    lr_log_callback = LearningRateMonitor(logging_interval='epoch')
    checkpoint_callback = ModelCheckpoint(
        monitor=config.checkpoint_on,
        save_top_k=1,
        mode="min",
        save_weights_only=True,
        verbose=True,
    )
    trainer.callbacks.append(lr_log_callback)
    trainer.callbacks.append(checkpoint_callback)
    model = UNICORNN(config)
    trainer.fit(model)
    best_checkpoint_path = checkpoint_callback.best_model_path
    print(f"Best checkpoint path: {best_checkpoint_path}")
    trainer._checkpoint_connector._restore_modules_and_callbacks(best_checkpoint_path)
    sigma_scaling(model, config, trainer.logger)
    trainer.test(model)
    metrics_logger.update(model.test_metrics)
    metrics_logger.write_intermediate(
        f'{trainer.logger.save_dir}/{trainer.logger.name}/version_{trainer.logger.version}', model.test_metrics)
    metrics_logger.write()
