import os
import sys
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
# from torch.distributed.pipeline.sync.checkpoint import checkpoint
from tqdm import tqdm
import math
from torch import optim, nn
import logging
import copy
from torch.utils.data import DataLoader
import torch
import numpy as np
from utils import factory
from utils.data_manager import DataManager
from torch.distributions.multivariate_normal import MultivariateNormal
from utils.toolkit import count_parameters
from pathlib import Path
from ESN.networks import IncrementalViTOOD, _create_vision_transformer
import time, datetime

def eval_esncp(args):
    seed_list = copy.deepcopy(args['seed'])
    device = copy.deepcopy(args['device'])

    for seed in seed_list:
        args['seed'] = seed
        args['device'] = device
        test(args)


def test(args):
    _set_device(args)
    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'])
    device = args['device'][0]
    model = torch.load(cp_path)
    model._network.eval()
    model._network.to(device)
    model._device = device
    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
    ptvit = _create_vision_transformer('vit_base_patch16_224', pretrained=True, **model_kwargs)
    ptvit = ptvit.to(device)
    ptvit.eval()
    faa_accuracy_table=[]
    acc_all = []
    for task in range(10):
        test_dataset = data_manager.get_dataset(np.arange((task+1)*args["increment"]), source='test', mode='test')
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=8)
        # start_time = time.time()
        # for ss in range(1):
        total = 0
        cor = 0
        # cor_ture = 0
        is_tty = sys.stdout.isatty()
        val_bar = tqdm(test_loader, desc=f"测试 [Val]", leave=False, disable=not is_tty)
        faa_y_true = []
        faa_pred = []
        # all_pt = []
        # easy_samples = []
        # semi_samples = []
        # hard_samples = []
        # strange_samples = []
        # faa_y_true = []
        # faa_pred=[]
        all_selected = []
        for _, inputs, targets in val_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            gen_p = []

            # candiatetask = []
            # with torch.no_grad():
            #     image_features = ptvit(inputs, instance_tokens=vitpromptlist[task].weight, returnbeforepool=True)
            #     for idx, fc in enumerate(all_classifiers[:task+1]):
            #         # image_features = ptvit(inputs, instance_tokens=vitpromptlist[idx].weight, returnbeforepool=True)
            #         B = image_features.shape[0]
            #         task_token = all_tokens[idx].expand(B, -1, -1)
            #         task_token, attn, v = all_tabs[idx](torch.cat((task_token, image_features), dim=1),
            #                                                          mask_heads=None)
            #         task_token = task_token[:, 0]
            #         logit = fc(task_token)
            #         candiatetask.append(torch.logsumexp(logit, axis=-1))
            #
            # candiatetask = torch.stack(candiatetask).T
            # selected = candiatetask.max(1)[1]
            # all_selected.append(selected.cpu())
            # selected = (targets // args["increment"]
            # total += len(targets)
            # cor+=(selected==(targets//args["increment"])).sum().item()
            with torch.no_grad():
                x_querry = model._network.image_encoder(inputs, returnbeforepool=True)[:,0,:]

            K=model._network.keys

            f=(task+1)*model.args["increment"]
            K = K[:f]
            n_K = nn.functional.normalize(K, dim=1)
            q = nn.functional.normalize(x_querry, dim=1)
            mk = torch.einsum('bd,kd->bk', q, n_K)

            m=torch.max(mk,dim=1,keepdim=True)[1]//model.args["increment"]
            ##
            ts_prompts_1=model._network.ts_prompts_1
            P1=torch.cat([ts_prompts_1[j].weight.detach().clone().unsqueeze(0) for j in m],dim=0)
            gen_p.append(P1)
            ts_prompts_2=model._network.ts_prompts_2
            P2=torch.cat([ts_prompts_2[j].weight.detach().clone().unsqueeze(0) for j in m],dim=0)
            gen_p.append(P2)
            with torch.no_grad():
                out_logits=model._network(inputs,gen_p,train=False)[:,:(task+1) * args["increment"]]
            preds_final=torch.max(out_logits, dim=1)[1]
            cor+=preds_final.eq(targets.expand_as(preds_final)).cpu().sum().numpy()
            total += len(targets)
            faa_pred.append(preds_final.cpu().numpy())
            faa_y_true.append(targets.cpu().numpy())
            # #给定真实的task_id
            # # true_task_id = targets // args["increment"]
            # true_task_id = preds_final //args["increment"]
            # gen_p = []
            # ts_prompts_1=model._network.ts_prompts_1
            # P1=torch.cat([ts_prompts_1[j].weight.detach().clone().unsqueeze(0) for j in true_task_id.unsqueeze(1)],dim=0)
            # gen_p.append(P1)
            # ts_prompts_2=model._network.ts_prompts_2
            # P2=torch.cat([ts_prompts_2[j].weight.detach().clone().unsqueeze(0) for j in true_task_id.unsqueeze(1)],dim=0)
            # gen_p.append(P2)
            # with torch.no_grad():
            #     out_logits=model._network(inputs,gen_p,train=False)
            # preds_true=torch.max(out_logits, dim=1)[1]
            # cor_ture += preds_true.eq(targets.expand_as(preds_true)).cpu().sum().numpy()

            # probabilities = torch.softmax(out_logits, dim=1)
            # # task_targets=targets // args["increment"]
            # pt_values = probabilities[torch.arange(len(targets)),  targets].cpu().numpy()
            # all_pt.extend(pt_values.tolist())
            #
            # # ================== 收集不同难度样本 ==================
            # mask_easy = (preds_final == targets) & (preds_true == targets)
            # mask_semi = (preds_final != targets) & (preds_true == targets)
            # mask_hard = (preds_final != targets) & (preds_true != targets)
            # mask_strange = (preds_final == targets) & (preds_true != targets)
            # easy_samples.extend(list(zip(inputs[mask_easy].cpu(), targets[mask_easy].cpu())))
            # semi_samples.extend(list(zip(inputs[mask_semi].cpu(), targets[mask_semi].cpu())))
            # hard_samples.extend(list(zip(inputs[mask_hard].cpu(), targets[mask_hard].cpu())))
            # strange_samples.extend(list(zip(inputs[mask_strange].cpu(), targets[mask_strange].cpu())))
            val_bar.set_postfix(acc=100. * cor / total if total > 0 else 0.0)
        # all_selected = torch.cat(all_selected, dim=0)
        # torch.save(all_selected, f'{args["dataset"]}_selected_values.pt')
        # print(f"FAA={100. * cor / total :.2f}")
        # acc_lis.append(f"{100. * cor / total :.2f}")
        val_bar.close()

        faa_pred = np.concatenate(faa_pred)
        faa_y_true = np.concatenate(faa_y_true)
        faa_tempacc = []
        for class_id in range(0, np.max(faa_y_true), args["increment"]):
            idxes = np.where(np.logical_and(faa_y_true >= class_id, faa_y_true < class_id + args["increment"]))[0]
            faa_tempacc.append(np.around((faa_pred[idxes] == faa_y_true[idxes]).sum() * 100 / len(idxes), decimals=3))

        faa_accuracy_table.append(faa_tempacc)
        acc_all.append(np.around(cor*100 / total, decimals=2))
        acctable = np.zeros([task + 1, task + 1])

        for idxx, line in enumerate(faa_accuracy_table):
            idxy = len(line)
            acctable[idxx, :idxy] = np.array(line)

        acctable = acctable.T
        forgetting = np.mean((np.max(acctable, axis=1) - acctable[:, task])[:task])
        if task==9:
            print(f"#########Task{task}训练后的结果如下：")
            print("Last-acc:{}".format(acc_all[-1]))
            print("Avg-acc:{:.3f}".format(np.mean(acc_all)))
            print("FF: {}".format(np.around(forgetting, decimals=2)))
    return


def _set_device(args):
    device_type = args['device']
    gpus = []

    for device in device_type:
        if device_type == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:{}'.format(device))

        gpus.append(device)
    args['device'] = gpus