import os
from typing import Dict, Optional, Tuple

import torch
from pathlib import Path
from omegaconf import DictConfig, open_dict

from src.models.task_vectors import ImageEncoder, NonLinearTaskVector
from src.utils.tallmask_utils import construct_consensus_mask, construct_tall_mask, load_tall_mask
from src.utils.ties_utils import ties_merging
from src.utils.subspace_boosting_utils import subspace_boosting
from src.utils.ho_gsvd_utils import ho_gsvd
from src.utils.cart_utils import cart
from src.utils.tsvm_utils import compute_and_sum_svd_mem_reduction
from src.utils.utils import (
    check_parameterNamesMatch,
    check_state_dicts_equal,
    state_dict_to_vector,
    topk_values_mask,
    vector_to_state_dict,
)
from src.utils.variables_and_paths import get_finetuned_path, get_zeroshot_path, get_averaged_path


def get_all_checkpoints(config: DictConfig) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
    """
    Retrieves all the checkpoints for the given configuration.

    Args:
        config (DictConfig): The configuration object containing the model location, datasets, and model name.

    Returns:
        Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple containing two dictionaries.
            The first dictionary contains the checkpoints for each dataset in the configuration's validation datasets.
            The second dictionary contains the checkpoint for the zeroshot model.
    """

    model_dir = config.model_location
    print("I am getting out all the checkpoints")
    print("datasets:", config.DATASETS_VAL)
    print("model:", config.model)
    for dataset in config.DATASETS_VAL:
        path = get_finetuned_path(model_dir, dataset, model=config.model)
        if os.path.exists(path):
            print(f"{path} exists")
        else:
            print(f"{path} does not exist")

    params = {
        dataset: torch.load(get_finetuned_path(model_dir, dataset, model=config.model), map_location="cpu")
        for dataset in config.DATASETS_VAL
    }

    # convert dict to vector
    params = list(params.values())

    try:
        ptm_check = torch.load(get_zeroshot_path(model_dir, "MNISTVal", model=config.model), map_location="cpu")
    except:
        ptm_check = ImageEncoder(config.model).state_dict()
        torch.save(ptm_check, get_zeroshot_path(model_dir, "MNISTVal", model=config.model))

    return params, ptm_check


def create_task_vector(config: DictConfig) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
    """
    Creates a task vector based on the given configuration.

    Args:
        config (DictConfig): The configuration for creating the task vector.

    Returns:
        Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: A tuple containing the task vector and evaluation masks
            (if applicable).
    """

    ft_checks, ptm_check = get_all_checkpoints(config)
    check_parameterNamesMatch(ft_checks + [ptm_check])

    remove_keys = []

    print(f"Flattening out Checkpoints")
    flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
    flat_ptm = state_dict_to_vector(ptm_check, remove_keys)

    # compute the task vector as {\theta_t - \theta_0}.
    tv_flat_checks = flat_ft - flat_ptm

    # Save the averaged model if it doesn't exist
    if config.method.name == "cart":
        # save the averaged model if it doesn't exist.
        file_path = get_averaged_path(config.model_location, config.num_tasks, config.model)
        if not Path(file_path).exists():
            # save averaged checkpoint
            averaged_tv_flat_checks = flat_ft.mean(dim=0)

            check_state_dicts_equal(vector_to_state_dict(averaged_tv_flat_checks, ptm_check, remove_keys), ptm_check)
            state_dict = vector_to_state_dict(averaged_tv_flat_checks, ptm_check, remove_keys)

            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            torch.save(state_dict, file_path)

    # check if the vectorized state dicts can be converted back to the original state dicts
    # covnert back the flat task vectors to state dict and see if the original and converted sd's are equal
    assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check)
    assert all(
        [
            check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])
            for i in range(len(ft_checks))
        ]
    )
    print(f"MODEL: {config.model}, METHOD {config.method.name}")

    if config.method.name == "ties":
        # TIES Merging
        merge_func = "dis-mean"
        merged_tv = ties_merging(tv_flat_checks, reset_thresh=config.method.k, merge_func=merge_func)
    elif config.method.name in ["sum", "zeroshot", "average"]:
        # "sum" corresponds to Task Arithmetic (TA)
        # TA, zeroshot, weight average all construct the task vector with sum, but use different scaling factors.
        tv_flat_checks = topk_values_mask(tv_flat_checks, K=config.method.k)
        merged_tv = tv_flat_checks.sum(dim=0)
    elif config.method.name == "tall_mask":
        # construct multi-task vector
        if config.method.use_ties:
            print(f"Using TIES for constructing multi-task vector")
            merged_tv = ties_merging(tv_flat_checks, reset_thresh=20, merge_func=f"dis-sum")
        else:
            print(f"Using Task Arithmetic for constructing multi-task vector")
            tv_flat_checks = topk_values_mask(tv_flat_checks, K=config.method.k)
            merged_tv = tv_flat_checks.sum(dim=0)
        # get TALL masks
        if config.method.load_mask:
            # load tall masks directly from storage
            eval_masks = load_tall_mask(remove_keys, ptm_check, config)
        else:
            print(f"=== Constructing TALL Mask ===")
            # construct tall masks
            eval_masks = construct_tall_mask(
                tv_flat_checks, flat_ft, flat_ptm, merged_tv, ptm_check, remove_keys, config
            )
    elif config.method.name == "consensus":  # consensus merging
        # construct consensus mask (assuming the TALL masks have already been constructed)
        consensus_mask = construct_consensus_mask(ptm_check, config.method.prun_thre_k, config, remove_keys)
        # construct multi-task vector
        if config.method.use_ties:
            merged_tv = ties_merging(tv_flat_checks, reset_thresh=20, merge_func="dis-sum")
        else:
            tv_flat_checks = topk_values_mask(tv_flat_checks, K=config.method.k)  # top-k mag filtering
            merged_tv = tv_flat_checks.sum(dim=0)
        # apply the consensus mask to filter multi-task vector
        merged_tv = merged_tv * consensus_mask
    elif config.method.name == "subspace_boosting":
        merged_tv = subspace_boosting(
            tv_flat_checks, 
            ptm_check, 
            base_method=config.method.base_method, 
            config=config,
            reset_thresh=config.method.k, 
            svd_thresh=config.method.svd_thresh,
            attn_svd_thresh=config.method.attn_svd_thresh,
            cumsum=config.method.cumsum
        )
    elif config.method.name == "tsvm":
        task_vectors = [
            NonLinearTaskVector(config.model, ptm_check, check) for check in ft_checks
        ]
        merged_tv = compute_and_sum_svd_mem_reduction(task_vectors, config)
    elif config.method.name == "cart":
        mean_tv_flat_checks = flat_ft - flat_ft.mean(dim=0)
        merged_tv = cart(mean_tv_flat_checks, ptm_check)
    elif config.method.name == "ho_gsvd":
        merged_tv = ho_gsvd(tv_flat_checks, ptm_check, config)
    else:
        raise ValueError(f"Method {config.method.name} not defined.")
    
    if config.method.apply_lines:
        # Compute L1 norm of the multi-task vector for LiNeS scaling
        with open_dict(config):
            config.norm_mtv = (merged_tv).abs().sum().item()
            config.norm_summed_tvs = (tv_flat_checks.sum(dim=0)).abs().sum().item()

    if config.method.name in ["tsvm"]:
        task_vector = NonLinearTaskVector(model_name=config.model, vector=merged_tv)
        # Move tensors to same device as ptm for merging later on
        with torch.no_grad():
            task_vector.vector = {k: v.cpu() for k, v in task_vector.vector.items()}
    else:
        merged_tv_state_dict = vector_to_state_dict(
            merged_tv, ptm_check, 
            remove_keys=remove_keys, 
            replace_layers=config.replace_layers, # Exclude layers and components
            replace_components=config.replace_components
        )

        task_vector = NonLinearTaskVector(model_name=config.model, vector=merged_tv_state_dict)

    print("Norm of task vector: ", task_vector.norm())

    if config.method.name not in ["tall_mask", "mag_masking"]:
        eval_masks = None

    return task_vector, eval_masks
