""""Evaluate catastrophic forgetting for neural networks of increasig width."""

import sys

import avalanche
import fire
import numpy as np
import torch
import wandb
import torchvision
from avalanche import benchmarks, evaluation
from avalanche.evaluation.metrics import (
    accuracy_metrics,
    forgetting_metrics,
    loss_metrics,
)

sys.path.insert(0, "../../experiments")
from catastrophic_forgetting import models
from evaluation import ParameterSpaceDistance
from helper import wandb_config


def experiment_catastrophic_forgetting(
    dataset: str,
    nn_width: int,
    nn_depth: int,
    train_mb_size=32,
    train_epochs=5,
    test_mb_size=32,
    lr: float = 1e-3,
    momentum: float = 0.9,
    weight_decay: float = 1e-4,
    seed: int = 42,
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
    experiment_name="catastrophic_forgetting",
):
    # Empty CUDA cache just in case
    torch.cuda.empty_cache()

    # Random state
    torch.random.manual_seed(seed)

    # Benchmark dataset
    benchmark_dataset = init_dataset(dataset=dataset, seed=seed)

    # Model
    model = init_model(dataset=dataset, nn_width=nn_width, nn_depth=nn_depth)

    # Optimizer and loss
    optimizer = torch.optim.SGD(
        model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
    )
    criterion = torch.nn.CrossEntropyLoss()

    # WandB logger
    wandb_logger = avalanche.logging.WandBLogger(
        project_name=wandb_config.WANDB_PROJECT,
        run_name=None,
        params={
            "entity": wandb_config.WANDB_ENTITY,
            "group": experiment_name,
            "reinit": True,
        },
        config={
            "dataset": dataset,
            "seed": seed,
            "model": model.__class__.__name__,
            "nn_width": nn_width,
            "nn_depth": nn_depth,
            "optimizer": optimizer.__class__.__name__,
            "lr": lr,
            "momentum": momentum,
            "weight_decay": weight_decay,
            "train_epochs": train_epochs,
            "train_mb_size": train_mb_size,
            "test_mb_size": test_mb_size,
        },
    )

    # Metrics
    # API docs: https://avalanche-api.continualai.org/en/v0.4.0/evaluation.html
    eval_plugin = avalanche.training.plugins.EvaluationPlugin(
        accuracy_metrics(minibatch=False, epoch=False, experience=True, stream=True),
        loss_metrics(minibatch=False, epoch=False, experience=True, stream=True),
        forgetting_metrics(experience=True, stream=True),
        ParameterSpaceDistance(experience_for_initial_params=0),
        # ParameterSpaceDistance(experience_for_initial_params=1),
        loggers=[avalanche.logging.InteractiveLogger(), wandb_logger],
        strict_checks=False,
        collect_all=True,
    )

    # Continual learning strategy
    cl_strategy = avalanche.training.Naive(
        model,
        optimizer,
        criterion,
        train_mb_size=train_mb_size,
        train_epochs=train_epochs,
        eval_mb_size=test_mb_size,
        evaluator=eval_plugin,
        device=device,
    )

    # Train and test loop over the stream of experiences
    for train_exp in benchmark_dataset.train_stream:
        cl_strategy.train(train_exp)
        cl_strategy.eval(benchmark_dataset.test_stream)


def init_dataset(
    dataset: str,
    seed: int,
):
    if dataset == "rotatedmnist":
        n_experiences = 9

        return benchmarks.classic.RotatedMNIST(
            n_experiences=n_experiences,
            rotations_list=np.linspace(0, 180, n_experiences),
            return_task_id=True,
            seed=seed,
        )
    elif dataset == "splitmnist":
        return benchmarks.classic.SplitMNIST(
            n_experiences=5,
            return_task_id=True,
            seed=seed,
        )
    elif dataset == "splitfmnist":
        return benchmarks.classic.SplitFMNIST(
            n_experiences=10,
            return_task_id=True,
            seed=seed,
        )
    elif dataset == "splitcifar10":
        return benchmarks.classic.SplitCIFAR10(
            n_experiences=5,
            return_task_id=True,
            seed=seed,
        )
    elif dataset == "splitcifar100":
        return benchmarks.classic.SplitCIFAR100(
            n_experiences=20,
            return_task_id=True,
            seed=seed,
        )
    elif dataset == "splittinyimagenet":
        normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        train_transform =torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                normalize,
                torchvision.transforms.Resize(32),
            ]
        )
        test_transform =torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                normalize,
                torchvision.transforms.Resize(32),
            ]
        )
        return benchmarks.classic.SplitTinyImageNet(
            n_experiences=20,
            return_task_id=True,
            train_transform=train_transform,
            eval_transform=test_transform,
            seed=seed
        )
    else:
        raise ValueError("Unknown dataset.")


def init_model(
    dataset: str,
    nn_width: int,
    nn_depth: int,
):
    if dataset == "rotatedmnist":
        return avalanche.models.SimpleMLP(
            num_classes=10,
            input_size=28 * 28,
            hidden_size=nn_width,
            hidden_layers=nn_depth - 1,
            drop_rate=0.0,
        )
    elif dataset == "splitmnist":
        model = avalanche.models.SimpleMLP(
            num_classes=10,
            input_size=28 * 28,
            hidden_size=nn_width,
            hidden_layers=nn_depth - 1,
            drop_rate=0.0,
        )
        return avalanche.models.as_multitask(model, "classifier")
    elif dataset == "splitfmnist" or dataset == "splitcifar10":
        model = models.WideResNet(
            depth=nn_depth,
            num_classes=10,
            widen_factor=nn_width,
            drop_rate=0.0,
        )
        return avalanche.models.as_multitask(model, "classifier")
    elif dataset == "splitcifar100":
        model = models.WideResNet(
            depth=nn_depth,
            num_classes=100,
            widen_factor=nn_width,
            drop_rate=0.0,
        )
        return avalanche.models.as_multitask(model, "classifier")
    elif dataset == "splittinyimagenet":
        model = models.WideResNet(
            depth=nn_depth,
            num_classes=200,
            widen_factor=nn_width,
            drop_rate=0.0,
        )
        return avalanche.models.as_multitask(model, "classifier")
    else:
        raise ValueError("Unknown dataset.")


if __name__ == "__main__":
    fire.Fire(experiment_catastrophic_forgetting)
