"""Out-of-distribution Detection Experiment."""

import copy
import sys
from pathlib import Path

sys.path.insert(0, "../../experiments")
sys.path.insert(0, "analysis/experiments")
sys.path.insert(0, "../ood_detection")

import datasets
import fire
import inferno
import lightning as L
import matplotlib.pyplot as plt
import custom_models as models
import torch
import wandb
from lightning.pytorch.tuner import Tuner


def train(
    seed: int,
    dataset: str,
    ood_dataset: str,
    model: str,
    hidden_size: int,
    parametrization: str,
    max_epochs: int,
    lr: float,
    batch_size: int,
    batch_size_test: int | None = 1024,
    bias: bool = True,
    momentum: float = 0.0,
    nesterov: bool = False,
    weight_decay: float = 0.0,
    precision: str = "32-true",
    pin_memory: bool = True,
    num_workers: int = 0,
    tune_lr: bool = False,
    tune_batch_size: bool = False,
    root_dir: str = Path.cwd(),
    data_dir: str = Path.cwd() / "../datasets",
    log_to_wandb: bool = True,
    log_to_csv: bool = True,
    experiment_name: str = "hyperparameter_transfer",
    project_name: str = "implicit_vi",
    scale_mean_input_init_weight: float=1.0,
    scale_mean_input_init_bias: float=1.0,
    scale_mean_input_lr_weight: float=1.0,
    scale_mean_input_lr_bias: float=1.0,
    scale_mean_input_forward_weight: float=1.0,
    scale_mean_input_forward_bias: float=1.0,
    scale_mean_output_init_weight: float=1.0,
    scale_mean_output_init_bias: float=1.0,
    scale_mean_output_lr_weight: float=1.0,
    scale_mean_output_lr_bias: float=1.0,
    scale_mean_output_forward_weight: float=1.0,
    scale_mean_output_forward_bias: float=1.0,
    scale_cov_input_init_weight: float=1.0,
    scale_cov_input_init_bias: float=1.0,
    scale_cov_input_lr_weight: float=1.0,
    scale_cov_input_lr_bias: float=1.0,
    scale_cov_input_forward_weight: float=1.0,
    scale_cov_input_forward_bias: float=1.0,
    scale_cov_output_init_weight: float=1.0,
    scale_cov_output_init_bias: float=1.0,
    scale_cov_output_lr_weight: float=1.0,
    scale_cov_output_lr_bias: float=1.0,
    scale_cov_output_forward_weight: float=1.0,
    scale_cov_output_forward_bias: float=1.0,
):
    """Train a model for out-of-distribution detection."""

    # More appropriate local variable names
    model_name = model
    dataset_name = dataset
    ood_dataset_name = ood_dataset

    # Random state
    L.seed_everything(seed)

    # Datasets
    generator_data = torch.Generator().manual_seed(seed + 235347894)
    dataset = eval("datasets." + dataset_name)(
        batch_size=batch_size,
        batch_size_test=batch_size_test,
        data_dir=data_dir,
        num_workers=num_workers,
        pin_memory=pin_memory,
        generator=generator_data,
        data_augmentation_transform = None,
    )
    ood_dataset = datasets.OODDataset(
        id_dataset=copy.deepcopy(dataset),
        ood_dataset=eval("datasets." + ood_dataset_name)(
            batch_size=batch_size,
            batch_size_test=batch_size_test,
            data_dir=data_dir,
            num_workers=num_workers,
            pin_memory=pin_memory,
            generator=generator_data,
            data_augmentation_transform = None,
        ),
    )
    ood_dataset.prepare_data()

    # Model
    model = eval("models." + model_name).from_dataset(
        dataset=dataset,
        parametrization=eval("inferno.bnn.params." + parametrization)(),
        #parametrization=eval("Custom" + parametrization)(),
        lr=lr,
        momentum=momentum,
        nesterov=nesterov,
        weight_decay=weight_decay,
        max_epochs=max_epochs,
        seed=seed,
        root_dir=root_dir,
        hidden_sizes=[hidden_size, hidden_size],
        bias=bias,
        scale_mean_input_init_weight = scale_mean_input_init_weight,
        scale_mean_input_init_bias = scale_mean_input_init_bias,
        scale_mean_input_lr_weight = scale_mean_input_lr_weight,
        scale_mean_input_lr_bias = scale_mean_input_lr_bias,
        scale_mean_input_forward_weight = scale_mean_input_forward_weight,
        scale_mean_input_forward_bias = scale_mean_input_forward_bias,
        scale_mean_output_init_weight = scale_mean_output_init_weight,
        scale_mean_output_init_bias = scale_mean_output_init_bias,
        scale_mean_output_lr_weight = scale_mean_output_lr_weight,
        scale_mean_output_lr_bias = scale_mean_output_lr_bias,
        scale_mean_output_forward_weight = scale_mean_output_forward_weight,
        scale_mean_output_forward_bias = scale_mean_output_forward_bias,
        scale_cov_input_init_weight = scale_cov_input_init_weight,
        scale_cov_input_init_bias = scale_cov_input_init_bias,
        scale_cov_input_lr_weight = scale_cov_input_lr_weight,
        scale_cov_input_lr_bias = scale_cov_input_lr_bias,
        scale_cov_input_forward_weight = scale_cov_input_forward_weight,
        scale_cov_input_forward_bias = scale_cov_input_forward_bias,
        scale_cov_output_init_weight = scale_cov_output_init_weight,
        scale_cov_output_init_bias = scale_cov_output_init_bias,
        scale_cov_output_lr_weight = scale_cov_output_lr_weight,
        scale_cov_output_lr_bias = scale_cov_output_lr_bias,
        scale_cov_output_forward_weight = scale_cov_output_forward_weight,
        scale_cov_output_forward_bias = scale_cov_output_forward_bias,
    )

    # Logging
    loggers = {}
    logging_dir = (
        Path("training_logs")
        / Path(dataset.__class__.__name__)
        / model_name
        / parametrization
        / Path(f"hidden_size_{hidden_size}")
        / Path(f"lr_{lr}")
        / Path(f"seed_{seed}")
    )
    if log_to_csv:
        loggers["csv"] = L.pytorch.loggers.CSVLogger(
            save_dir=root_dir, name=logging_dir
        )

    if log_to_wandb:
        loggers["wandb"] = L.pytorch.loggers.WandbLogger(
            save_dir=root_dir,
            project=project_name,
            log_model=False,
            group=experiment_name,
            config={
                "OOD Dataset": ood_dataset_name,
                "Seed": seed,
            },
        )

    # Trainer setup
    accelerator = "auto"
    if "Laplace" in model_name:  # Laplace is not compatible with "mps" accelerator
        accelerator = "gpu" if torch.cuda.is_available() else "cpu"

    trainer = L.Trainer(
        accelerator=accelerator,  # Select accelerator ("cpu", "gpu", "mps", etc.)
        devices="auto",
        precision=precision,
        logger=loggers.values(),
        callbacks=[
            L.pytorch.callbacks.ModelSummary(max_depth=-1),
        ],
        max_epochs=max_epochs,
        default_root_dir=logging_dir,
    )
    tuner = Tuner(trainer)

    if tune_batch_size:
        # Select batch size based on available memory
        new_batch_size = tuner.scale_batch_size(model, datamodule=dataset, mode="power")

    if tune_lr:
        # Select initial learning rate
        lr_finder = tuner.lr_find(model, datamodule=dataset)

        lr_finder.plot(suggest=True, show=False)
        if log_to_csv:
            plt.savefig(Path(root_dir / logging_dir / "lr_finder.png"))

    # Training
    dataset.setup(stage="fit")
    trainer.validate(model, datamodule=dataset)  # Measure performance prior to training
    trainer.fit(model, datamodule=dataset)

    # Test
    dataset.setup(stage="test")
    ood_dataset.setup(stage="test")
    trainer.test(model, (dataset.test_dataloader(), ood_dataset.test_dataloader(), dataset.val_dataloader()))

    # Update logging configuration
    if log_to_wandb:
        if tune_lr:
            lr_finder_plot_data = [
                [lr, loss]
                for lr, loss in zip(lr_finder.results["lr"], lr_finder.results["loss"])
            ]
            lr_finder_table = wandb.Table(
                data=lr_finder_plot_data,
                columns=["Learning Rate", "Loss"],
            )

            wandb.log(
                {
                    "lr_finder": wandb.plot.line(
                        lr_finder_table, "Learning Rate", "Loss", title="LR Finder"
                    )
                }
            )
            wandb.config.update({"lr": lr_finder.suggestion()}, allow_val_change=True)
        if tune_batch_size:
            wandb.config.update({"batch_size": new_batch_size}, allow_val_change=True)


if __name__ == "__main__":
    fire.Fire(train)
