from _common import *

log = logging.getLogger(__name__)

import types
from src.clip_eval import eval_single_dataset
from src.task_vectors import StateDict, TaskVector, state_dict_mean
from src.ties_merging_utils import check_parameterNamesMatch

from clip_checkpoint_path import CHECKPOINT_DIR, finetuned_model_path, pretrained_model_path

from src.task_vectors import NonLinearTaskVector
from singular_vector import TSVM_utils
from src.tasks.shortest_route_classification_heads import *


@hydra.main(config_path=str(CONFIG_DIR), config_name="default", version_base=None)
def main(cfg: DictConfig) -> None:
    cfg.save = str(CHECKPOINT_DIR / cfg.model)
    cfg.data_location = str(DATA_DIR)
    model = cfg.model

    log.info("load finetuned models")
    task_vectors: List[StateDict] = [
                TaskVector(
                    pretrained_checkpoint=pretrained_model_path(cfg.model),
                    finetuned_checkpoint=finetuned_model_path(cfg.model, dataset_name),
                ).vector
                for dataset_name in tqdm(cfg.datasets)
            ]
    check_parameterNamesMatch(task_vectors)

    for i in range(len(task_vectors) - 1):
        print(f"Computing mean state dict for checkpoints {i} and {i + 1}")
        continual_ft = [task_vectors[i], task_vectors[i + 1]]
        # task_vectors = [
        #     NonLinearTaskVector(model, ft_checks[0], check) for check in continual_ft
        # ]
        
        config = types.SimpleNamespace()
        config.DATASETS = ["0", "1"]
        config.device = "cuda:0"
        TSVM_state_dict = TSVM_utils.compute_and_sum_svd_mem_reduction(continual_ft, config)
        task_vectors[i + 1] = TSVM_state_dict

    merged_vector = NonLinearTaskVector(vector=TSVM_state_dict)

    #将merged_vector map到cpu上
    for key in merged_vector.vector:
        merged_vector.vector[key] = merged_vector.vector[key].cpu()
    #image_encoder: nn.Module = torch.load(pretrained_model_path(model))
    image_encoder = merged_vector.apply_to(
            pretrained_model_path(model)
        )

    results = {"dataset": [], "acc": []}

    accs = []
    for dataset in cfg.datasets:
        metrics = eval_single_dataset(image_encoder, dataset, cfg)
        log.info(str(dataset) + ":" + str(metrics.get("top1") * 100) + "%")
        acc = metrics.get("top1")
        accs.append(metrics.get("top1") * 100)

        results["dataset"].append(dataset)
        results["acc"].append(acc)

    log.info("Avg ACC:" + str(np.mean(accs)) + "%")

    log.info("Eval: All tasks accuracy:")
    for acc in results["acc"]:
        print(f"{acc:.3f}", end=" ")


    os.makedirs(RESULTS_DIR / cfg.model, exist_ok=True)
    df = pd.DataFrame(results)
    df.to_csv(RESULTS_DIR / cfg.model / "averaging.csv", index=False)


if __name__ == "__main__":
    main()
