import re
from _common import *

import logging
import os
import itertools
import pandas as pd
import torch
from torch import nn
from torch.optim import Adam
from copy import deepcopy
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import hydra
from omegaconf import DictConfig
from src.draw_distribution import draw_distribution


from clip_checkpoint_path import (
    CHECKPOINT_DIR,
    finetuned_model_path,
    pretrained_model_path,
    sam_retraining_model_path,
)
from timer import timer
from src.adamerging import softmax_entropy
from src.clip_eval import eval_single_dataset, eval_single_dataset_preprocess_head
# Note: This assumes the previously defined multi-model Optimal Transport fusion class is saved as src/optimal_transport_mask.py
from src.tasks.shortest_route_mask import *
from src.tasks.shortest_route_classification_heads import *
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.heads import get_classification_head
from src.task_vectors import StateDict, TaskVector
from src.task_wise_fusion import *
from src.task_wise_fusion import check_parameterNamesMatch
from src.utils import num_parameters, timeit_context
from tqdm.autonotebook import tqdm

log = logging.getLogger(__name__)
 

class Program:
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        cfg.save = str(CHECKPOINT_DIR / cfg.model)
        cfg.data_location = str(DATA_DIR)

        self.mask_alpha = 0.84
        self.device = torch.device("cuda:3")

        save_dir = RESULTS_DIR / cfg.model
        if cfg.version is not None:
            save_dir = save_dir / f"version_{cfg.version}"
        os.makedirs(save_dir, exist_ok=True)
        self.results_path = save_dir / "clip_optimal_transport.csv"
        self.ckpt_dir = self.results_path.parent / os.path.basename(self.results_path).split(".")[0]
        self.ckpt_path = self.ckpt_dir / os.path.basename(self.results_path).replace(".csv", ".pt")
        self.individual_results_path = save_dir / "clip_optimal_transport_individuals.csv"
        log.info(f'Results will be saved to "{self.results_path}"')

        # ==== NEW: for BWT ====
        # 记录“刚学完该任务时”的准确率 R_{i,i}
        self.acc_when_learned = {}   # {dataset_name: acc_at_step_i}
        # 记录“最终合并模型”在各任务上的准确率 R_{T,i}
        self.final_acc = {}          # {dataset_name: acc_final}
        # BWT 输出文件
        self.bwt_path = save_dir / "clip_optimal_transport_bwt.txt"

        # ==== NEW: for milestone BWTs ====
        self.bwt_milestones = [8, 14, 20]     # 在这些“学到第k个任务后”计算 BWT_k
        self.acc_step = {}                    # {k: {dataset_name: acc_at_step_k}}
        self.bwt_step_values = {}             # {k: bwt_k}
        self.bwt_steps_path = save_dir / "clip_optimal_transport_bwt_steps.txt"


    @torch.no_grad()
    def _eval_one_dataset_with_state_dict(self, merged_state_dict: dict, dataset_name: str) -> float:
        """
        使用给定的 merged_state_dict（被理解为完整权重，而非delta），
        评估该模型在指定 dataset 上的 top1 准确率。
        """
        model = deepcopy(self.pretrained_model)
        model.load_state_dict({k: v.to(self.device) for k, v in merged_state_dict.items()})
        model = model.to(self.device)
        metrics = eval_single_dataset_preprocess_head(
            model,
            self.classification_heads[dataset_name],
            dataset_name,
            self.cfg,
            dataloader=self.test_loaders[dataset_name],
        )
        return float(metrics["top1"])

    def run(self):
        self.load_models()
        self.load_datasets()
        self.Shortest_Route_Fusion()
        self.eval_individuals()

    # ==== NEW: helper to compute & log BWT at step k ====
    def _compute_and_log_bwt_for_step(self, k: int):
        """
        使用 self.acc_step[k] 与 self.acc_when_learned 计算 BWT_k 并写入文件/日志。
        BWT_k = mean_{i in seen[1:]} ( R_{k,i} - R_{i,i} )
        其中 seen = cfg.datasets[:k]
        """
        try:
            datasets_seen = list(self.cfg.datasets)[:k]
            if len(datasets_seen) < 2:
                logging.warning(f"[BWT@{k}] Need at least 2 tasks to compute BWT.")
                return

            if k not in self.acc_step:
                logging.warning(f"[BWT@{k}] No acc_step found for step {k}.")
                return

            terms = []
            for name in datasets_seen[1:]:
                if name in self.acc_when_learned and name in self.acc_step[k]:
                    terms.append(self.acc_step[k][name] - self.acc_when_learned[name])

            if len(terms) == 0:
                logging.warning(f"[BWT@{k}] No valid terms (missing acc_when_learned or acc_step).")
                return

            bwt_k = sum(terms) / len(terms)
            self.bwt_step_values[k] = float(bwt_k)

            # 追加写入步骤汇总文件
            os.makedirs(self.bwt_steps_path.parent, exist_ok=True)
            with open(self.bwt_steps_path, "a") as f:
                f.write(f"{k}-task BWT: {bwt_k:.6f}\n")

            log.info(f"[BWT] {k}-task BWT = {bwt_k:.4f} (avg over {len(terms)} tasks)")
        except Exception as e:
            log.warning(f"[BWT@{k}] Failed to compute/log BWT: {e}")
            
        
    def eval_individuals(self):
        log.info("Start eval individuals (optimal transport version)")
        cfg = self.cfg

        # Fix the loading method here
        loaded_dict = torch.load(
            self.ckpt_path, map_location="cpu"
        )["merged_state_dict"]

        merged_state_dict = {k: v.to(self.device) for k, v in loaded_dict.items()}

        results = {"dataset": [], "acc": []}
        Total_ACC = 0
        for dataset_idx, dataset_name in enumerate(tqdm(cfg.datasets, desc="Evaluating individual models")):
            model = deepcopy(self.pretrained_model)

            # For each task, we add the merged parameters to the pretrained model
            model.load_state_dict(merged_state_dict)
            model = model.to(self.device)

            # For each task, visualize the distribution of the merged model and task-specific model using T-SNE
            # draw_distribution(
            #     task_vector=self.task_vectors[dataset_idx],
            #     merged_state_dict=merged_state_dict,
            #     pretrained_model=deepcopy(self.pretrained_model),
            #     dataloader=self.shuffled_test_loader_iters[dataset_name],
            #     device=self.device,
            #     dataset_name=dataset_name,
            #     type="T-SNE-OT-ViT-B-32",
            # )

            metrics = eval_single_dataset_preprocess_head(
                model,
                self.classification_heads[dataset_name],
                dataset_name,
                cfg,
                dataloader=self.test_loaders[dataset_name],
            )
            Total_ACC += metrics["top1"]
            log.info(f"Eval: dataset: {dataset_name} ACC: {metrics['top1']:.3f}")
            # >>> 新增这一行：记录最终合并模型在各任务上的准确率 R_{T,i}
            self.final_acc[dataset_name] = float(metrics["top1"])

            results["dataset"].append(dataset_name)
            results["acc"].append(metrics["top1"])

        log.info(f"Eval: Avg ACC: {Total_ACC/len(cfg.datasets):.3f}\n")


        # Print accuracy of all tasks
        log.info("Eval: All tasks accuracy:")
        for acc in results["acc"]:
            print(f"{acc:.3f}", end=" ")
        log.info(f"device: {self.device}, mask_alpha: {self.mask_alpha}")
        pd.DataFrame(results).to_csv(self.individual_results_path, index=False)

        # ==== NEW: 计算 BWT ====
        try:
            datasets = list(cfg.datasets)
            if len(datasets) >= 2:
                terms = []
                for name in datasets[1:]:
                    if name in self.acc_when_learned and name in self.final_acc:
                        terms.append(self.final_acc[name] - self.acc_when_learned[name])

                if len(terms) > 0:
                    bwt = sum(terms) / len(terms)
                    with open(self.bwt_path, "w") as f:
                        f.write(f"{bwt:.6f}\n")
                    log.info(f"[BWT] Backward Transfer = {bwt:.4f}  (averaged over {len(terms)} tasks)")
                else:
                    log.warning("[BWT] No valid terms to compute BWT (did not record acc_when_learned or final_acc).")
            else:
                log.warning("[BWT] Need at least 2 tasks to compute BWT.")
        except Exception as e:
            log.warning(f"[BWT] Failed to compute BWT: {e}")


    @torch.no_grad()
    def eval_model_on_datasets(self, epoch_idx: int, results: dict):
        model = deepcopy(self.pretrained_model)

        # For each task, we add the merged parameters to the pretrained model
        state_dict = model.state_dict()
        for n, p in self.fused_state_dict.items():
            if n in state_dict:
                state_dict[n] = state_dict[n] + p.to(self.device)
        model.load_state_dict(state_dict)
        model = model.to(self.device)

        self.model.eval()
        Total_ACC = 0
        for dataset_name in self.cfg.datasets:
            classification_head = self.classification_heads[dataset_name]
            metrics = eval_single_dataset_preprocess_head(
                self.model,
                classification_head,
                dataset_name,
                self.cfg,
                dataloader=self.test_loaders[dataset_name],
            )
            Total_ACC += metrics["top1"]
            log.info(f"Eval: dataset: {dataset_name} ACC: {metrics['top1']:.2f}")

            if results is not None:
                results["epoch"].append(epoch_idx)
                results["dataset"].append(dataset_name)
                results["acc"].append(metrics["top1"])
        log.info(f"Eval: Avg ACC: {Total_ACC/len(self.cfg.datasets):.2f}\n")

    def load_models(self):
        cfg = self.cfg

        if cfg.sam_retraining:
            log.info("SAM retrained model is used")
            _finetuned_model_path = sam_retraining_model_path
        else:
            _finetuned_model_path = finetuned_model_path
        with timeit_context():
            log.info("load models")
            pretrained_model: nn.Module = torch.load(pretrained_model_path(cfg.model), map_location="cpu")
            task_vectors: List[StateDict] = [
                TaskVector(
                    pretrained_checkpoint=pretrained_model_path(cfg.model),
                    finetuned_checkpoint=_finetuned_model_path(cfg.model, dataset_name),
                ).vector
                for dataset_name in tqdm(cfg.datasets)
            ]
            # Check if the parameter names in the task vectors match
            check_parameterNamesMatch(task_vectors)

        self.pretrained_model = pretrained_model
        #! If GPU memory is not enough, comment the following line
        if cfg.model == "ViT-B-32" or cfg.model == "ViT-B-16":
            task_vectors = [{k: v.to(self.device) for k, v in tv.items()} for tv in task_vectors]
        elif cfg.model == "ViT-L-14":
            temp = []
            for tv in task_vectors[:5]:
                temp.append({k: v.cuda(4, non_blocking=True) for k, v in tv.items()})
            for tv in task_vectors[5:10]:
                temp.append({k: v.cuda(5, non_blocking=True) for k, v in tv.items()})
            for tv in task_vectors[10:15]:
                temp.append({k: v.cuda(6, non_blocking=True) for k, v in tv.items()})
            for tv in task_vectors[15:]:
                temp.append({k: v.cuda(7, non_blocking=True) for k, v in tv.items()})
            task_vectors = temp
        self.task_vectors = task_vectors

        self.classification_heads = {dataset_name: get_classification_head(cfg, dataset_name).to("cuda:3") for dataset_name in cfg.datasets}


    def Shortest_Route_Fusion(self):
        pretrained_model = self.pretrained_model
        #task_vectors = deepcopy(self.task_vectors)

        for p in pretrained_model.parameters():
            p.detach_().requires_grad_(False)
        
        # Initialize the first mask
        self.masks_pre_init = ShortestRouteMask(
            state_dict=self.task_vectors[0],
            init_value=0,
        )
        self.masks_post_init = ShortestRouteMask(
            state_dict=self.task_vectors[1],
            init_value=0,
        )
        
        self.masks_pre = self.masks_pre_init._draw_mask()
        self.masks_post = self.masks_post_init._draw_mask()
        
        # Gradually fuse subsequent tasks
        for i in range(0, len(self.task_vectors) - 1):

            # post_task_dataset = self.cfg.datasets[i+1]
            # pre_task_dataset = self.cfg.datasets[i]
            # task_vector_pre = task_vectors[i]
            # task_vector_post = task_vectors[i+1]
            
            pre_task_dataloader = self.shuffled_test_loader_iters[self.cfg.datasets[i]]
            post_task_dataloader = self.shuffled_test_loader_iters[self.cfg.datasets[i+1]]

           
            self.merged_state_dict = compute_sr_mask(
                task_vector_pre=self.task_vectors[i],
                task_vector_post=self.task_vectors[i+1], 
                pretrained_model=deepcopy(pretrained_model),
                masks_pre=self.masks_pre,
                masks_post=self.masks_post,
                pre_task_dataloader=pre_task_dataloader,
                post_task_dataloader= post_task_dataloader,
                mask_alpha=self.mask_alpha,
                device=self.device,
                )

            self.classification_heads[self.cfg.datasets[i]] = compute_sr_classification_heads_single(
                task_vector_pre=self.task_vectors[i],
                pretrained_model=deepcopy(pretrained_model),
                merged_state_dict=self.merged_state_dict,
                classification_head_pre=self.classification_heads[self.cfg.datasets[i]],
                pre_task_dataloader=pre_task_dataloader,
                device=self.device,
            )

            if(i == len(self.task_vectors) - 2):
                self.classification_heads[self.cfg.datasets[i+1]] = compute_sr_classification_heads_single(
                    task_vector_pre=self.task_vectors[i+1],
                    pretrained_model=deepcopy(pretrained_model),
                    merged_state_dict=self.merged_state_dict,
                    classification_head_pre=self.classification_heads[self.cfg.datasets[i+1]],
                    pre_task_dataloader=post_task_dataloader,
                    device=self.device,
                )

            # ==== NEW: 记录 R_{i,i} （注意此时“学完”的是第 i+1 个任务，索引名取 datasets[i+1] 更符合定义，
            # 但常见做法也用 datasets[i]。为与文献定义完全一致，这里记第 i+1 个任务）====
            just_learned_dataset = self.cfg.datasets[i+1]
            try:
                acc_ii = self._eval_one_dataset_with_state_dict(self.merged_state_dict, just_learned_dataset)
                self.acc_when_learned[just_learned_dataset] = acc_ii
                log.info(f"[BWT trace] After learning {just_learned_dataset}: ACC_i_i={acc_ii:.3f}")
            except Exception as e:
                log.warning(f"[BWT trace] Eval on {just_learned_dataset} failed during step {i+1}: {e}")

                        # ==== NEW: 当学到第 k 个任务（k = i + 2）时，评估并计算 BWT_k ====
            k = i + 2  # 因为 i 从 0 开始，融合(i,i+1)结束后共学了前 (i+2) 个任务
            total_tasks = len(self.cfg.datasets)
            if k in self.bwt_milestones and k <= total_tasks:
                try:
                    # 评估当前合并模型在前 k 个任务上的精度 R_{k,i}
                    self.acc_step[k] = {}
                    for name in self.cfg.datasets[:k]:
                        acc_ki = self._eval_one_dataset_with_state_dict(self.merged_state_dict, name)
                        self.acc_step[k][name] = float(acc_ki)
                        log.info(f"[BWT trace] Step {k}: R_k,i on {name} = {acc_ki:.3f}")

                    # 计算并写入 BWT_k
                    self._compute_and_log_bwt_for_step(k)
                except Exception as e:
                    log.warning(f"[BWT step {k}] evaluation failed: {e}")
            
            # Update the merged parameters
            #merged_vector = deepcopy(task_vectors[i])
            for k in self.merged_state_dict.keys():
            #for k, v in task_vectors[i].items():
                if k in pretrained_model.state_dict():
                    #self.task_vectors[i+1][k] = self.merged_state_dict[k].to(self.device) - pretrained_model.state_dict()[k].to(self.device)
                    self.task_vectors[i+1][k] = (self.merged_state_dict[k].cpu() - pretrained_model.state_dict()[k].cpu())
            #task_vectors[i+1] = merged_vector
            
            
            self.masks_pre = self.masks_pre_init._draw_mask()
            self.masks_post = self.masks_post_init._draw_mask()

            if i == len(self.task_vectors) - 2:
                # If this is the last iteration, we save the final merged state dict
                self.fused_state_dict = self.merged_state_dict
            # 在循环体结束前加入
            del (
                pre_task_dataloader, post_task_dataloader,
                self.merged_state_dict,
            )
            # 如果你把 RouteMergedModel 存进了 self 的某个属性，也要 del
            if hasattr(self, "route_merged_model"):
                del self.route_merged_model

            torch.cuda.empty_cache()
    
        # Save checkpoint
        os.makedirs(self.ckpt_dir, exist_ok=True)
        torch.save(
            {
                "merged_state_dict": self.fused_state_dict,
            },
            self.ckpt_path,
        )
        #self.task_vectors = deepcopy(task_vectors)
        

    def load_datasets(self):
        cfg = self.cfg
        if cfg.corruption is None:
            from src.datasets.registry import get_dataset
        else:
            from src.datasets.corruption.registry import get_dataset

        datasets_kwargs = dict(
            location=cfg.data_location,
            batch_size=16,
            num_workers=8,
        )
        if cfg.corruption is not None:
            datasets_kwargs["corruption"] = cfg.corruption
        datasets = {
            dataset_name: get_dataset(
                dataset_name,
                self.pretrained_model.val_preprocess,
                **datasets_kwargs,
            )
            for dataset_name in cfg.datasets
        }
        shuffled_test_loaders = {
            dataset_name: dataset.test_loader_shuffle for dataset_name, dataset in datasets.items()
        }
        shuffled_test_loader_iters = {
            dataset_name: iter(itertools.cycle(dataloader))
            for dataset_name, dataloader in shuffled_test_loaders.items()
        }
        self.datasets = datasets
        self.test_loaders = {name: ds.test_loader for name, ds in datasets.items()}
        self.shuffled_test_loader_iters = shuffled_test_loader_iters





@hydra.main(config_path=str(CONFIG_DIR), config_name="default", version_base=None)
def main(cfg: DictConfig) -> None:
    Program(cfg).run()


if __name__ == "__main__":
    main()
