import os.path as osp
import os
import json
import statistics
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import numpy as np

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

from .utils import cosine_loss_3d, cal_MTIL_metrics, select_task_by_semantic_center, select_task_by_gaussian_center, load_description_dicts, count_description_keys, average_logits_per_class, average_logits_per_description
from .utils import check_optimizer_params, check_gradients, check_optimizer_contents, set_trainable_params, log_trainable_params, visualize_logits, get_task_class_names, build_prompt, save_task_semantic_center
from .utils import save_predictions_to_json, build_default_prompt, fuse_logits, build_positive_prototypes_for_images, prototype_loss_multipos_infoNCE, compute_default_and_external_logits, visualize_correction
from .grad_cam import GradCAM, save_gradcam_batch_jet

from continuum.metrics import Logger
from CoFiCL.utils import build_cosine_scheduler
from CoFiCL.datasets import parse_sample

from torch.distributions.multivariate_normal import MultivariateNormal

_tokenizer = _Tokenizer()
from clip.adapter_manager import AdapterManager
from clip.prompt_manager import MultiPromptPool, ImagePromptAugmentor, TextPromptAugmentor
import random


vision_adapter = AdapterManager(num_layers=12, d_model=768, num_tasks=11, init_type="adapter", init_layers=[0,1,2,3,4,5,6,7,8,9,10,11])
text_adapter   = AdapterManager(num_layers=12, d_model=512, num_tasks=11, init_type="adapter", init_layers=[0,1,2,3,4,5,6,7,8,9,10,11])


class_names = get_task_class_names() 

def load_clip_to_cpu(cfg, with_ori=False):
    backbone_name = cfg.model_backbone_name
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")
    design_details = {"vision_depth": cfg.CoFiCL.prompt_depth_vision,
                      "language_depth": cfg.CoFiCL.prompt_depth_text, "vision_ctx": cfg.CoFiCL.n_ctx_vision,
                      "language_ctx": cfg.CoFiCL.n_ctx_text,
                      "pool_size": cfg.nb_task}
    train_model = clip.build_model(state_dict or model.state_dict(), design_details, vision_adapter_manager=vision_adapter, text_adapter_manager=text_adapter)

    if with_ori:
        design_details = {"vision_depth": 0,
                          "language_depth": 0, "vision_ctx": 0,
                          "language_ctx": 0}
        ori_model = clip.build_model(state_dict or model.state_dict(), design_details)
        return train_model, ori_model

    return train_model


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts, indices, batch_weight=None, cur_train_task_id: int = -1, domain_pred_task_id: int = -1):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x, _ = self.transformer(x, indices, batch_weight, cur_train_task_id, domain_pred_task_id)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x


class PromptProcessor(nn.Module):
    def __init__(self, cfg, classnames, description_dicts, clip_model):
        super().__init__()

        dtype = clip_model.dtype
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.input_size[0]
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        self.class_ids_per_task = []
        self.desc_dicts = description_dicts
        self.n_cls = 0
        self.dataset = cfg.dataset
        if cfg.dataset == "MTIL":
            self.task_to_dataset = ["Aircraft", "Caltech101", "CIFAR100", "DTD", "EuroSAT", "Flowers",
                                    "Food", "MNIST", "Oxford_Pet", "Stanford_Cars", "SUN397"]
        elif cfg.dataset == "X-TAIL":
            self.task_to_dataset = ["Aircraft", "Caltech101", "DTD", "EuroSAT", "Flowers",
                                    "Food", "MNIST", "Oxford_Pet", "Stanford_Cars", "SUN397"]
        all_prompts = []
        self.prompts_per_class = [] 
        for task_idx, task_classnames in enumerate(classnames):  
            cur_ids = []
            dataset_name = self.task_to_dataset[task_idx]

            task_prompts_count = []  

            for cls_name in task_classnames:
                cur_class_prompts = []

                default_prompt = build_default_prompt(task_idx, cls_name)
                all_prompts.append(default_prompt)
                cur_class_prompts.append(default_prompt)
                cur_ids.append(self.n_cls)
                self.n_cls += 1

                if cls_name in description_dicts[dataset_name]:
                    desc_list = description_dicts[dataset_name][cls_name]
                    for desc_text in desc_list:
                        if dataset_name == "MNIST":
                            prompt = build_prompt(desc_text, f"{cls_name}")
                        else:
                            prompt = build_prompt(desc_text, cls_name)

                        all_prompts.append(prompt)
                        cur_class_prompts.append(prompt)
                        cur_ids.append(self.n_cls)
                        self.n_cls += 1

                task_prompts_count.append(len(cur_class_prompts))

            self.prompts_per_class.append(task_prompts_count)
            self.class_ids_per_task.append(cur_ids)

        # clean classnames for tokenizer
        self.classnames = [name.replace("_", " ") for name in all_prompts]
        self.all_name_lens = [len(_tokenizer.encode(name)) for name in self.classnames]

        # tokenize
        self.register_buffer("all_tokenized_prompts", torch.cat([clip.tokenize(p) for p in self.classnames]))
        with torch.no_grad():
            self.register_buffer("all_embedding", clip_model.token_embedding(self.all_tokenized_prompts).type(clip_model.dtype))
        
        # init
        self.register_buffer("token_prefix", self.all_embedding[:, :1, :])
        self.register_buffer("token_suffix", self.all_embedding[:, 1:, :])
        self.register_buffer("tokenized_prompts", self.all_tokenized_prompts.clone())
        self.cur_n_cls = 0

    def forward(self, indices):
        batch_size = indices.size(0)
        prefix = self.token_prefix.unsqueeze(0).repeat(batch_size, 1, 1, 1)  # [bs, n_cls, 1, ctx_dim]
        suffix = self.token_suffix.unsqueeze(0).repeat(batch_size, 1, 1, 1)  # [bs, n_cls, ..., ctx_dim]
        prompts = torch.cat([prefix, suffix], dim=2)  # [bs, n_cls, 77, ctx_dim]
        prompts = prompts.view(batch_size*self.cur_n_cls, prompts.size(2), prompts.size(3))  # [bs*n_cls, 77, ctx_dim]
        tokenized_prompts = self.tokenized_prompts.unsqueeze(0).repeat(batch_size, 1, 1).view(batch_size*self.cur_n_cls, -1)  # [bs*n_cls, 77, tkn_dim]
        return prompts, tokenized_prompts

    def update_classnames(self, task_id):
        # X-TAIL benckmark
        if self.dataset == "MTIL":
            class_idx = self.class_ids_per_task[task_id]
        elif self.dataset == "X-TAIL":
            class_idx = [i for class_ids in self.class_ids_per_task[:task_id + 1] for i in class_ids]
        class_idx_tensor = torch.tensor(class_idx, dtype=torch.int, device=self.all_embedding.device)
        self.token_prefix = self.all_embedding[class_idx, :1, :]
        self.token_suffix = self.all_embedding[class_idx, 1:, :]
        self.tokenized_prompts = self.all_tokenized_prompts[class_idx]
        self.name_lens = [self.all_name_lens[idx] for idx in class_idx]
        self.cur_n_cls = len(class_idx)


class CustomCLIP(nn.Module):
    def __init__(self, cfg, classnames, templates, clip_model, clip_model_ori=None, img_prompter=None):
        super().__init__()
        description_dicts = load_description_dicts(["Aircraft", "Caltech101", "CIFAR100", "DTD", "EuroSAT", "Flowers", "Food", "MNIST", "Oxford_Pet", "Stanford_Cars", "SUN397"], "descriptions/gpt-4o/")
        self.prompt_processor = PromptProcessor(cfg, classnames, description_dicts, clip_model)
        self.image_encoder = clip_model.visual
        self.image_encoder_ori = clip_model_ori.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.vis_dim = clip_model.visual.output_dim
        self.pool_size = cfg.nb_task
        self.visual_prompt = cfg.CoFiCL.prompt_depth_vision > 0
        self.batchwise_prompt = cfg.CoFiCL.batchwise_prompt

        self.register_buffer("means", torch.empty(self.pool_size, self.vis_dim, dtype=torch.float))
        self.register_buffer("covars", torch.empty(self.pool_size, self.vis_dim, self.vis_dim, dtype=torch.float))
        self.register_buffer("task_learnt", torch.tensor(0, dtype=torch.int))

        self.semantic_means = {} 
        self.img_prompter = img_prompter 
        self.cfg = cfg

    def forward(self, image, task_ids=None, cur_train_task_id: int = -1, test_cur_train_task_id: int = -1, test_cur_test_task_id: int = -1, cfg=None, target=None, debug_for_train_cnt: int = -1, 
                positives=None, prototypes=None, train_flag=False):
        res = {}
        batch_weight = None
        text_batch_weight = None
        with torch.no_grad():
            image_features, image_features_768, _ = self.image_encoder_ori(image.type(self.dtype))
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            image_features_768 = image_features_768 / image_features_768.norm(dim=-1, keepdim=True)
            res["image_features"] = image_features.detach()

        if task_ids is not None:
            prom_seq, idx_sel, reg_prompt = self.img_prompter(image_features_768, train_task_id=test_cur_train_task_id, test_task_id=test_cur_test_task_id, cur_train_task_id=cur_train_task_id) # 选 prompt + 拼接 + 得正则项
            res["prompt_loss"] = reg_prompt
            test_flag = False
            task_ids = task_ids.type(torch.int).to(image.device)# 
            assert (task_ids == task_ids[0]).all()
            indices = task_ids[0:1]
            indices = indices.unsqueeze(1)  # size [1, 1]
            domain_pred_task_id = -1 
        else:
            prom_seq, prom_weights = self.img_prompter.select_from_all_tasks(image_features_768, test_cur_train_task_id=test_cur_train_task_id, test_cur_test_task_id=test_cur_test_task_id)
            test_flag = True
            raw_indices, indices, batch_weight, text_batch_weight, domain_pred_task_id = select_task_by_gaussian_center(
                image_features, self.means, self.covars, self.task_learnt.item(),
                self.prompt_processor, self.batchwise_prompt
            )
            res["text_batch_weight"] = text_batch_weight[0].item()
            res["raw_indices"] = raw_indices
        
        res["indices"] = indices

        prompts, tokenized_prompts = self.prompt_processor(indices)  # [bs*n_cls, 77, ctx_dim]
        text_features = self.text_encoder(prompts, tokenized_prompts, indices, text_batch_weight, cur_train_task_id, domain_pred_task_id)  # [bs*n_cls, model_dim]
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        if self.visual_prompt:
            image_features, _, image_features_prom = self.image_encoder(image.type(self.dtype), indices, batch_weight, cur_train_task_id, domain_pred_task_id, prom_seq)  # [bs, model_dim]
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            image_features_prom = image_features_prom / image_features_prom.norm(dim=-1, keepdim=True)

        if prototypes is not None and positives is not None:
            prototype_loss = prototype_loss_multipos_infoNCE(image_features_prom, prototypes, positives, temperature=0.07)
            res["prototype_loss"] = prototype_loss

        logit_scale = self.logit_scale.exp()
        if indices.size(0) == 1:
            prompts_per_class = self.prompt_processor.prompts_per_class
            if test_flag:
                if self.cfg.dataset == "MTIL":
                    prompts_per_class = prompts_per_class[test_cur_test_task_id]
                elif self.cfg.dataset == "X-TAIL":
                    prompts_per_class = [p for prompts in prompts_per_class for p in prompts]
            else:
                if self.cfg.dataset == "MTIL":
                    prompts_per_class = prompts_per_class[indices.item()]
                elif self.cfg.dataset == "X-TAIL":
                    prompts_per_class = [p for prompts in prompts_per_class[:indices.item() + 1] for p in prompts]
            if prom_seq is not None:
                default_logits, external_logits_avg = compute_default_and_external_logits(image_features, image_features_prom, text_features,
                                                                                          prompts_per_class, logit_scale)
                if train_flag:
                    final_logits = fuse_logits(default_logits, external_logits_avg, alpha=0.8)
                else:
                    final_logits = fuse_logits(default_logits, external_logits_avg, alpha=0.8)
            else:
                logits = logit_scale * image_features @ text_features.t()  # [bs, n_cls]
                default_logits, external_logits_avg = average_logits_per_class(logits, prompts_per_class)
                final_logits = fuse_logits(default_logits, external_logits_avg, alpha=1.0)
        else:
            text_features_resize = text_features.view(image_features.size(0), -1, text_features.size(1))  # [bs, n_cls, model_dim]
            image_features_resize = image_features.unsqueeze(1)  # [bs, 1, model_dim]
            logits = logit_scale * image_features_resize @ text_features_resize.permute(0, 2, 1)  # [bs, 1, n_cls]
            logits = logits.squeeze(1)  # [bs, n_cls]
        res["outputs"] = final_logits
        
        return res

    def update_classnames(self, task_id):
        self.prompt_processor.update_classnames(task_id)


class CoFiCL:
    def __init__(self, cfg, device, classes_names, templates, load_file=None):
        self.build_model(cfg, device, classes_names, templates, load_file)

    def build_model(self, cfg, device, classes_names, templates, load_file=None):

        print(f"Loading CLIP (backbone: {cfg.model_backbone_name})")
        clip_model, clip_model_ori = load_clip_to_cpu(cfg, with_ori=True)

        print("Building custom CLIP")
        prompt_pool = MultiPromptPool(
            num_pools=1,
            M=32,             
            Lp=4,                
            D=768,              
            Dk=None,            
            device=torch.device(f"cuda:{cfg.gpu_id}")
        )
        self.img_prompter = ImagePromptAugmentor(prompt_pool, topk=2)
        model = CustomCLIP(cfg, classes_names, templates, clip_model, clip_model_ori, img_prompter=self.img_prompter)

        print("Turning off gradients in both the image and the text encoder")
        names_to_update = ["img_prompter.pool.keys", "img_prompter.pool.values"]
        self.model, enabled = set_trainable_params(model, names_to_update)
        self.grad_cam_default = GradCAM(model=self.model, target_module=self.model.image_encoder.transformer.resblocks[-3].attn, use_prom=False) 
        para_log = f"Parameters to be updated: {enabled}"
        print(para_log)
        f = open(osp.join(cfg.log_path, 'output.txt'), 'a')
        f.write(para_log + '\n')
        f.close()

        self.devices = device
        self.device = device[0]

        if load_file:
            self.load_model(None, None, load_file)
        
        self.model.to(device[0])
        if len(device) > 1:
            self.model = torch.nn.DataParallel(self.model, device_ids=device)
            
        self.model_wo_dp = self.model.module if len(device) > 1 else self.model

    def save_model(self, cfg, task_id):
        save_dict = {}
        for name, para in self.model.named_parameters():
            if para.requires_grad:
                save_dict[name] = para
        for name, para in self.model.named_buffers():  # for gaussian parameters
            if "means" in name or "covars" in name or "task_learnt" in name:
                save_dict[name] = para
        save_dir = os.path.join(cfg.log_path, 'ckpt')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(save_dict, os.path.join(save_dir, f'task_{task_id}.pt'))

    def load_model(self, cfg, task_id, load_file=None):
        if load_file is None:
            load_file = os.path.join(cfg.log_path, 'ckpt', f'task_{task_id}.pt')
        if not osp.exists(load_file):
            raise FileNotFoundError('Model not found at "{}"'.format(load_file))

        state_dict = torch.load(load_file, map_location="cpu")

        print(f"Loading weights from {load_file}")
        # set strict=False
        self.model.load_state_dict(state_dict, strict=False)

        return [i for i in state_dict.keys()]

    def train_and_eval(self, cfg, datasets):
        acc_list = []
        metric_logger = Logger(list_subsets=["train", "test"])

        metric_writer = open(os.path.join(cfg.log_path, 'metrics.json'), 'w')

        if cfg.zero_shot:
            with torch.no_grad():
                if cfg.dataset == "X-TAIL":
                    self.update_classnames(cfg.nb_task - 1)
                for cur_task in tqdm(range(cfg.nb_task)):
                    if cfg.dataset == "MTIL":
                        self.update_classnames(cur_task)
                    eval_loader = self.get_dataloader(cfg, datasets['test'], cur_task, is_train=False)
                    for sample in eval_loader:
                        inputs, targets, task_ids = parse_sample(sample, is_train=False, task_id=cur_task, cfg=cfg)
                        inputs, targets = inputs.to(self.device), targets.to(self.device)
                        res = self.model(inputs, task_ids)
                        outputs = res["outputs"]
                        metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
                cur_all_task_acc = metric_logger.accuracy_per_task
                acc_list.append(cur_all_task_acc)
                log = {'acc_per_task': [round(100 * acc_t, 2) for acc_t in cur_all_task_acc]}
                metric_writer.write(json.dumps(log) + '\n')
                metric_writer.flush()
                print(log)
                return

        if cfg.eval_only:
            self.eval_all(cfg, datasets, metric_logger, metric_writer, acc_list)
            return
        
        for task_id in range(cfg.nb_task):
            current_task = task_id

            # ================= Freeze Adapter =================
            vision_adapter.freeze_all_except_task(current_task)
            text_adapter.freeze_all_except_task(current_task)

            print(f"Training for task {task_id} has started.")
            
            log_trainable_params(self.model, cfg)
            
            if task_id == 0:
                self.img_prompter.pool.values.data.zero_()
                self.img_prompter.pool.keys.data = torch.randn_like(self.img_prompter.pool.keys) * 0.02
                prev_task_id = None
                init_strategy = 'task-special'
            else:
                prev_task_id = task_id - 1
                init_strategy = 'task-special'  
                if init_strategy == 'inherit':
                    self.img_prompter.pool.load_task_values(prev_task_id)
                else:
                    self.img_prompter.pool.values.data.zero_()
                    self.img_prompter.pool.keys.data = torch.randn_like(self.img_prompter.pool.keys) * 0.02

            self.train_one_task(cfg, task_id, datasets, metric_logger)

            self.img_prompter.pool.save_task_values(task_id)

            if prev_task_id is not None:
                fuse_alpha = 0.0
                self.img_prompter.pool.fuse_values(prev_task_id, task_id, alpha=fuse_alpha, strategy=init_strategy, fuse_keys=True)

            # ================= Diversity Logging =================
            self.img_prompter.pool.log_diversity(task_id=task_id)

            # ================= Load Best Model (if val exists) =================
            if datasets['val']:
                keys = self.load_model(cfg, task_id)
                log = f"Load best epoch weight (epoch {self.best_epoch}), parameters {keys}."
                print(log)
                with open(osp.join(cfg.log_path, 'output.txt'), 'a') as f:
                    f.write(log + '\n')

            # ================= Evaluation =================
            print(f"Evaluation for task {task_id} has started.")
            self.eval_all(cfg, datasets, metric_logger, metric_writer, acc_list, test_cur_train_task_id=task_id)
        
        self.img_prompter.pool.plot_diversity(cfg.log_path)
            
        res = cal_MTIL_metrics(acc_list)
        metric_writer.write(json.dumps(res["transfer"]) + '\n')
        metric_writer.write(json.dumps(res["avg"]) + '\n')
        metric_writer.write(json.dumps(res["last"]) + '\n')
        metric_writer.write(json.dumps(res["results_mean"]) + '\n')
        metric_writer.flush()

    def train_one_task(self, cfg, task_id, datasets, metric_logger):

        if task_id == -1:
            max_epoch = 15
        else:
            max_epoch = cfg.CoFiCL.optim.max_epoch

        train_dataset, val_dataset, eval_dataset = datasets['train'], datasets['val'], datasets['test']
        train_loader = self.get_dataloader(cfg, train_dataset, task_id, is_train=True)
        self.update_classnames(task_id)
        self.model.train()

        # === Optimizer ===
        per_epoch_steps = len(train_loader)
        if cfg.CoFiCL.optim.name == 'SGD':
            optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.CoFiCL.optim.lr, weight_decay=cfg.CoFiCL.optim.weight_decay, momentum=getattr(cfg.CoFiCL.optim, "momentum", 0.9))
            # check_optimizer_contents(optimizer, self.model)
        elif cfg.CoFiCL.optim.name == 'AdamW':
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.CoFiCL.optim.lr, weight_decay=cfg.CoFiCL.optim.weight_decay, betas=getattr(cfg.CoFiCL.optim, "betas", (0.9, 0.999)), eps=1e-6)
            # check_optimizer_contents(optimizer, self.model)
        else:
            raise NotImplementedError
        
        # === LR Scheduler ===
        if cfg.CoFiCL.optim.lr_scheduler == 'cosine':
            lr_warmup_step = cfg.CoFiCL.optim.warmup_epoch * per_epoch_steps
            scheduler = build_cosine_scheduler(optimizer, lr=cfg.CoFiCL.optim.lr, total_step=max_epoch * per_epoch_steps, lr_warmup_step=lr_warmup_step)
        elif cfg.CoFiCL.optim.lr_scheduler == 'no':
            scheduler = None
        else:
            raise NotImplementedError

        self.best_epoch = -1
        self.best_acc = -1

        all_image_features = torch.empty([0, self.model_wo_dp.vis_dim], dtype=self.model_wo_dp.dtype, device=self.device)
        with torch.no_grad():
            debug_for_train_cnt = 0 
            for sample in train_loader:
                debug_for_train_cnt += 1 
                if debug_for_train_cnt > 2: 
                    if cfg.debug_flag == True:
                        break 
                inputs, _, _ = parse_sample(sample, is_train=False, task_id=task_id, cfg=cfg)
                image_features, _, _ = self.model_wo_dp.image_encoder_ori(inputs.type(self.model_wo_dp.dtype).to(self.device))
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                all_image_features = torch.cat([all_image_features, image_features.detach()], dim=0)

        all_image_features = all_image_features.type(torch.float)  # to avoid precision problems
        mean = all_image_features.mean(dim=0)
        delta = (all_image_features - mean.unsqueeze(0))
        covar = delta.t() @ delta / (all_image_features.size(0) - 1)
        covar +=  torch.eye(covar.size(0), device=covar.device, dtype=torch.float)*1e-7  # to avoid precision problems
        self.model_wo_dp.means[task_id] = mean
        self.model_wo_dp.covars[task_id] = covar
        self.model_wo_dp.task_learnt += 1

        if cfg.dataset == "MTIL":
            with open(f"proto_cache/task{task_id}/class_proto_task{task_id}.json", "r") as f:
                class_proto_map = json.load(f)
            prototypes = torch.load(f"proto_cache/task{task_id}/prototypes_task{task_id}.pt").to(self.device)
        elif cfg.dataset == "X-TAIL":
            task_list = [0, 1, 3, 4, 5, 6, 7, 8, 9, 10]
            prototypes, class_proto_map = list(), dict()
            offset = 0
            for i in task_list:
                with open(f"{path}/task{i}/class_proto_task{i}.json", "r") as f:
                    class_proto_task = json.load(f)
                for key, value in class_proto_task.items():
                    class_proto_map[int(key) + offset] = {"proto_ids": [(pid + offset) for pid in value["proto_ids"]]}
                task_prototypes = torch.load(f"{path}/task{i}/prototypes_task{i}.pt")
                offset += task_prototypes.shape[0]
                prototypes.append(task_prototypes)
            prototypes = torch.cat(prototypes, dim=0).to(self.device)

        for epoch in tqdm(range(max_epoch)):
            main_loss_tot = 0
            proto_loss_tot = 0.0
            prompt_loss_tot = 0.0
            total_loss_tot = 0.0
            loss_num = 0
            debug_for_train_cnt = 0 
            for idx, sample in enumerate(train_loader):
                debug_for_train_cnt += 1 
                if debug_for_train_cnt > 2: 
                    if cfg.debug_flag == True:
                        break 
                if scheduler:
                    cur_iter_idx = epoch*per_epoch_steps + idx
                    scheduler.step(cur_iter_idx)

                inputs, targets, task_ids = parse_sample(sample, is_train=True, task_id=task_id, cfg=cfg)
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                positives = build_positive_prototypes_for_images(sample[1], class_proto_map) 
                res = self.model(inputs, task_ids, cur_train_task_id=task_id, positives=positives, prototypes=prototypes, train_flag=True)
                outputs = res["outputs"]
                loss_main = F.cross_entropy(outputs, targets)
                loss_prototype = res["prototype_loss"] 
                prompt_loss = res["prompt_loss"] 
                loss = loss_main + 0.1*loss_prototype + 0.1*prompt_loss
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
                optimizer.step()

                main_loss_tot += loss_main.item()
                proto_loss_tot += loss_prototype.item()
                prompt_loss_tot += prompt_loss.item()
                total_loss_tot += loss.item()
                loss_num += 1

                metric_logger.add([outputs.detach().cpu().argmax(dim=1), targets.cpu(), task_ids], subset="train")
            
            log = f"\ntask{task_id}_epoch{epoch}:\n"
            log += f"train acc: {metric_logger.online_accuracy}"
            metric_logger.end_epoch()
            with open(osp.join(cfg.log_path, 'output.txt'), 'a') as f:
                f.write(log + '\n')

            log = (
                    f"avg main loss {round(main_loss_tot / loss_num, 5)} | "
                    f"avg proto loss {round(proto_loss_tot / loss_num, 5)} | "
                    f"avg prompt loss {round(prompt_loss_tot / loss_num, 5)} | "
                    f"avg total loss {round(total_loss_tot / loss_num, 5)}"
                )
            with open(osp.join(cfg.log_path, 'output.txt'), 'a') as f:
                f.write(log + '\n')
            
            if val_dataset:
                self.model.eval()
                self.update_classnames(task_id)
                val_loader = self.get_dataloader(cfg, val_dataset, task_id, is_train=False)
                cur_right = torch.FloatTensor([0]).to(self.device)
                cur_all = torch.FloatTensor([0]).to(self.device)
                with torch.no_grad():
                    debug_for_train_cnt = 0 
                    for sample in val_loader:
                        debug_for_train_cnt += 1
                        if debug_for_train_cnt > 2: 
                            if cfg.debug_flag == True:
                                break 
                        inputs, targets, task_ids = parse_sample(sample, is_train=False, task_id=task_id, cfg=cfg)
                        inputs, targets = inputs.to(self.device), targets.to(self.device)
                        res = self.model(inputs, task_ids, cur_train_task_id=task_id)
                        outputs = res["outputs"]
                        cur_right += torch.sum((outputs.argmax(dim=1)==targets))
                        cur_all += targets.size(0)
                cur_acc = cur_right/cur_all
                if cur_acc > self.best_acc:
                    self.best_epoch = epoch
                    self.best_acc = cur_acc
                    self.save_model(cfg, task_id)
                self.update_classnames(task_id)
                self.model.train()

    def eval_all(self, cfg, datasets, metric_logger, metric_writer, acc_list, test_cur_train_task_id=-1):
        eval_dataset = datasets['test']
        self.model.eval()
        if cfg.dataset == "X-TAIL":
            self.update_classnames(cfg.nb_task - 1)
        for cur_task in tqdm(range(cfg.nb_task)):
            if cfg.dataset == "MTIL":
                self.update_classnames(cur_task)
            eval_loader = self.get_dataloader(cfg, eval_dataset, cur_task, is_train=False)
            self.evaluate(cfg, eval_loader, cur_task, metric_logger, test_cur_train_task_id)
            self.img_prompter.pool.plot_task_all_pools(train_task_id=test_cur_train_task_id, test_task_id=cur_task, normalize=True, log_path=cfg.log_path)

        cur_all_task_acc = metric_logger.accuracy_per_task
        acc_list.append(cur_all_task_acc)
        log = {'acc_per_task': [round(100 * acc_t, 2) for acc_t in cur_all_task_acc]}
        metric_writer.write(json.dumps(log) + '\n')
        metric_writer.flush()
        print(log)
        metric_logger.end_task()

    def evaluate(self, cfg, loader, task_id, metric_logger=None, test_cur_train_task_id=-1):
        right_num = 0
        sample_num = 0
        with torch.no_grad():
            debug_for_train_cnt = 0 
            for sample in loader:
                debug_for_train_cnt += 1 
                if debug_for_train_cnt > 2: 
                    if cfg.debug_flag == True:
                        break 
                inputs, targets, task_ids = parse_sample(sample, is_train=False, task_id=task_id, cfg=cfg)
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                res = self.model(inputs, test_cur_train_task_id=test_cur_train_task_id, test_cur_test_task_id=task_id, cfg=cfg, target=targets, debug_for_train_cnt=debug_for_train_cnt) 
                outputs = res["outputs"]
                if metric_logger:
                    metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
                right_num += torch.sum(outputs.argmax(dim=1) == targets).item()
                sample_num += inputs.size(0)
        return right_num / sample_num

    def get_dataloader(self, cfg, dataset, task_id, is_train):
        batch_size = cfg.CoFiCL.optim.batch_size
        if isinstance(dataset, list):
            if cfg.CoFiCL.batchwise_prompt and (not is_train):
                batch_size *= 2
            loader = DataLoader(dataset[task_id], batch_size=int(batch_size), shuffle=is_train, num_workers=8)
        else:
            raise NotImplementedError
        return loader

    def update_classnames(self, task_id):
        if isinstance(self.model, torch.nn.DataParallel):
            self.model.module.update_classnames(task_id)
        else:
            self.model.update_classnames(task_id)
