# Copyright (c) 2021 Copyright holder of the paper "Test-Time Adaptation to Distribution Shifts by Confidence Maximization and Input Transformation" submitted to NeurIPS 2021 for review
# All rights reserved.

"""Conduct test time adaptation"""

import os
import tempfile
import argparse
from tqdm import tqdm

import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import mlflow
import mlflow.pytorch
import yaml
import pandas as pd

from data import get_datasets, split_dataset, to_dataloader
from utils import AverageMeter, accuracy


class ModelAugmentation(nn.Module):
    """Prepend the model augmentation module to a pre-trained model"""
    def __init__(self, net, config, in_channel=3):
        super().__init__()

        self.net = net  # pre-trained model

        self.affine = config.get("affine", True)
        self.normalization = config.get("normalization", True)
        self.depth = config.get("depth", 6)
        self.multiplicity = config.get("multiplicity", 6)
        self.kernel_size = config.get("kernel_size", 3)

        respath_modules = []
        for i in range(self.depth):
            n_in = in_channel if i == 0 else in_channel * self.multiplicity
            n_out = in_channel * self.multiplicity if i < self.depth - 1 else in_channel
            respath_modules.append(
                nn.Conv2d(
                    n_in,
                    n_out,
                    kernel_size=self.kernel_size,
                    padding=self.kernel_size // 2,
                    padding_mode="reflect",
                ).cuda()
            )
            if self.normalization:
                respath_modules.append(
                    nn.GroupNorm(num_channels=n_out, num_groups=n_out)
                )

            if i != self.depth - 1:
                respath_modules.append(nn.ReLU(inplace=True))

        self.residual_path = nn.Sequential(*respath_modules)

        self.residual_weight = nn.Parameter(torch.zeros(1, 1, 1, 1))

        if self.affine:
            self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1))
            self.bias = nn.Parameter(torch.zeros(1, in_channel, 1, 1))

    def preprocess(self, inputs):
        residual = self.residual_path(inputs)
        combined = (1 - self.residual_weight) * inputs + self.residual_weight * residual
        if self.affine:
            return combined * self.scale + self.bias
        else:
            return combined

    def forward(self, inputs):
        return self.net(self.preprocess(inputs))


def set_norm_layer_requires_grad(module, flag):
    """Set the requires_grad flag for normalization layers"""
    if isinstance(module, (nn.GroupNorm, nn.BatchNorm2d)):
        if module.weight is not None:
            module.weight.requires_grad = flag
        if module.bias is not None:
            module.bias.requires_grad = flag


def set_model_augmentation_requires_grad(module, flag):
    """Set the requires_grad flag for model augmentation module"""
    if isinstance(module, ModelAugmentation):
        module.residual_weight.requires_grad = flag
        for name, params in module.residual_path.named_parameters():
            params.requires_grad = flag
        if module.affine:
            module.scale.requires_grad = flag
            module.bias.requires_grad = flag


def get_optimizer(net, config):
    """set the parameters to be updated in the optimizer"""
    for param in net.parameters():
        param.requires_grad = False
    if "affine" in config["parameters_to_update"]:
        net.apply(lambda m: set_norm_layer_requires_grad(m, True))
    if "input" in config["parameters_to_update"]:
        net.apply(lambda m: set_model_augmentation_requires_grad(m, True))
    if config["parameters_to_update"] == "all":
        for param in net.parameters():
            param.requires_grad = True

    if config["freeze_top_layers"]:
        # freeze top layers of the network
        for name, param in net.named_parameters():
            if any(x in name for x in config["freeze"]):
                param.requires_grad = False

    params = filter(lambda p: p.requires_grad, net.parameters())

    """Construct the optimizer from a configuration."""
    if config["optimizer"]["type"] == "SGD":
        return optim.SGD(
            params=params,
            lr=config["optimizer"]["lr"],
            momentum=config["optimizer"]["momentum"],
            weight_decay=config["optimizer"]["weight_decay"],
        )
    elif config["optimizer"]["type"] == "Adam":
        return optim.Adam(
            params=params,
            lr=config["optimizer"]["lr"],
            weight_decay=config["optimizer"]["weight_decay"],
        )
    else:
        raise ValueError("Unsupported optimizer.")


def compute_entropy_loss(probs, logits=None):
    if logits is None:
        return -(probs * torch.log(probs)).sum(dim=1).mean(dim=0)
    else:
        return -(probs * F.log_softmax(logits, dim=1)).sum(dim=1).mean(dim=0)


def confidence_maximization(logits, probs, loss_type="hard_likelihood_ratio"):
    """compute loss based on the type of the desired loss function"""
    if loss_type == "soft_likelihood_ratio":
        delta = 0.025
        logits = logits.double()  # we need double-precision here
        # Shift logits to stabilize log-sum-exp computation
        logits_max = logits.max(dim=1, keepdim=True)[0]
        logits_shifted = logits - logits_max
        # Compute exponentiated logits and their sum once
        logit_exp = torch.exp(logits_shifted)
        logit_exp_sum = torch.sum(logit_exp, dim=1, keepdim=True)
        # Vectorized computation of the logsumexp terms (excluding the respective class)
        logsumexp = torch.log(logit_exp_sum - logit_exp + 1e-50)
        # Compute expected probability ratio loss
        conf_max_loss = -delta * torch.sum(probs * (logits_shifted - logsumexp), dim=1)

    elif loss_type == "hard_likelihood_ratio":
        delta = 0.025
        topk = torch.topk(logits, k=logits.shape[1], dim=1).values
        conf_max_loss = -delta * (topk[:, 0] - torch.logsumexp(topk[:, 1:], dim=1))

    elif loss_type == "pseudolabels":
        conf_max_loss = -1.0 * (logits.max(1)[0] - torch.logsumexp(logits, dim=1))

    elif loss_type == "TENT":
        conf_max_loss = compute_entropy_loss(probs, logits)

    return torch.mean(conf_max_loss)


def running_diversity_maximization(probs, running_probs_mean, kappa):
    """compute running estimate for Class Distribution Matching L_div"""
    this_batch_probs_mean = probs.mean(dim=0, keepdim=True)
    running_probs_mean = running_probs_mean
    if running_probs_mean is None:
        running_probs_mean = this_batch_probs_mean
    else:
        running_probs_mean = (
            kappa * running_probs_mean.detach() + (1 - kappa) * this_batch_probs_mean
        )
        running_probs_mean = running_probs_mean / running_probs_mean.sum(
            dim=1, keepdim=True
        )  # normalize to probability dist
        running_probs_mean = torch.clamp(running_probs_mean, 1e-7, 1)

    running_div_reg_loss = -compute_entropy_loss(running_probs_mean)
    return running_div_reg_loss, running_probs_mean


def compute_test_time_loss(model, inputs, config, running_probs_mean, used_labels):
    loss = {}  # dict to collect the losses

    test_inputs = inputs.cuda()
    logits = model(test_inputs)

    # if used_labels is None, then use all the 1000 ImageNet classes
    if used_labels is not None:
        """
        mask the labels in two different cases
        (1) when using ImageNet-R, this dataset contains only 200 ImageNet classes, 
        (2) when conducting ablation study on adaptation with subset of classes.
        """
        logits = logits[:, used_labels]

    probs = F.softmax(logits, dim=1)

    if config["confidence_maximization"]:
        loss["conf_max_loss"] = confidence_maximization(
            logits, probs, loss_type=config["confidence_maximization"]
        )

    if config["running_diversity_regularizer"]:
        running_div_loss, running_probs_mean = running_diversity_maximization(
            probs, running_probs_mean, config["kappa"]
        )
        loss["running_diversity_loss"] = running_div_loss

    return loss, running_probs_mean


def test_time_adapt(model, loader_train, loader_eval, config, num_epochs, title):
    """Adapt the model on a target data"""
    print("\n", title)
    running_losses = {}
    if config["confidence_maximization"]:
        running_losses["conf_max_loss"] = AverageMeter()
    if config["running_diversity_regularizer"]:
        running_losses["running_diversity_loss"] = AverageMeter()

    optimizer = get_optimizer(model, config)
    model.train()

    running_probs_mean = None
    try:
        used_labels = loader_train.dataset.used_labels
    except AttributeError:
        used_labels = None

    outer_pbar = tqdm(total=num_epochs, desc="Epoch", position=0)
    for epoch_no in range(num_epochs):
        inner_pbar = tqdm(loader_train, desc="Training", position=1)

        model.train()

        # cosine decay schedule of the learning rate
        loss_type = config["confidence_maximization"]
        if loss_type != "TENT":
            initial_lr = config["optimizer"]["lr"]
            lr = 0.5 * initial_lr * (1 + np.cos(np.pi * epoch_no / num_epochs))
            print(" Learning rate decayed to {}".format(lr))
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr

        # start the batch wise adaptation
        for inputs, _ in inner_pbar:
            batch_size = inputs.size(0)
            optimizer.zero_grad()

            (losses, running_probs_mean) = compute_test_time_loss(
                model, inputs, config, running_probs_mean, used_labels
            )

            loss = 0
            for loss_type, loss_val in losses.items():
                loss += loss_val
                running_losses[loss_type].update(loss_val, batch_size)

            loss.backward()
            optimizer.step()
            inner_pbar.update(1)

        outer_pbar.update(1)

        # evaluate and print top1 accuracy after 1st epoch of adaptation
        if epoch_no == 0:
            top1 = evaluate(model, loader_eval, title, epoch_no, num_epochs)
            epoch1_top1_acc = top1

    return model, epoch1_top1_acc


def evaluate(model, loader, title, current_epoch=-1, total_epochs=-1):
    """Evaluate the model on entire target data"""
    top1 = AverageMeter()
    model.eval()

    try:
        unused_labels = loader.dataset.unused_labels
    except AttributeError:
        unused_labels = None

    with torch.no_grad():
        pbar = tqdm(loader, desc="Evaluation")
        for inputs, targets in pbar:
            batch_size = targets.size(0)
            test_inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(test_inputs)
            if unused_labels:
                outputs[:, unused_labels] = -float("inf")
            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            top1.update(prec1[0], batch_size)

    title = f"Epoch-{str(current_epoch+1)}/{str(total_epochs)} {title}"
    print(f"{title} Classification Top 1 Acc: {top1.avg.item():.2f}")

    return round(top1.avg.item(), 2)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=argparse.FileType("r"))
    parser.add_argument("--experiment-name", type=str, default="ttt")
    # Parse changes for the config passed as CLI arguments
    args, args_remainder = parser.parse_known_args()
    config = yaml.load(args.config, Loader=yaml.SafeLoader)

    # Load the right model type, but don't save the list of all the other models
    models = config.pop("models")
    model_type = config["model_type"]
    config["model"] = models[model_type]["model"]
    config["freeze"] = models[model_type]["freeze"]

    mlflow.set_tracking_uri(
        "file:///home/test-time-robustness/mlruns"
    )
    mlflow.set_experiment(args.experiment_name)
    mlflow.set_tag("mlflow.runName", config["run_name"])
    mlflow.log_artifact(args.config.name)

    datasets = get_datasets(config)
    num_epochs = config["test_time_epochs"]

    results = []
    for key, dataset in datasets.items():
        model = mlflow.pytorch.load_model(config["model"])

        if "model_augmentation" in config and config["model_augmentation"].get(
            "enabled", True
        ):
            # Prepend trainable model augmentation module
            print("Prepending augmentation module...")
            model = ModelAugmentation(model, config=config["model_augmentation"])

        model = model.cuda()

        if "data_split" in config:
            # ablation study to investigate the effect of dataset size (Figure 4 in the main manuscript)
            dataset_train, _ = split_dataset(dataset, **config["data_split"])
        else:
            dataset_train = dataset

        dataloader_train = to_dataloader(dataset_train, config, shuffle=True)
        dataloader_eval = to_dataloader(dataset, config)

        top1_before = evaluate(model, dataloader_eval, title=key)
        model, epoch1_top1_acc = test_time_adapt(
            model, dataloader_train, dataloader_eval, config, num_epochs, title=key
        )
        top1_after = evaluate(model, dataloader_eval, title=key)

        # save the results to a list
        results.extend(
            [
                [*key.split("-"), "pre-adaption", "0", top1_before],
                [*key.split("-"), "post-adaption", "1", epoch1_top1_acc],
                [*key.split("-"), "post-adaption", str(num_epochs), top1_after],
            ]
        )

        if "model_augmentation" in config and config["model_augmentation"].get(
            "enabled", True
        ):
            # Plot preprocessing
            for inputs, _ in dataloader_eval:
                break
            plt.figure(0, figsize=(12, 7))
            plt.subplot(1, 2, 1)

            plt.imshow(inputs.cpu().numpy()[0].transpose(1, 2, 0))
            plt.title("Distorted input")
            plt.xticks([])
            plt.yticks([])
            plt.subplot(1, 2, 2)
            preprocessed = model.preprocess(inputs.cuda())
            min = preprocessed.min()
            max = preprocessed.max()
            preprocessed = (preprocessed - min) / (max - min)
            plt.imshow(preprocessed.detach().cpu().numpy()[0].transpose(1, 2, 0))
            plt.title("Preprocessed input [%.2f, %.2f]" % (min, max))
            plt.xticks([])
            plt.yticks([])
            plt.tight_layout()
            with tempfile.TemporaryDirectory() as tmpdir:
                path = os.path.join(tmpdir, "%s_preprocessed.png" % key)
                plt.savefig(path)
                mlflow.log_artifact(path)
            plt.close()

    columns = ["corruption", "severity", "adaption", "epochs", "metric"]
    df = pd.DataFrame(results, columns=columns)
    df["severity"] = df["severity"].map(int)

    with tempfile.TemporaryDirectory() as tmpdir:
        path = os.path.join(tmpdir, "results.csv")
        df.to_csv(path, index=False)
        mlflow.log_artifact(path)

    mlflow.pytorch.log_model(model, "model")


if __name__ == "__main__":
    main()
