import os.path as osp
import os
import json
import statistics
from tqdm import tqdm
import pandas as pd
import openpyxl

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import numpy as np

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

from .utils import cosine_loss_3d, cal_MTIL_metrics

from continuum.metrics import Logger
from TSPD.utils import build_cosine_scheduler
from TSPD.datasets import parse_sample
from TSPD.svd import compute_svd_base, compute_task_id
from TSPD.utils import get_transform, set_random_seed
from TSPD.MTIL_datasets.tiny_imagenet import ImagenetTiny
import sys

_tokenizer = _Tokenizer()


def load_clip_to_cpu(cfg, with_ori=False):
    backbone_name = cfg.model_backbone_name
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")
    design_details = {"vision_depth": cfg.TSPD.prompt_depth_vision,
                      "language_depth": cfg.TSPD.prompt_depth_text, 
                      "vision_ctx": cfg.TSPD.n_ctx_vision,
                      "language_ctx": cfg.TSPD.n_ctx_text,
                      "pool_size": cfg.nb_task,
                      "temperature": cfg.temperature}
    train_model = clip.build_model(state_dict or model.state_dict(), design_details)

    if with_ori:
        design_details = {"vision_depth": 0,
                          "language_depth": 0, "vision_ctx": 0,
                          "language_ctx": 0}
        ori_model = clip.build_model(state_dict or model.state_dict(), design_details)
        return train_model, ori_model

    return train_model


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts, indice, ori_features=None, ori_output=False):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x, indice, ori_features, ori_output=ori_output)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x


class PromptProcessor(nn.Module):
    def __init__(self, cfg, classnames, templates, clip_model):
        super().__init__()

        dtype = clip_model.dtype
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.input_size[0]
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if isinstance(classnames[0], list):
            self.n_cls = 0
            self.class_ids_per_task = []
            self.classnames = []
            for idx, cls_name in enumerate(classnames):
                cur_n = len(cls_name)
                self.class_ids_per_task.append([i for i in range(self.n_cls, self.n_cls+cur_n)])
                cls_name = [templates[idx](name) for name in cls_name]
                self.classnames += cls_name
                self.n_cls += cur_n
        else:
            raise NotImplementedError
        self.cur_n_cls = 0
        self.classnames = [name.replace("_", " ") for name in self.classnames]
        self.all_name_lens = [len(_tokenizer.encode(name)) for name in self.classnames]
        all_prompts = [name for name in self.classnames]
        self.register_buffer("all_tokenized_prompts", torch.cat([clip.tokenize(p) for p in all_prompts]))
        with torch.no_grad():
            self.register_buffer("all_embedding", clip_model.token_embedding(self.all_tokenized_prompts).type(clip_model.dtype))
        self.register_buffer("token_prefix", self.all_embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", self.all_embedding[:, 1:, :])  # CLS, EOS
        self.register_buffer("tokenized_prompts", self.all_tokenized_prompts.clone())
    
    def forward(self):
        prefix = self.token_prefix  # [n_cls, 1, ctx_dim]
        suffix = self.token_suffix  # [n_cls, ..., ctx_dim]
        prompts = torch.cat([prefix, suffix], dim=1)  # [n_cls, 77, ctx_dim]
        tokenized_prompts = self.tokenized_prompts  # [n_cls, 77] 或 [n_cls, 77, tkn_dim]
        return prompts, tokenized_prompts
        
    


    def update_classnames(self, task_id):
        class_idx = self.class_ids_per_task[task_id]
        class_idx_tensor = torch.tensor(class_idx, dtype=torch.int, device=self.all_embedding.device)
        self.token_prefix = self.all_embedding[class_idx_tensor, :1, :]
        self.token_suffix = self.all_embedding[class_idx_tensor, 1:, :]
        self.tokenized_prompts = self.all_tokenized_prompts[class_idx_tensor]
        self.name_lens = [self.all_name_lens[idx] for idx in class_idx]
        self.cur_n_cls = len(class_idx)


class CustomCLIP(nn.Module):
    def __init__(self, cfg, classnames, templates, clip_model, clip_model_ori=None):
        super().__init__()
        self.prompt_processor = PromptProcessor(cfg, classnames, templates, clip_model)
        self.image_encoder = clip_model.visual
        self.image_encoder_ori = clip_model_ori.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.vis_dim = clip_model.visual.output_dim
        self.pool_size = cfg.nb_task
        self.visual_prompt = cfg.TSPD.prompt_depth_vision > 0
        self.batchwise_prompt = cfg.TSPD.batchwise_prompt
        self.threshold = cfg.threshold
        self.log_path = cfg.log_path
        self.energy = cfg.energy

        self.register_buffer("singular_matrix", torch.empty(self.pool_size, self.vis_dim, self.vis_dim, dtype=torch.float))
        self.register_buffer("means", torch.empty(self.pool_size, self.vis_dim, dtype=torch.float))
        self.register_buffer("covars", torch.empty(self.pool_size, self.vis_dim, self.vis_dim, dtype=torch.float))
        self.register_buffer("task_learnt", torch.tensor(0, dtype=torch.int))

    def forward(self, image, task_id=None, ori_output=False):
        res = {}
        with torch.no_grad():
            image_features = self.image_encoder_ori(image.type(self.dtype), ori_output=True)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            res["image_features"] = image_features.detach()
        if task_id is not None:
            indice = task_id
        else:
            task_ids_svd, _ = compute_task_id(image_features, self.singular_matrix[:self.task_learnt.item()],threshold=self.threshold)
            if self.batchwise_prompt:
                prompt_id, id_counts = torch.unique(task_ids_svd, return_counts=True, sorted=True)
                _, major_idx = torch.topk(id_counts, k=1)
                indice = prompt_id[major_idx].item()
                res["raw_indice"] = task_ids_svd
        res["indice"] = indice

        prompts, tokenized_prompts = self.prompt_processor()  # [bs*n_cls, 77, ctx_dim]
        text_features = self.text_encoder(prompts, tokenized_prompts, indice, ori_output=ori_output)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        if self.visual_prompt:
            image_features = self.image_encoder(image.type(self.dtype), indice, res["image_features"], ori_output=ori_output)  # [bs, model_dim]
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()  # [bs, n_cls]
        res["outputs"] = logits
        return res

    def update_classnames(self, task_id):
        self.prompt_processor.update_classnames(task_id)


class TSPD:
    def __init__(self, cfg, device, classes_names, templates, load_file=None):
        self.build_model(cfg, device, classes_names, templates, load_file)


    def build_model(self, cfg, device, classes_names, templates, load_file=None):
        print(f"Loading CLIP (backbone: {cfg.model_backbone_name})")
        clip_model, clip_model_ori = load_clip_to_cpu(cfg, with_ori=True)
        print("Building custom CLIP")
        model = CustomCLIP(cfg, classes_names, templates, clip_model, clip_model_ori)
        print("Turning off gradients in both the image and the text encoder")
        names_to_update = ["lora", "gumbel"]

        for name, param in model.named_parameters():
            update_flag = False
            for name_to_update in names_to_update:
                if name_to_update in name:
                    update_flag = True
            if not update_flag:
                param.requires_grad_(False)

        enabled = set()
        for name, param in model.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        para_log = f"Parameters to be updated: {enabled}"
        f = open(osp.join(cfg.log_path, 'output.txt'), 'a')
        f.write(para_log + '\n')
        f.close()
        
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print("Total trainable params: {}.".format(total_params))

        self.model = model
        self.devices = device
        self.device = device[0]

        if load_file:
            self.load_model(None, None, load_file)
        
        self.model.to(device[0])
        if len(device) > 1:
            self.model = torch.nn.DataParallel(self.model, device_ids=device)
            
        self.model_wo_dp = self.model.module if len(device) > 1 else self.model


    def save_model(self, cfg, task_id):
        save_dict = {}
        for name, para in self.model.named_parameters():
            if para.requires_grad:
                save_dict[name] = para
        for name, para in self.model.named_buffers():  # for gaussian parameters
            if "means" in name or "covars" in name or "task_learnt" in name or "singular_matrix" in name:
                save_dict[name] = para
        save_dir = os.path.join(cfg.log_path, 'ckpt')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(save_dict, os.path.join(save_dir, f'task_{task_id}.pt'))
    
    
    def load_model(self, cfg, task_id, load_file=None):
        if load_file is None:
            load_file = os.path.join(cfg.log_path, 'ckpt', f'task_{task_id}.pt')
        if not osp.exists(load_file):
            raise FileNotFoundError('Model not found at "{}"'.format(load_file))

        state_dict = torch.load(load_file, map_location="cpu")

        print(f"Loading weights from {load_file}")
        self.model.load_state_dict(state_dict, strict=False)

        return [i for i in state_dict.keys()]    

    def train_and_eval(self, cfg, datasets):
        acc_list = []
        metric_logger = Logger(list_subsets=["train", "test"])
        metric_writer = open(os.path.join(cfg.log_path, 'metrics.json'), 'w')
        if cfg.zero_shot:
            with torch.no_grad():
                for cur_task in tqdm(range(cfg.nb_task)):
                    self.update_classnames(cur_task)
                    eval_loader = self.get_dataloader(cfg, datasets['test'], cur_task, is_train=False)
                    for sample in eval_loader:
                        inputs, targets, task_ids = parse_sample(sample, is_train=False, task_id=cur_task, cfg=cfg)
                        inputs, targets = inputs.to(self.device), targets.to(self.device)
                        res = self.model(inputs, cur_task, ori_output=True)
                        outputs = res["outputs"]
                        metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
                cur_all_task_acc = metric_logger.accuracy_per_task
                acc_list.append(cur_all_task_acc)
                log = {'acc_per_task': [round(100 * acc_t, 2) for acc_t in cur_all_task_acc]}
                metric_writer.write(json.dumps(log) + '\n')
                metric_writer.flush()
                print(log)
                return

        if cfg.eval_only:
            self.eval_all(cfg, datasets, metric_logger, metric_writer, acc_list)
            return

        for task_id in range(cfg.nb_task):
            print("Training for task {} has started. learning rate for lora is {}, gumbel_fc is {}.".format(task_id, cfg.TSPD.optim.lr[task_id], cfg.gumbel_lr))
            self.train_one_task(cfg, task_id, datasets, metric_logger)

            if datasets['val']:
                keys = self.load_model(cfg, task_id)
                log = f"Load best epoch weight (epoch {self.best_epoch}), parameters {keys}."
                # print(log)
                with open(osp.join(cfg.log_path, 'output.txt'), 'a') as f:
                    f.write(log + '\n')

            print(f"Evaluation for task {task_id} has started.")
            self.eval_all(cfg, datasets, metric_logger, metric_writer, acc_list, global_task_id=task_id)

        res = cal_MTIL_metrics(acc_list)
        metric_writer.write(json.dumps(res["transfer"]) + '\n')
        metric_writer.write(json.dumps(res["avg"]) + '\n')
        metric_writer.write(json.dumps(res["last"]) + '\n')
        metric_writer.write(json.dumps(res["results_mean"]) + '\n')
        metric_writer.flush()

    def train_one_task(self, cfg, task_id, datasets, metric_logger):

        train_dataset, val_dataset, eval_dataset = datasets['train'], datasets['val'], datasets['test']
        train_loader = self.get_dataloader(cfg, train_dataset, task_id, is_train=True)
        self.update_classnames(task_id)
        self.model.train()

        per_epoch_steps = len(train_loader)
        if cfg.TSPD.optim.name == 'SGD':
            lora_params = [p for n, p in self.model.named_parameters() if "lora" in n]
            gumbel_params = [p for n, p in self.model.named_parameters() if "gumbel" in n]
            lora_optimizer = torch.optim.SGD(
                lora_params,
                lr=cfg.TSPD.optim.lr[task_id],
                weight_decay=cfg.TSPD.optim.weight_decay
            )
            gumbel_optimizer = torch.optim.SGD(
                gumbel_params,
                lr=cfg.gumbel_lr,
                weight_decay=cfg.TSPD.optim.weight_decay
            )

        else:
            raise NotImplementedError
        
        if cfg.TSPD.optim.lr_scheduler == 'cosine':
            lora_scheduler = build_cosine_scheduler(lora_optimizer, lr=cfg.TSPD.optim.lr[task_id], total_step=cfg.TSPD.optim.max_epoch*per_epoch_steps)
            gumbel_scheduler = build_cosine_scheduler(gumbel_optimizer, lr=cfg.gumbel_lr, total_step=cfg.TSPD.optim.max_epoch*per_epoch_steps)
        elif cfg.TSPD.optim.lr_scheduler == 'no':
            lora_scheduler = None
            gumbel_scheduler = None
        else:
            raise NotImplementedError
        self.best_epoch = -1
        self.best_acc = -1

        all_image_features = torch.empty([0, self.model_wo_dp.vis_dim], dtype=self.model_wo_dp.dtype, device=self.device)
        with torch.no_grad():
            for sample in train_loader:
                inputs, _, _ = parse_sample(sample, is_train=False, task_id=task_id, cfg=cfg)
                image_features = self.model_wo_dp.image_encoder_ori(inputs.type(self.model_wo_dp.dtype).to(self.device), ori_output=True)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                all_image_features = torch.cat([all_image_features, image_features.detach()], dim=0)
        all_image_features = all_image_features.type(torch.float)  # to avoid precision problems
        mu_S = compute_svd_base(all_image_features, energy=self.model.energy)
        self.model_wo_dp.singular_matrix[task_id] = mu_S        
        self.model_wo_dp.task_learnt += 1
        for epoch in tqdm(range(cfg.TSPD.optim.max_epoch)):
            main_loss_tot = 0
            loss_num = 0
            for idx, sample in enumerate(train_loader):
                if lora_scheduler and gumbel_scheduler:
                    cur_iter_idx = epoch*per_epoch_steps+idx
                    lora_scheduler.step(cur_iter_idx)
                    gumbel_scheduler.step(cur_iter_idx)
                inputs, targets, task_ids = parse_sample(sample, is_train=True, task_id=task_id, cfg=cfg)
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                res = self.model(inputs, task_id)
                outputs = res["outputs"]
                loss_main = F.cross_entropy(outputs, targets)
                loss = loss_main
                lora_optimizer.zero_grad()
                gumbel_optimizer.zero_grad()
                loss.backward()
                lora_optimizer.step()
                gumbel_optimizer.step()
                main_loss_tot += loss_main.item()
                loss_num += 1
                metric_logger.add([outputs.detach().cpu().argmax(dim=1), targets.cpu(), task_ids], subset="train")
            log = f"\ntask{task_id}_epoch{epoch}:\n"
            log += f"train acc: {metric_logger.online_accuracy}"
            metric_logger.end_epoch()
            f = open(osp.join(cfg.log_path, 'output.txt'), 'a')
            f.write(log + '\n')
            f.close()

            log = f"avg main loss {round(main_loss_tot/loss_num, 5)}"
            f = open(osp.join(cfg.log_path, 'output.txt'), 'a')
            f.write(log + '\n')
            f.close()
            
            if val_dataset:
                self.model.eval()
                self.update_classnames(task_id)
                val_loader = self.get_dataloader(cfg, val_dataset, task_id, is_train=False)
                cur_right = torch.FloatTensor([0]).to(self.device)
                cur_all = torch.FloatTensor([0]).to(self.device)
                with torch.no_grad():
                    for sample in val_loader:
                        inputs, targets, task_ids = parse_sample(sample, is_train=False, task_id=task_id, cfg=cfg)
                        inputs, targets = inputs.to(self.device), targets.to(self.device)
                        res = self.model(inputs, task_id)
                        outputs = res["outputs"]
                        cur_right += torch.sum((outputs.argmax(dim=1)==targets))
                        cur_all += targets.size(0)
                cur_acc = cur_right/cur_all
                if cur_acc > self.best_acc:
                    self.best_epoch = epoch
                    self.best_acc = cur_acc
                    self.save_model(cfg, task_id)
                self.update_classnames(task_id)
                self.model.train()


    def eval_all(self, cfg, datasets, metric_logger, metric_writer, acc_list):
        eval_dataset = datasets['test']
        self.model.eval()
        set_random_seed(cfg.seed)

        for cur_task in tqdm(range(cfg.nb_task)):
            self.update_classnames(cur_task)
            eval_loader = self.get_dataloader(cfg, eval_dataset, cur_task, is_train=False)
            self.evaluate(cfg, eval_loader, cur_task, metric_logger)

        cur_all_task_acc = metric_logger.accuracy_per_task
        acc_list.append(cur_all_task_acc)
        log = {'acc_per_task': [round(100 * acc_t, 2) for acc_t in cur_all_task_acc]}
        metric_writer.write(json.dumps(log) + '\n')
        metric_writer.flush()
        print(log)
        metric_logger.end_task()


    def evaluate(self, cfg, loader, task_id, metric_logger=None):
        right_num = 0
        sample_num = 0
        with torch.no_grad():
            for sample in loader:
                inputs, targets, task_ids = parse_sample(sample, is_train=False, task_id=task_id, cfg=cfg)
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                res = self.model(inputs)

                outputs = res["outputs"]
                if metric_logger:
                    metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
                right_num += torch.sum(outputs.argmax(dim=1) == targets).item()
                sample_num += inputs.size(0)
        return right_num / sample_num


    def get_dataloader(self, cfg, dataset, task_id, is_train):
        batch_size = cfg.TSPD.optim.batch_size
        if isinstance(dataset, list):
            if cfg.TSPD.batchwise_prompt and (not is_train):
                batch_size *= 2
            loader = DataLoader(dataset[task_id], batch_size=int(batch_size), shuffle=is_train, num_workers=8)
        else:
            raise NotImplementedError
        return loader


    def update_classnames(self, task_id):
        if isinstance(self.model, torch.nn.DataParallel):
            self.model.module.update_classnames(task_id)
        else:
            self.model.update_classnames(task_id)
