from lib2to3.pgen2 import token
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

from typing import Any

from _common import *

from scripts._common import DictConfig

log = logging.getLogger(__name__)

from collections import defaultdict

import lightning as L
import lightning.fabric
import lightning.pytorch as pl
from flan_t5_checkpoint_path import finetuned_model_path
from flan_t5_individuals import Program as _Program
from flan_t5_individuals import metric_func
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, default_data_collator
from transformers.generation import GenerationConfig, GenerationMixin

from datasets import DatasetDict, load_dataset, load_from_disk
from src.adamerging import softmax_entropy
from src.concrete_mask import ConcreteMask
from src.task_wise_fusion import *
from src.tasks.arithmetic import state_dict_avg, state_dict_sub, state_dict_sum
from src.ties_merging_utils import *
from src.utils import num_devices, num_parameters, timeit_context
from src.tasks.shortest_route_mask_flant5 import *
from src.tasks.shortest_route_classification_heads import *

# Disable tokenizers parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class Program(_Program):
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        if hasattr(cfg, "seed") and cfg.seed is not None:
            log.info(f"set seed to {cfg.seed}")
            L.seed_everything(cfg.seed)

        if cfg.peft.peft_config is None:
            self.results_dir = RESULTS_DIR / cfg.model.name
        else:
            self.results_dir = RESULTS_DIR / (cfg.model.name + "_" + cfg.peft.name)
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.results_path = self.results_dir / "concrete_task_arithmetic.csv"
        self.ckpt_dir = self.results_dir / "concrete_task_arithmetic"
        self.ckpt_path = self.ckpt_dir / "ckpt.pt"
        self.individual_results_path = self.results_dir / "concrete_task_arithmetic_individuals.csv"

        self.fabric = L.Fabric(accelerator="cuda", devices=1)
        self.fabric.launch()
        self.mask_alpha = 0.82 
        self.device = torch.device("cuda:0")

    def run(self):
        self.load_models(task_vector_device=torch.device("cuda:1"))
        self.load_datasets()
        self.Shortest_Route_Fusion()
        self.eval_individuals()

    def Shortest_Route_Fusion(self):
        pretrained_model = self.pretrained_model
        task_vectors = self.task_vectors

        for p in pretrained_model.parameters():
            p.detach_().requires_grad_(False)
        
        # Initialize the first mask
        self.masks_pre_init = ShortestRouteMask(
            state_dict=task_vectors[0],
            init_value=1,
        )
        self.masks_post_init = ShortestRouteMask(
            state_dict=task_vectors[1],
            init_value=1,
        )
        
        self.masks_pre = self.masks_pre_init._draw_mask()
        self.masks_post = self.masks_post_init._draw_mask()
        
        # Iteratively fuse the following tasks
        for i in range(0, len(task_vectors) - 1):

            post_task_dataset = self.cfg.test_datasets[i+1]
            pre_task_dataset = self.cfg.test_datasets[i]
            task_vector_pre = task_vectors[i]
            task_vector_post = task_vectors[i+1]
            
            pre_task_dataloader = self.shuffled_test_loader_iters[pre_task_dataset]
            post_task_dataloader = self.shuffled_test_loader_iters[post_task_dataset]
            self.merged_state_dict = compute_sr_mask(
                task_vector_pre=task_vector_pre,
                task_vector_post=task_vector_post, 
                pretrained_model=deepcopy(pretrained_model),
                masks_pre=self.masks_pre,
                masks_post=self.masks_post,
                pre_task_dataloader=pre_task_dataloader,
                post_task_dataloader= post_task_dataloader,
                mask_alpha=self.mask_alpha,
                device=self.device,
                pad_token_id=self.tokenizer.pad_token_id,
                )

            # Calculate the merged vector
            merged_vector = deepcopy(task_vectors[i])
            for k, v in task_vector_pre.items():
                if k in pretrained_model.state_dict():
                    merged_vector[k] = self.merged_state_dict[k].to(self.device) - pretrained_model.state_dict()[k].to(self.device)

            task_vectors[i+1] = merged_vector
            
            self.masks_pre = self.masks_pre_init._draw_mask()
            self.masks_post = self.masks_post_init._draw_mask()
    
        # Save checkpoint
        os.makedirs(self.ckpt_dir, exist_ok=True)
        torch.save(
            {
                "merged_state_dict": self.merged_state_dict,
            },
            self.ckpt_path,
        )
        
    @torch.no_grad()
    def eval_individuals(self):
        log.info("start eval indivuduals")
        cfg = self.cfg
        # Fix loading method here
        loaded_dict = torch.load(
            self.ckpt_path, map_location="cpu"
        )["merged_state_dict"]

        merged_state_dict = {k: v.to(self.device) for k, v in loaded_dict.items()}

        results = {"dataset": [], "acc": []}
        Total_score = 0
        for dataset_idx, dataset_name in enumerate(tqdm(cfg.test_datasets, desc="Evaluating individual models")):
            model = deepcopy(self.pretrained_model)
            
            model.load_state_dict(merged_state_dict)
            model = model.to(self.device)

            model = self.fabric.setup_module(model)

            score = metric_func[dataset_name](model, self.test_loaders[dataset_idx], self.tokenizer)
            log.info(f"Eval: {dataset_name} - Score: {score}")
            Total_score += score

        log.info("Eval: " + " Avg score:" + str(Total_score / len(cfg.test_datasets)) + "\n")
        pd.DataFrame(results).to_csv(self.individual_results_path, index=False)


@hydra.main(str(CONFIG_DIR), "flan_t5_default", None)
def main(cfg: DictConfig):
    (program := Program(cfg)).run()


if __name__ == "__main__":
    main()
