import os
import sys
import warnings
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
# from torch.distributed.pipeline.sync.checkpoint import checkpoint
from tqdm import tqdm
import math
from torch import optim, nn
import logging
import copy
from torch.utils.data import DataLoader
import torch
import numpy as np
from utils import factory
from utils.data_manager import DataManager
from torch.distributions.multivariate_normal import MultivariateNormal
from utils.toolkit import count_parameters
from pathlib import Path
from ESN.networks import IncrementalViTOOD, _create_vision_transformer


def eval_cp(args):
    seed_list = copy.deepcopy(args['seed'])
    device = copy.deepcopy(args['device'])

    for seed in seed_list:
        args['seed'] = seed
        args['device'] = device
        test(args)


def test(args):
    _set_device(args)
    data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'])
    device = args['device'][0]
    model = torch.load(cp_path)
    model._network.eval()
    model._network.to(device)
    model._device = device
    # task = 9
    acc_lis=[]
    for task in range(9,10):
        test_dataset = data_manager.get_dataset(range(args["increment"],(task+1)*args["increment"]), source='test', mode='test')
        test_loader = DataLoader(test_dataset, batch_size=args["batch_size"], shuffle=False, num_workers=8)
        total = 0
        cor = 0
        is_tty = sys.stdout.isatty()
        val_bar = tqdm(test_loader, desc=f"测试 [Val]", leave=False, disable=not is_tty)
        for _, inputs, targets in val_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            gen_p=[]
            # with torch.no_grad():
            #     x_querry = model._network.image_encoder(inputs, returnbeforepool=True)[:,0,:]
            #
            # K=model._network.keys
            #
            # f=(task+1)*model.args["increment"]
            # K = K[:f]
            # n_K = nn.functional.normalize(K, dim=1)
            # q = nn.functional.normalize(x_querry, dim=1)
            # mk = torch.einsum('bd,kd->bk', q, n_K)
            #
            # m=torch.max(mk,dim=1,keepdim=True)[1]//model.args["increment"]
            # 计算真实的 task_id

            true_task_id = ((targets // args["increment"])-(targets // args["increment"]))
            m=true_task_id.unsqueeze(1)
            # random_tensor = torch.randint(low=0, high=10, size=targets.shape)
            # m = random_tensor.unsqueeze(1)
            # pred_task_id = m.squeeze()  # [batch_size]
            # targets = targets
            ts_prompts_1=model._network.ts_prompts_1
            P1=torch.cat([ts_prompts_1[j].weight.detach().clone().unsqueeze(0) for j in m],dim=0)
            gen_p.append(P1)
            ts_prompts_2=model._network.ts_prompts_2
            P2=torch.cat([ts_prompts_2[j].weight.detach().clone().unsqueeze(0) for j in m],dim=0)
            gen_p.append(P2)
            with torch.no_grad():
                out_logits=model._network(inputs,gen_p,train=False)[:,:(task+1)*args["increment"]]
            preds_final=torch.max(out_logits, dim=1)[1]
            cor+=preds_final.eq(targets.expand_as(preds_final)).cpu().sum().numpy()
            total += len(targets)

            val_bar.set_postfix(acc=100. * cor / total if total > 0 else 0.0)
        val_bar.close()
        acc_lis.append(100. * cor / total)
        print(f"task{task}用ESN正确率为:{100. * cor / total :.2f}")
    average=sum(acc_lis)/len(acc_lis)
    print(f"平均准确率为{average :.2f}")
    print(f"acc_lis={acc_lis}")
    return

def _set_device(args):
    device_type = args['device']
    gpus = []

    for device in device_type:
        if device_type == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:{}'.format(device))

        gpus.append(device)
    args['device'] = gpus