import os.path
import os.path
from collections import OrderedDict

import argparser as MiBargparser
import tasks as MiBtasks
# MiB test
from segmentation_module import make_model

parser = MiBargparser.get_argparser()

opts = parser.parse_args()
mib_opts = MiBargparser.modify_command_options(opts)

mib_opts.dataset = "voc"
mib_opts.task = "1-1"
mib_opts.method = "MiB"
mib_opts.options = r"/home/gyf/QLY/CVPR2021_PLOP/checkpoints/not frozen backbone/"
mib_opts.data_root = r"./data/VOC2012"

parser = MiBargparser.get_argparser()

opts = parser.parse_args()
plop_opts = MiBargparser.modify_command_options(opts)

plop_opts.dataset = "voc"
plop_opts.task = "1-1"
plop_opts.method = "PLOP"
plop_opts.options = r"/home/gyf/QLY/CVPR2021_PLOP/checkpoints/not frozen backbone/"
plop_opts.options = r"./data/VOC2012"
plop_opts.data_root = r"./data/VOC2012"
device = "cuda:0"
def filter_checkpoint(checkpoint):
    new_checkpoints = OrderedDict()
    if "model_state" in checkpoint:
        for name, parameters in checkpoint["model_state"].items():
            if name.startswith("module"):
                k = name[7:]
                new_checkpoints[k] = parameters
    return new_checkpoints


def calculate_parameter_diff(task_t, task_i, _opts):
    _path_t = os.path.join(_opts.options,
                           f"{_opts.task}-{_opts.dataset}" + f"_{_opts.method}" + f"_{task_t}" + ".pth")
    _path_i = os.path.join(_opts.options,
                           f"{_opts.task}-{_opts.dataset}" + f"_{_opts.method}" + f"_{task_i}" + ".pth")
    # 过滤掉bn或者其他层的参数
    model_t = make_model(mib_opts, classes=MiBtasks.get_per_task_classes(_opts.dataset, _opts.task, task_t)).to(device)
    params_t = filter_checkpoint(torch.load(_path_t, map_location=device))
    params_i = filter_checkpoint(torch.load(_path_i, map_location=device))
    model_t.load_state_dict(params_t)
    diff_parameters = []
    for name, param_t in model_t.named_parameters():
        if name in params_i:
            param_i = params_i[name]
            if param_t.shape == param_i.shape:
                diff = param_t - param_i
            else:
                param_i_padded = torch.zeros_like(param_t)
                # for conv, weight is [output, input, kernel_h,kernel_w]
                # for conv , bias is [output]
                slice_indices = [slice(None)] * param_i.dim()
                slice_indices[0] = slice(0, param_i.shape[0])
                param_i_padded[tuple(slice_indices)] = param_i
                diff = param_t - param_i_padded
        else:
            name_check = name.split(".")[0]
            assert name_check == "cls"
            print(name)
            diff = param_t
        diff_parameters.append(diff)
    return tuple(diff_parameters)


def calculate_delta_t(task_t, opts):
    return calculate_parameter_diff(task_t, task_t - 1, opts)


from dataset import VOCSegmentationIncremental as dataset
from dataset import transform

from torch.nn import CrossEntropyLoss
from torch.func import functional_call
from torch.autograd.functional import hvp


def prepare_calculate(task_num, _opts):
    _opts.step = task_num
    model_t = make_model(_opts, classes=MiBtasks.get_per_task_classes(_opts.dataset, _opts.task, task_num))
    val_transform = transform.Compose(
        [
            transform.Resize((_opts.crop_size, _opts.crop_size)),
            transform.ToTensor(),
            transform.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    model_t = model_t.to(device)

    labels, labels_old, path_base = MiBtasks.get_task_labels(_opts.dataset, _opts.task, _opts.step - 1)
    path_base += "-ov"
    labels_cum = labels_old + labels
    dataset_t_1 = dataset(
        root=_opts.data_root,
        train=False,
        transform=val_transform,
        labels=list(labels),
        labels_old=list(labels_old),
        idxs_path=path_base + f"/val-{_opts.step}.npy",
        masking=not _opts.no_mask,
        overlap=True,
        disable_background=_opts.disable_background,
        data_masking=_opts.data_masking,
        step=_opts.step
    )
    loss_fn = CrossEntropyLoss(ignore_index=255)
    return dataset_t_1, loss_fn, model_t


import torch
from torch.utils.data import DataLoader
import tqdm
from torch.cuda.amp import autocast


def calculate_forgetting_rate_taylor_second_order(task_num,_opts):
    second_taylor = 0
    delta_t = calculate_delta_t(task_num, _opts)
    for task in range(1, task_num):
        dataset_t, loss_fn, model_t = prepare_calculate(task,_opts)
        model_t.eval()
        params = tuple(model_t.parameters())
        params_name = tuple(name for name, _ in model_t.named_parameters())
        batch = 30
        dataloader = DataLoader(dataset_t, batch_size=batch, shuffle=False, num_workers=0)
        hessian_result_list = []

        def stateless_loss(*current_params):
            param_dict = {name: p for name, p in zip(params_name, current_params)}
            outputs = functional_call(model_t, param_dict, (image,))[0]
            loss = loss_fn(outputs, label)
            return loss

        for i in tqdm.tqdm(dataloader, total=len(dataloader)):
            image, label = i
            label = label.long()
            image = image.to(device)
            label = label.to(device)

            with autocast(dtype=torch.float16):
                slice_delta_t = delta_t[:-len(delta_t) + len(params)]
                hvp_result: tuple = hvp(stateless_loss, params, slice_delta_t)[1]
                assert len(slice_delta_t) == len(hvp_result)
                hessian_result = sum(torch.sum(i * j) for i, j in zip(slice_delta_t, hvp_result))
            hessian_result_list.append(hessian_result.item())
            del image, label, hessian_result, hvp_result
        hessian_result = sum(hessian_result_list) / (2 * len(hessian_result_list))
        second_taylor += hessian_result
        print(hessian_result)
        del dataset_t, slice_delta_t, model_t, params, dataloader, hessian_result_list
        torch.cuda.empty_cache()

    print(second_taylor)


def calculate_forgetting_rate_taylor_third_order(task_num,_opts):
    third_order = 0
    delta_t = calculate_delta_t(task_num, _opts)
    hessian_result_list = []
    for task in range(1, task_num):
        dataset_t, loss_fn, model_t = prepare_calculate(task,_opts)

        params = tuple(model_t.parameters())
        params_name = tuple(name for name, _ in model_t.named_parameters())
        batch = 30
        dataloader = DataLoader(dataset_t, batch_size=batch, shuffle=False, num_workers=0)
        def stateless_loss(*current_params):
            param_dict = {name: p for name, p in zip(params_name, current_params)}
            outputs = functional_call(model_t, param_dict, (image,))[0]
            loss = loss_fn(outputs, label)
            return loss

        for i in tqdm.tqdm(dataloader, total=len(dataloader)):
            image, label = i
            label = label.long()
            image = image.to(device)
            label = label.to(device)
            with autocast(dtype=torch.float16):
                slice_delta_t = delta_t[:-len(delta_t) + len(params)]
                hvp_result: tuple = hvp(stateless_loss, params, slice_delta_t)[1]
            hessian_result_list.append(hvp_result)
            del image, label,  hvp_result




def total_calculate_forgetting_rate(_opts):
    from tasks import tasks_voc,tasks_ade
    setting = tasks_voc[_opts.task] if _opts.dataset == "voc" else tasks_ade[_opts.task]
    task_num = len(setting)
    forgetting_rate_dict = {}
    batch = 40
    for i in range(1,task_num):
        print(f"testing model {i}")
        _,loss_fn,model_i = prepare_calculate(i, _opts)
        model_i = model_i.to(device).eval()
        forgetting_rate_dict[i] = {}
        for j in range(1,i+1):
            print(f"testing model {i} in dataset {j}")
            dataset_j,_,_ = prepare_calculate(j, _opts)
            dataloader = DataLoader(dataset_j, batch_size=batch, shuffle=False, num_workers=0)
            loss_value = []
            with torch.no_grad():
                for image, label in tqdm.tqdm(dataloader,total=len(dataloader)):
                    image,label = image.to(device),label.long().to(device)
                    pred = model_i(image)[0]
                    loss = loss_fn(pred, label)
                    loss_value.append(loss.item())
            forgetting_rate_dict[i][j] = sum(loss_value) / len(loss_value)
    return forgetting_rate_dict

def calculate_forgetting_rate(task_num, opts):
    task_forgetting_rate = []
    _, _, model_tau = prepare_calculate(task_num, opts)
    model_tau = model_tau.to(device)
    for task in range(1, task_num):
        loss_in_task_t = []
        loss_in_task_t_for_tau = []
        dataset_t, loss_fn, model_t = prepare_calculate(task, opts)
        model_t = model_t.to(device)
        batch = 30
        dataloader = DataLoader(dataset_t, batch_size=batch, shuffle=False, num_workers=0)
        print(f"In task {task}")
        with torch.no_grad():
            for image, label in tqdm.tqdm(dataloader, total=len(dataloader)):
                image, label = image.to(device), label.long().to(device)
                pred_t = model_t(image)[0]
                loss_t = loss_fn(pred_t, label)
                loss_in_task_t.append(loss_t.item())
                pred_tau = model_tau(image)[0]
                loss_tau = loss_fn(pred_tau, label)
                loss_in_task_t_for_tau.append(loss_tau.item())
        task_loss = sum(loss_in_task_t) / len(loss_in_task_t)
        task_loss_for_tau = sum(loss_in_task_t_for_tau) / len(loss_in_task_t_for_tau)
        task_forgetting_rate.append(task_loss_for_tau - task_loss)
    rate_ = sum(task_forgetting_rate) / len(task_forgetting_rate)
    return rate_, task_forgetting_rate


# calculate_forgetting_rate_taylor_second_order(5)
# rate_4 = calculate_forgetting_rate(4)
# rate_5 = calculate_forgetting_rate(5)
# rate_6 = calculate_forgetting_rate(6)
# rate_7 = calculate_forgetting_rate(11,mib_opts)
# calculate_forgetting_rate_taylor(4)
# print(rate_5-rate_4)
# mib_rate = []
# mib_rate_list = []
# plop_rate = []
# plop_rate_list= []
# for i in range(2, 20):
#     result, result_list = calculate_forgetting_rate(i, mib_opts)
#     # r,r_list = calculate_forgetting_rate(i, plop_opts)
#     mib_rate.append(result)
#     # plop_rate.append(r)
#     mib_rate_list.append(result_list)
#     # plop_rate_list.append(r_list)
# import matplotlib.pyplot as plt
#
# print(mib_rate_list)
# # print(plop_rate_list)
# # with open(r"./test.txt","x") as f:
# #     for i,j in zip(mib_rate_list,plop_rate_list):
# #         f.write(f"{i},{j}\n")
# with open(r"./test.txt", "w") as f:
#     for i in mib_rate_list:
#         f.write(f"{i}\n")
# x_data = range(len(mib_rate))
# plt.plot(x_data, mib_rate, label="mib", color="blue")
# # plt.plot(x_data, plop_rate,label="plop",color="red")
# plt.legend()
# plt.grid(True)
# plt.savefig("./mib_forgetting_rate.png")
import os
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# results = total_calculate_forgetting_rate(mib_opts)
# with open(r"./mib_1-1.txt","x") as f:
#     for i in results.keys():
#         f.write(f"{results[i]}\n")
num = 1
import ast
plop_results = {}
with open("./plop_1-1.txt","r") as f:
    for i in f.readlines():
        i = i[:-1]
        plop_results[num] = ast.literal_eval(i)
        num+=1
num = 1
mib_results = {}
with open("./mib_1-1.txt","r") as f:
    for i in f.readlines():
        i = i[:-1]
        mib_results[num] = ast.literal_eval(i)
        num+=1
# print(results)
#
def calculate_from_dict(result_dict):
    task_num = len(result_dict)
    rate_list = []
    for i in range(2,task_num+1):
        rate_i = 0
        for j in range(1,i):
            rate_i_j = result_dict[i][j]-result_dict[j][j]
            rate_i+= rate_i_j
        rate_list.append(rate_i/(i-1))
    return rate_list
#
y_data = calculate_from_dict(mib_results)
print(y_data)
y_data = [0]+y_data
y_data1 = calculate_from_dict(plop_results)
print(y_data1)
y_data1 = [0]+y_data1
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
x_data = range(len(y_data))
plt.plot(x_data, y_data, label="mib", color="blue",marker="o")
plt.plot(x_data,y_data1, label="plop", color="red",marker="^")
ax = plt.gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend()
plt.grid(True)
plt.savefig("./forgetting_rate.png")
