import pandas as pd
import torch
import pytorch_lightning as pl

from vis_models.architectures import ModelConfig
from utils.representations.representation_distances import (
    compute_representations,
    compute_representation_distances,
    PU_LAYER_NAME,
    RepDistanceMetric,
)
from utils import persistence

REP_DIST_RESULT_NAME = "rep_dist_res"

def compute_dataset_representation_distances(
    eval_rep_distances: bool,
    exp_name: list[str],
    models: dict[str, torch.nn.Module],
    model_config: ModelConfig,
    datasets: dict[str, tuple[pl.LightningDataModule, pl.LightningDataModule]],
    dist_metric: RepDistanceMetric,
    rep_dist_samples: int,
) -> pd.DataFrame:
    if eval_rep_distances:
        model_distances = pd.DataFrame(
            index=models.keys(), columns=list(datasets.keys())
        )
        for dataset_combo_name, (dataset_1, dataset_2) in datasets.items():
            print("Computing representation distances for:", dataset_combo_name)
            for model_name, model in models.items():
                model_dist = _compute_mean_distance(
                    model,
                    model_config,
                    dataset_1,
                    dataset_2,
                    dist_metric=dist_metric,
                    rep_dist_samples=rep_dist_samples,
                )
                model_distances.loc[model_name, dataset_combo_name] = model_dist
        persistence.save_result(
            exp_name,
            model_distances,
            result_name=REP_DIST_RESULT_NAME,
        )
    else:
        model_distances = persistence.load_result(
            exp_name,
            REP_DIST_RESULT_NAME,
        )
    return model_distances

def _compute_mean_distance(
    model: torch.nn.Module,
    model_config: ModelConfig,
    dataset_1: pl.LightningDataModule,
    dataset_2: pl.LightningDataModule,
    dist_metric: RepDistanceMetric,
    rep_dist_samples: int,
) -> float:
    reps_1 = compute_representations(
        model, model_config, dataset_1
    )[PU_LAYER_NAME][:rep_dist_samples]
    reps_2 = compute_representations(
        model, model_config, dataset_2
    )[PU_LAYER_NAME][:rep_dist_samples]
    # dists = compute_representation_distances(reps_1, reps_2, "linear_cka")
    dists = compute_representation_distances(reps_1, reps_2, dist_metric)
    return dists.mean().item()
