import itertools
from dataclasses import dataclass, field
from typing import Union, Any, cast

import torch
import torchvision
import pandas as pd
import pytorch_lightning as pl

from utils import persistence
from vis_models.architectures import ModelConfig, create_model
from utils.training import (
    TrainingConfig, training_experiment, fine_tuning_experiment
)
from .data import (
    TvIDataConfig,
    create_dataset_combination,
    get_item_name,
    ExpandedDataConfig,
)
from utils.eval import ACCURACY_METRIC


EXP_NAME = "transforms_vs_identity"

@dataclass
class TvIExperimentConfig:
    exp_name: list[str]
    training: TrainingConfig
    fine_tuning: TrainingConfig
    data: TvIDataConfig
    model: ModelConfig

@dataclass
class TvIExperimentResult:
    config: TvIExperimentConfig 
    expanded_data_config: ExpandedDataConfig
    in_dist_performance: pd.DataFrame
    # TODO: rename to transfer_performance
    transfer_performance: pd.DataFrame


def tvi_experiment(config: TvIExperimentConfig) -> TvIExperimentResult:
    exp_name = persistence.get_experiment_name(
        config.exp_name, config.data.config_seed, config.data.sampling_seed
    )
    print("exp name:", exp_name)

    data_combination = create_dataset_combination(config.data)
    datasets: dict[str, pl.LightningDataModule] = data_combination.data

    models = {
        get_item_name(False, *data_id): create_model(config.model)
        for data_id in data_combination.data_keys
    }

    training_res = training_experiment(
        [*exp_name, "training"],
        config.training,
        models,
        datasets,
    )

    fine_tuning_res = fine_tuning_experiment(
        [*exp_name, "fine_tuning"],
        config.fine_tuning,
        model_config=config.model,
        models=training_res.models,
        datasets=datasets,
    )

    result = TvIExperimentResult(
        config=config,
        expanded_data_config=data_combination.expanded_config,
        in_dist_performance=training_res.metrics[ACCURACY_METRIC],
        transfer_performance=fine_tuning_res[ACCURACY_METRIC],
    )
    persistence.save_result(
        exp_name,
        result,
    )
    return result
