from dataclasses import dataclass, field

import pandas as pd
import pytorch_lightning as pl

# from vis_models.training import adversarial
from vis_models.architectures import ModelConfig, create_model
from utils import persistence
from utils.training import (
    TrainingConfig, training_experiment, fine_tuning_experiment
)
from utils.eval import ACCURACY_METRIC
from .data import CTDataConfig, create_transforms_datasets, get_random_configs


EXP_NAME = "cross_transforms"

@dataclass
class CTExperimentConfig:
    exp_name: list[str]
    training: TrainingConfig
    fine_tuning: TrainingConfig
    data: CTDataConfig
    model: ModelConfig

@dataclass
class CTExperimentResult:
    config: CTExperimentConfig 
    transforms: list[str]
    objects: list[int]
    in_dist_performance: pd.DataFrame
    cross_transforms_performance: pd.DataFrame

    def __post_init__(self) -> None:
        self.cross_transforms_performance = (
            self.cross_transforms_performance.reindex(
                sorted(self.cross_transforms_performance.index),
                axis=0,
            )
        )
        self.cross_transforms_performance = (
            self.cross_transforms_performance.reindex(
                sorted(self.cross_transforms_performance.columns),
                axis=1,
            )
        )


def ct_experiment(
    config: CTExperimentConfig,
) -> CTExperimentResult: 
    exp_name = [
        *config.exp_name
        + [f"cs_{config.data.config_seed}_ss_{config.data.sampling_seed}"]
    ]
    print("exp name:", exp_name)

    cross_transforms_data = create_transforms_datasets(
        config.data
    )
    datasets: dict[str, pl.LightningDataModule] = cross_transforms_data.data

    model_config = config.model
    model_config.num_classes = config.data.n_classes
    models = {
        f"m_{transform_name}": create_model(model_config)
        for transform_name in cross_transforms_data.data_keys
    }
    # attack_configs = [
    #     ("l2", adversarial.AdversarialTrainingConfig("2", 0.5, 0.01, 10)),
    #     ("linf", adversarial.AdversarialTrainingConfig("inf", 0.05, 0.01, 10)),
    # ]
    # training_models = models.copy()
    # for attack_name, attack_conf in attack_configs:
    #     training_models[f"m_{attack_name}"] = adversarial.AdversarialTraining(
    #         model=create_model(
    #             config.model_type, config.data.n_classes, wrap_call=False,
    #         ),
    #         # TODO: use a proper, computed config here
    #         dataset_stats=adversarial.DatasetStats(
    #             torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]),
    #         ),
    #         config=attack_conf,
    #     )

    training_res = training_experiment(
        [*exp_name, "training"],
        config.training,
        # training_models,
        models,
        datasets,
    )
    fine_tune_performance_res = fine_tuning_experiment(
        [*exp_name, "fine_tuning"],
        config.fine_tuning,
        model_config,
        training_res.models,
        datasets,
    )

    result = CTExperimentResult(
        config=config,
        transforms=cross_transforms_data.transforms,
        objects=cross_transforms_data.objects,
        in_dist_performance=training_res.metrics[ACCURACY_METRIC],
        cross_transforms_performance=fine_tune_performance_res[ACCURACY_METRIC],
    )
    persistence.save_result(
        exp_name,
        result,
    )
    return result
