import itertools
from dataclasses import dataclass, field
from typing import Union, Any, cast

import torch
import torchvision
import pandas as pd
import pytorch_lightning as pl

from utils import persistence
from vis_models.architectures import ModelConfig, create_model
from utils.training import (
    TrainingConfig, training_experiment, fine_tuning_experiment
)
from .data import (
    TMDataConfig,
    create_tm_dataset,
    ExpandedDataConfig,
)
from utils.eval import ACCURACY_METRIC


EXP_NAME = "transforms_mismatch"

@dataclass
class TMExperimentConfig:
    exp_name: list[str]
    training: TrainingConfig
    quant_mismatch: TrainingConfig
    # qual_mismatch: TrainingConfig
    order_mismatch: TrainingConfig
    data: TMDataConfig
    model: ModelConfig

@dataclass
class TMExperimentResult:
    config: TMExperimentConfig 
    expanded_data_config: ExpandedDataConfig
    in_dist_performance: pd.DataFrame
    # TODO: rename to transfer_performance
    quant_mismatch_performance: pd.DataFrame
    # qual_mismatch_performance: pd.DataFrame
    order_mismatch_performance: pd.DataFrame


def tm_experiment(config: TMExperimentConfig) -> TMExperimentResult:
    exp_name = persistence.get_experiment_name(
        config.exp_name, config.data.config_seed, config.data.sampling_seed
    )
    print("exp name:", exp_name)

    data_combination = create_tm_dataset(config.data)
    datasets: dict[str, pl.LightningDataModule] = data_combination.data

    print("expanded config:", data_combination.expanded_config)

    models = {
        f"m_{config_name}": create_model(config.model)
        for config_name in data_combination.config_names
    }

    training_res = training_experiment(
        [*exp_name, "training"],
        config.training,
        models,
        datasets,
    )

    quant_mismatch_res = fine_tuning_experiment(
        [*exp_name, "quant_mismatch"],
        config.quant_mismatch,
        model_config=config.model,
        models=training_res.models,
        datasets=datasets,
    )
    # qual_mismatch_res = fine_tuning_experiment(
    #     [*exp_name, "qual_mismatch"],
    #     config.qual_mismatch,
    #     model_config=config.model,
    #     models=training_res.models,
    #     datasets=data_combination.mismatch_data,
    # )
    # Fine-tune and evaluate models on datasets with shuffled
    # transformation order
    order_mismatch_res = fine_tuning_experiment(
        [*exp_name, "shuffle"],
        config.order_mismatch,
        model_config=config.model,
        models={
            model_name: model
            for model_name, model in training_res.models.items()
            if model_name != "m_1"
        },
        datasets=data_combination.shuffle_data,
        eval_cross_roduct=False,
    )

    result = TMExperimentResult(
        config=config,
        expanded_data_config=data_combination.expanded_config,
        in_dist_performance=training_res.metrics[ACCURACY_METRIC],
        quant_mismatch_performance=quant_mismatch_res[ACCURACY_METRIC],
        # qual_mismatch_performance=qual_mismatch_res[ACCURACY_METRIC],
        order_mismatch_performance=order_mismatch_res[ACCURACY_METRIC],
    )
    persistence.save_result(
        exp_name,
        result,
    )
    return result
