from _common import *

log = logging.getLogger(__name__)

from src.clip_eval import eval_single_dataset
from src.ties_merging_utils import *

from clip_checkpoint_path import CHECKPOINT_DIR, finetuned_model_path, pretrained_model_path, sam_retraining_model_path
from src.draw_distribution_tie import draw_distribution_tie
from src.datasets.common import maybe_dictionarize
from src.datasets.registry import get_dataset
from torch.utils.data import DataLoader


@hydra.main(config_path=str(CONFIG_DIR), config_name="default", version_base=None)
def main(cfg: DictConfig):
    cfg.save = str(CHECKPOINT_DIR / cfg.model)
    cfg.data_location = str(DATA_DIR)
    model = cfg.model
    pretrained_checkpoint = pretrained_model_path(model)
    pretrained_model: nn.Module = torch.load(pretrained_checkpoint)

    datasets = {
        dataset_name: get_dataset(
            dataset_name, pretrained_model.val_preprocess, location=cfg.data_location, batch_size=4, num_workers=cfg.num_workers
        )
        for dataset_name in cfg.datasets
    }
    shuffled_test_loaders: Dict[str, DataLoader] = {dataset_name: dataset.test_loader_shuffle for dataset_name, dataset in datasets.items()}
    shuffled_test_loader_iters = {dataset_name: iter(itertools.cycle(dataloader)) for dataset_name, dataloader in shuffled_test_loaders.items()}

    pretrained_checkpoint = pretrained_model_path(model)
    if cfg.sam_retraining:
        log.info("SAM retrained model is used")
        _finetuned_model_path = sam_retraining_model_path
    else:
        _finetuned_model_path = finetuned_model_path

    # Load pretrained and finetuned checkpoints
    ft_checks: List[StateDict] = [
        torch.load(_finetuned_model_path(model, dataset_name), map_location="cpu").state_dict()
        for dataset_name in tqdm(cfg.datasets, "load finetuned checkpoints")
    ]
    ptm_check: StateDict = torch.load(pretrained_checkpoint, map_location="cpu").state_dict()
    check_parameterNamesMatch(ft_checks + [ptm_check])

    remove_keys = []
    print("Flattening out Checkpoints")
    flat_ptm = state_dict_to_vector(ptm_check, remove_keys)
    flat_fts = [state_dict_to_vector(check, remove_keys) for check in ft_checks]

    tvs = [flat - flat_ptm for flat in flat_fts]

    K = 20
    merge_func = "dis-sum"
    results = {"scaling_coef": [], "dataset": [], "acc": []}

    for scaling_coef_ in np.linspace(0, 1, 11):
        for i in range(0, len(tvs)-1):
            merged_tv = ties_merging(
                torch.stack([tvs[i], tvs[i+1]]),
                reset_thresh=K,
                merge_func=merge_func,
            )
            tvs[i + 1] = merged_tv  # update intermediate task vector


        # Apply scaling and reconstruct merged weights
        merged_check = flat_ptm + scaling_coef_ * merged_tv
        merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys)

        image_encoder: nn.Module = torch.load(pretrained_checkpoint)
        image_encoder.load_state_dict(merged_state_dict, strict=False)

        Total_ACC = 0.0
        for dataset in cfg.datasets:
            # metrics = eval_single_dataset(image_encoder, dataset, cfg)
            # Total_ACC += metrics["top1"]
            # log.info(str(dataset) + ":" + str(metrics))

            # results["scaling_coef"].append(scaling_coef_)
            # results["dataset"].append(dataset)
            # results["acc"].append(metrics["top1"])

            merged_model = image_encoder
            SFT_model = torch.load(pretrained_checkpoint)
            m_check = flat_ptm + scaling_coef_ * tvs[cfg.datasets.index(dataset)]
            m_state_dict = vector_to_state_dict(m_check, ptm_check, remove_keys)
            SFT_model.load_state_dict(m_state_dict, strict=False)

            draw_distribution_tie(
                merged_model=merged_model,
                SFT_model=SFT_model,
                dataloader=shuffled_test_loaders[dataset],
                device=cfg.device,
                dataset_name=dataset,
                type="sequential_ties_merging",
            )

        log.info("Final: " + "Avg ACC:" + str(Total_ACC / len(cfg.datasets)))

    # Save results
    if cfg.sam_retraining:
        save_dir = RESULTS_DIR / "sam_retraining" / cfg.model
    else:
        save_dir = RESULTS_DIR / cfg.model
    os.makedirs(save_dir, exist_ok=True)
    df = pd.DataFrame(results)
    df.to_csv(save_dir / "sequential_ties_merging.csv", index=False)


if __name__ == "__main__":
    main()
