import argparse
import copy
import pickle

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from avalanche.benchmarks.classic import SplitImageNet
from avalanche.evaluation.metrics import accuracy_metrics, forgetting_metrics, Accuracy
from avalanche.logging import TextLogger
from avalanche.models import as_multitask
from avalanche.training import Naive, EWC, Replay
from avalanche.training.plugins import EvaluationPlugin
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision.models import resnet18, resnet50, resnext50_32x4d, vit_l_16, ResNet18_Weights, ResNet50_Weights, \
    ResNeXt50_32X4D_Weights, ViT_L_16_Weights, vit_b_16, ViT_B_16_Weights

from common import set_random_seed

PYTHONUTF8=1 # fix windows-specific locale issues
DELTA=0.05
DEFAULT_NOISE=0.008

def calculate_kl_surrogate(mt_model, prev_model, bound_lambda, noise_level=DEFAULT_NOISE):
    kl_surrogate = 0.0
    NOISE_CONST = 0.5 / (noise_level ** 2)
    for name, param in prev_model.named_parameters():
        new_par = mt_model.state_dict()[name]
        kl_surrogate += ((NOISE_CONST * (new_par - param).pow(2)).sum()/bound_lambda)
    return kl_surrogate

def calculate_noisy_train_loss(mt_model, dataset, noise_level=DEFAULT_NOISE):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    N_TRIALS = 2
    losses = []
    for i in range(N_TRIALS):
        noises = {}
        for name, param in mt_model.named_parameters():
            noises[name] = torch.normal(0, noise_level, param.shape, device=device)
            param += noises[name]
        # calculate accuracy/01 loss
        acc_metric = Accuracy()
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=256, shuffle=True
        )
        for (x,y,t_id) in train_loader:
            x = x.to(device)
            y = y.to(device)
            acc_metric.update(mt_model(x, t_id), y)
        losses.append(1.0 - acc_metric.result())
        for name, param in mt_model.named_parameters():
            param -= noises[name]
    return losses


def calculate_online_bound(experience, mt_model, prev_model, bound_lambda, k):
    base_loss = np.mean(calculate_noisy_train_loss(mt_model, experience.dataset))
    kl_surrogate = calculate_kl_surrogate(mt_model, prev_model, bound_lambda)

    t = experience.current_experience + 1
    const1 = np.log(1 / DELTA)/bound_lambda/t+bound_lambda*(k**2)/(8*len(experience.dataset))
    const2 = k*(1+np.log(1 / DELTA))/bound_lambda + k*np.sqrt(np.sqrt(t)/(2*len(experience.dataset)))
    return base_loss, kl_surrogate, const1, const2

def get_total_online_bound(online_bounds):
    online_loss = 0.0
    total_kl = 0.0
    t = len(online_bounds)
    for loss, kl_surrogate, _, _ in online_bounds:
        online_loss += loss/t
        total_kl += kl_surrogate
    const1 = online_bounds[-1][2]
    const2 = online_bounds[-1][3]
    print(online_loss, total_kl, const1, const2)
    if total_kl/t+const1 < total_kl+const2:
        return online_loss+total_kl/t+const1
    else:
        return online_loss+total_kl+const2

def calculate_forget_bound(train_stream, validation_stream, curr_task_id, mt_model, prev_model, bound_lambda, k):
    loss = np.mean(calculate_noisy_train_loss(mt_model, train_stream[curr_task_id].dataset))
    kl = calculate_kl_surrogate(mt_model, prev_model, bound_lambda)
    total_bound =  loss+kl
    # copy missing classifier head and allow new task id
    prev_model.__getattr__(prev_model.classifier_name).classifiers[f"{curr_task_id}"] = copy.deepcopy(mt_model.__getattr__(mt_model.classifier_name).classifiers[f"{curr_task_id}"])
    prev_model.__getattr__(prev_model.classifier_name)._buffers[f"active_units_T{curr_task_id}"] = copy.deepcopy(mt_model.__getattr__(mt_model.classifier_name)._buffers[f"active_units_T{curr_task_id}"])
    # calculate loss on current task
    prev_loss = calculate_noisy_train_loss(prev_model, validation_stream[curr_task_id].dataset)
    discrepency = 0.0
    EPS=1e-7
    for experience in validation_stream:
        if experience.current_experience >= curr_task_id:
            break
        prev_loss_i = calculate_noisy_train_loss(prev_model, experience.dataset)
        diffs = torch.tensor(
            [ prev_loss_i[t] - prev_loss[t] for t in range(len(prev_loss))])
        # the log-discrepency is calculated using log-sum-exp trick for numerical stability
        max_diff = torch.max(diffs) * bound_lambda
        discrepency_i = torch.mean(torch.exp(diffs * bound_lambda - max_diff))
        assert not torch.isinf(discrepency_i)
        log_disc =  (max_diff + torch.log(discrepency_i + EPS))/bound_lambda
        discrepency += log_disc

    t = curr_task_id + 1
    const = np.log(1 / DELTA) / bound_lambda + bound_lambda * (k**2) / (8 * len(train_stream[curr_task_id].dataset))
    print(curr_task_id, loss, kl, const, prev_loss, discrepency/(t-1))
    return total_bound+const + discrepency/(t-1)

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default="resnet18", type=str,
                        help="Model: resnet18, resnet50, resnext50, vit_b")
    parser.add_argument('--strategy', default="naive", type=str,
                        help="Training strategy: naive, ewc, replay")
    parser.add_argument('--batch_size', default=128, type=int,
                        help="Gradient batch size")
    parser.add_argument('--tasks', default=20, type=int,
                        help="Number of tasks, must be denominator of 1000, min 2, max 500")
    parser.add_argument('--lr', default=1e-3, type=float,
                        help="Learning rate")
    parser.add_argument('--bound_lambda', default=1e7, type=float,
                        help="Bound parameter")
    parser.add_argument('--seed', type=int, default=42, help="Random seed")
    # [42, 11, 451, 1337, 805287]
    return parser

if __name__ == '__main__':
    args = get_parser().parse_args()
    set_random_seed(args.seed)
    batch_size = args.batch_size
    learning_rate = args.lr
    n_tasks = args.tasks
    bound_lambda = args.bound_lambda
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    n_classes = 1000//n_tasks
    K = 1.0 - 1.0/n_classes
    dataset = SplitImageNet(n_experiences=n_tasks, seed=args.seed, dataset_root="./data/inet",
                            class_ids_from_zero_in_each_exp=True, return_task_id=True) # task-aware heads for easier comparison

    model = resnet18(weights=ResNet18_Weights.DEFAULT) if args.model == "resnet18" \
        else resnet50(weights=ResNet50_Weights.DEFAULT) if args.model == "resnet50" \
        else resnext50_32x4d(weights=ResNeXt50_32X4D_Weights.DEFAULT) if args.model == "resnext50" \
        else vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
    lin_layer_name = "heads" if "vit" in args.model else "fc"
    if lin_layer_name == "fc":
        model.fc = nn.Linear(in_features=model.fc.in_features, out_features=n_classes)
    else:
        model.heads =  nn.Linear(in_features=model.heads[0].in_features, out_features=n_classes)
    mt_model = as_multitask(model, classifier_name=lin_layer_name)
    mt_model.to(device)

    # log to text file
    textfile = open(f'run_metrics_{n_tasks}_{bound_lambda}_{args.strategy}_{args.model}.txt', 'a')
    text_logger = TextLogger(textfile)
    eval_plugin = EvaluationPlugin(
        accuracy_metrics(trained_experience=True, experience=True),
        forgetting_metrics(stream=True),
        loggers=[text_logger],
        strict_checks=False)

    strategy_kwargs = {"model": mt_model, "optimizer":SGD(model.parameters(), lr=learning_rate, weight_decay=learning_rate / 10),
            "criterion":CrossEntropyLoss(), "train_mb_size":batch_size, "train_epochs":1, "eval_mb_size":batch_size,
            "evaluator":eval_plugin, "device":device}
    if args.strategy == "naive":
        cl_strategy = Naive(**strategy_kwargs)
    elif args.strategy == "ewc":
        cl_strategy = EWC(ewc_lambda=40, **strategy_kwargs)
    else:
        cl_strategy =Replay(mem_size=1000,  **strategy_kwargs)


    # TRAINING LOOP
    print('Starting experiment...')
    textfile.write(f"{n_tasks}, {bound_lambda}, {args.strategy}, {args.model}\n")
    results = []
    online_bounds = [] # tuples of l, kl/lambda, const for t (version 1), const for t(version 2)
    online_losses = []
    bwt_bounds = [] # BWT bound
    prev_model = copy.deepcopy(mt_model)
    for experience in dataset.train_stream:
        print("Start of experience: ", experience.current_experience)
        print("Current Classes: ", experience.classes_in_this_experience)

        # train returns a dictionary which contains all the metric values
        res = cl_strategy.train(experience)
        print('Training completed')

        print('Computing accuracy on the whole test set')
        # test also returns a dictionary which contains all the metric values
        results.append(cl_strategy.eval(dataset.test_stream))

        print('Computing bounds')
        # eval bounds. Both model and prev_model are in eval mode due to strategy.eval
        with torch.no_grad():
            online_bounds.append(calculate_online_bound(experience, mt_model, prev_model, bound_lambda, k=K))
            total_online_bound = get_total_online_bound(online_bounds)
            textfile.write(f"task_id:{experience.current_experience}, online bound:{total_online_bound}\n")
            online_loss = 0.0
            for i, r in enumerate(results):
                online_loss += 1.0- r[f"Top1_Acc_Exp/eval_phase/test_stream/Task{i:03d}/Exp{i:03d}"]
            textfile.write(f"task_id:{experience.current_experience}, online loss:{online_loss/(experience.current_experience+1)}\n")
            online_losses.append(online_loss)
            # if experience.current_experience > 0:
            #     bwt_bounds.append(calculate_forget_bound(dataset.train_stream, dataset.test_stream, experience.current_experience, mt_model, prev_model, bound_lambda, k=K))
            #     textfile.write(f"task_id:{experience.current_experience}, BWT bound:{bwt_bounds[-1]}\n")
            #     trained_loss = results[-1][f"Accuracy_On_Trained_Experiences/eval_phase/test_stream/Task{(n_tasks-1):03d}"]
            #     textfile.write(f"task_id:{experience.current_experience}, forget bound:{bwt_bounds[-1]-1.0+trained_loss}\n")

        prev_model = copy.deepcopy(mt_model)
        textfile.flush()

    # Save and plot
    with open(f"results_{n_tasks}_{args.seed}_{args.strategy}_{args.model}.pkl", "wb") as fptr:
        pickle.dump((results, online_losses, online_bounds), fptr)

    plot_range = np.arange(1, n_tasks+1)
    loss_bounds_pretty = [get_total_online_bound(online_bounds[:j]).cpu().item() for j in plot_range]
    plt.figure()
    plt.scatter(plot_range, loss_bounds_pretty,
                    label=f"{args.strategy}_{args.model} bound",
                    marker="+", color='b')
    plt.scatter(plot_range,
                online_losses / plot_range,
                label=f"{args.strategy}_{args.model} cumulative loss", color='b')
    plt.legend()
    plt.xlabel("Task number")
    plt.ylabel("test loss")
    plt.savefig(f"cumulative_loss_{n_tasks}_{bound_lambda}_{args.strategy}_{args.model}.jpg")