from _common import *

log = logging.getLogger(__name__)

from src.clip_eval import eval_single_dataset
from src.task_vectors import TaskVector
from src.draw_distribution import draw_distribution
from src.datasets.registry import get_dataset
from torch.utils.data import DataLoader


from clip_checkpoint_path import CHECKPOINT_DIR, finetuned_model_path, pretrained_model_path, sam_retraining_model_path


@hydra.main(config_path=str(CONFIG_DIR), config_name="default", version_base=None)
def main(cfg: DictConfig) -> None:
    cfg.save = str(CHECKPOINT_DIR / cfg.model)
    cfg.data_location = str(DATA_DIR)
    model = cfg.model
    pretrained_checkpoint = pretrained_model_path(model)
    pretrained_model: nn.Module = torch.load(pretrained_checkpoint)
    datasets = {
        dataset_name: get_dataset(
            dataset_name, pretrained_model.val_preprocess, location=cfg.data_location, batch_size=4, num_workers=cfg.num_workers
        )
        for dataset_name in cfg.datasets
    }
    shuffled_test_loaders: Dict[str, DataLoader] = {dataset_name: dataset.test_loader_shuffle for dataset_name, dataset in datasets.items()}
    shuffled_test_loader_iters = {dataset_name: iter(itertools.cycle(dataloader)) for dataset_name, dataloader in shuffled_test_loaders.items()}

    if cfg.sam_retraining:
        log.info("SAM retrained model is used")
        _finetuned_model_path = sam_retraining_model_path
    else:
        _finetuned_model_path = finetuned_model_path

    task_vectors = [
        TaskVector(
            pretrained_checkpoint=pretrained_model_path(cfg.model),
            finetuned_checkpoint=_finetuned_model_path(cfg.model, dataset_name),
        )
        for dataset_name in cfg.datasets
    ]
    
    for i in range(len(task_vectors)-1):

        new_task_vector = {}
        for key in task_vectors[i+1].vector:
            new_task_vector[key] = (task_vectors[i+1].vector[key] + task_vectors[i].vector[key]) / 2
        task_vectors[i+1] = TaskVector(vector=new_task_vector)
        

        results = {"scaling_coef": [], "dataset": [], "acc": []}

        #for scaling_coef_ in np.linspace(0, 1, 11):
        image_encoder = task_vectors[i+1].apply_to(pretrained_model_path(cfg.model), scaling_coef=1)
        #log.info("*" * 20 + "scaling_coef:" + str(scaling_coef_) + "*" * 20)
    merged_state_dict = {}
    for key in task_vectors[i+1].vector:
        merged_state_dict[key] = pretrained_model.state_dict()[key] + task_vectors[i+1].vector[key]

    #accs = []
    for dataset_idx, dataset in enumerate(cfg.datasets):
        # metrics = eval_single_dataset(image_encoder, dataset, cfg)
        # log.info(str(dataset) + ":" + str(metrics.get("top1") * 100) + "%")
        # acc = metrics.get("top1")
        # accs.append(metrics.get("top1") * 100)

        # results["scaling_coef"].append(scaling_coef_)
        # results["dataset"].append(dataset)
        # results["acc"].append(acc)
        
        draw_distribution(
                task_vector=task_vectors[dataset_idx].vector,
                merged_state_dict=merged_state_dict,
                pretrained_model=pretrained_model,
                dataloader=shuffled_test_loaders[dataset],
                device=cfg.device,
                dataset_name=dataset,
                type="T-SNE-TA",
            )
        
    log.info("Avg ACC:" + str(np.mean(accs)) + "%")

    if cfg.sam_retraining:
        save_dir = RESULTS_DIR / "sam_retraining" / cfg.model
    else:
        save_dir = RESULTS_DIR / cfg.model
    os.makedirs(save_dir, exist_ok=True)
    df = pd.DataFrame(results)
    df.to_csv(save_dir / "task_arithmetic.csv", index=False)


if __name__ == "__main__":
    main()
