from dataclasses import dataclass, field
from typing import Optional

import pandas as pd
import pytorch_lightning as pl

from vis_models.architectures import ModelConfig, create_model
from utils import persistence
from utils.training import (
    TrainingConfig, training_experiment
)
from experiments.cross_transforms.data import (
    create_transforms_datasets,
)
from experiments.cross_transforms.experiment import EXP_NAME as CT_EXP_NAME
from .invariance_estimation import evaluate_invariance
from .data import RIDataConfig


EXP_NAME = "representation_impact"
REPRESENTATION_DISTANCE_KEY = "representation_dists"

@dataclass
class RIExperimentConfig:
    exp_name: list[str]
    training: TrainingConfig
    data: RIDataConfig
    model: ModelConfig
    eval_invariance: str
    n_distance_samples: int
    n_object_samples: int
    dist_metric: str

@dataclass
class RIExperimentResult:
    config: Optional[RIExperimentConfig]
    # transforms: list[str]
    # objects: list[int]
    representation_distances: Optional[pd.DataFrame]


def ri_experiment(
    config: RIExperimentConfig,
) -> RIExperimentResult: 
    exp_name = persistence.get_experiment_name(
        config.exp_name, config.data.config_seed, config.data.sampling_seed
    )
    print("exp name:", exp_name)

    cross_transforms_data = create_transforms_datasets(
        config.data
    )
    datasets: dict[str, pl.LightningDataModule] = cross_transforms_data.data

    models = {
        f"m_{transform_name}": create_model(config.model)
        for transform_name in cross_transforms_data.data_keys
    }

    training_res = training_experiment(
        # use the pre-trained models from the CT experiment
        # [CT_EXP_NAME, *exp_name[1:-3], "training"],
        # ct_exp_name = persistence.get_experiment_name(
        #     exp_name, config.data.config_seed, config.data.sampling_seed
        # )
        # TODO: don't hardcode
        [CT_EXP_NAME, "obj2d_resnet-18", *exp_name[2:], "training"],
        config.training,
        models,
        datasets,
    )

    eval_models = {
        "untrained": create_model(config.model),
        **training_res.models,
    }
    rep_dists = evaluate_invariance(
        [*exp_name, REPRESENTATION_DISTANCE_KEY],
        config.eval_invariance,
        eval_models,
        model_config=config.model,
        training_objects=cross_transforms_data.objects,
        data_config=config.data,
        n_distance_samples=config.n_distance_samples,
        n_object_samples=config.n_object_samples,
        dist_metric=config.dist_metric,
    )

    result = RIExperimentResult(
        config=config,
        representation_distances=rep_dists,
    )
    persistence.save_result(
        exp_name,
        result,
    )
    return result
