from dataclasses import dataclass
from typing import Any, Optional, cast
from copy import deepcopy

import torch
import pandas as pd
import pytorch_lightning as pl

from vis_models.architectures import ModelConfig, create_model
from utils import persistence
# from utils.models import create_model
from utils.training import (
    TrainingConfig, training_experiment, fine_tuning_experiment
)
from utils.eval import ACCURACY_METRIC
from .data import (
    IFEDataConfig,
    create_irrelevant_features_dataset,
    create_cifar_only_dataset,
    CIFAR_ONLY_KEY,
    OBJECTS_ONLY_KEY,
    MIXED_CIFAR_KEY,
    MIXED_OBJECTS_KEY,
)
from .invariance_estimation import (
    compute_dataset_representation_distances,
)


EXP_NAME = "irrelevant_feature_extraction"

@dataclass
class IFEExperimentConfig:
    exp_name: list[str]
    training: TrainingConfig
    fine_tuning: TrainingConfig
    data: IFEDataConfig
    model: ModelConfig
    transforms: list[str]
    eval_rep_dist: bool = True
    rep_dist_samples: int = 2048
    train_cifar_only: bool = True

@dataclass
class IFEExperimentResult:
    config: IFEExperimentConfig 
    objects: list[int]
    in_dist_performance: pd.DataFrame
    transfer_performance: pd.DataFrame
    cka_rep_dists: Optional[pd.DataFrame]


CIFAR_VS_NO_CIFAR_COL = "cifar_vs_no_cifar"
PATCH_VS_NO_PATCH_COL = "patch_vs_no_patch"

def ife_experiment(config: IFEExperimentConfig) -> None: 
    exp_name = persistence.get_experiment_name(
        config.exp_name, config.data.config_seed, config.data.sampling_seed
    )
    print("exp name:", exp_name)

    should_train_cifar_only = config.train_cifar_only and config.training.train
    if should_train_cifar_only:
        cifar_only_data, cifar_only_model, cifar_only_accuracy = (
            _get_cifar_only_model(config, exp_name)
        )
    else:
        cifar_only_data, cifar_only_model, cifar_only_accuracy = (
            None, None, None
        )
    for transform_name in config.transforms:
        transform_exp_name = [*exp_name, transform_name]

        irrelevant_features_data = create_irrelevant_features_dataset(
            config.data,
            transform=transform_name,
        )
        training_datasets = irrelevant_features_data.data

        models = {}
        dataset_classes = {CIFAR_ONLY_KEY: 10}
        for data_name in irrelevant_features_data.data_keys:
            model_config = deepcopy(config.model)
            if data_name == MIXED_CIFAR_KEY:
                num_classes = 10
            else:
                num_classes = config.data.n_classes
            model_name = f"m_{data_name}"
            models[model_name] = create_model(model_config)
            dataset_classes[data_name] = num_classes

        training_res = training_experiment(
            [*transform_exp_name, "training"],
            config.training,
            models,
            training_datasets,
        )

        if not should_train_cifar_only:
            # Putting this here again is a trick to delay loading the
            # cifar-only model until after the other models have been
            # trained. This way, if multiple versions of this script
            # run, only one of them has to train that model.
            cifar_only_data, cifar_only_model, cifar_only_accuracy = (
                _get_cifar_only_model(config, exp_name)
            )
        cifar_no_correlation = f"{MIXED_CIFAR_KEY}0"
        transfer_datasets = {
            CIFAR_ONLY_KEY: cifar_only_data,
            cifar_no_correlation: training_datasets[cifar_no_correlation],
            MIXED_OBJECTS_KEY: training_datasets[MIXED_OBJECTS_KEY],
            OBJECTS_ONLY_KEY: training_datasets[OBJECTS_ONLY_KEY],
        }
        pretrained_models = {
            f"m_{CIFAR_ONLY_KEY}": cifar_only_model,
            **training_res.models,
        }
        fine_tune_performance_res = fine_tuning_experiment(
            [*transform_exp_name, "fine_tuning"],
            config.fine_tuning,
            config.model,
            pretrained_models,
            transfer_datasets,
            # TODO: fix this
            dataset_classes=cast(Any, dataset_classes),
        )
        training_accuracies = pd.concat([
            cast(pd.DataFrame, cifar_only_accuracy),
            training_res.metrics[ACCURACY_METRIC],
        ])
        result = IFEExperimentResult(
            config=config,
            objects=irrelevant_features_data.objects,
            in_dist_performance=training_accuracies,
            transfer_performance=fine_tune_performance_res[ACCURACY_METRIC],
            cka_rep_dists=None,
        )
        persistence.save_result(
            transform_exp_name,
            result,
        )

        rep_dist_datasets = {
            PATCH_VS_NO_PATCH_COL: (
                transfer_datasets[CIFAR_ONLY_KEY],
                transfer_datasets[cifar_no_correlation],
            ),
            CIFAR_VS_NO_CIFAR_COL: (
                transfer_datasets[OBJECTS_ONLY_KEY],
                transfer_datasets[MIXED_OBJECTS_KEY],
            )
        }
        cka_rep_dists = compute_dataset_representation_distances(
            config.eval_rep_dist,
            [*transform_exp_name, "rep_dist_eval"],
            # We can use the non-fine-tuned models here because we're only
            # interested in the penultimate layer representations
            pretrained_models,
            config.model,
            rep_dist_datasets,
            dist_metric="linear_cka",
            rep_dist_samples=config.rep_dist_samples,
        )

        result.cka_rep_dists = cka_rep_dists
        persistence.save_result(
            transform_exp_name,
            result,
        )

def _get_cifar_only_model(
    config: IFEExperimentConfig,
    exp_name: list[str]
) -> tuple[pl.LightningDataModule, torch.nn.Module, pd.DataFrame]:
    cifar_only_dataset = create_cifar_only_dataset(config.data.batch_size)
    model_config = deepcopy(config.model)
    model_config.num_classes = 10
    model = create_model(model_config)

    training_config = deepcopy(config.training)
    training_config.train &= config.train_cifar_only
    cifar_only_model_key = f"m_{CIFAR_ONLY_KEY}"
    training_res = training_experiment(
        [*exp_name, "cifar_only_training"],
        training_config,
        {cifar_only_model_key: model},
        {CIFAR_ONLY_KEY: cifar_only_dataset},
    )
    return (
        cifar_only_dataset,
        training_res.models[cifar_only_model_key],
        training_res.metrics[ACCURACY_METRIC],
    )
