import json
import os

from numpy import mean, sqrt
import torch
from torch.utils.data import ConcatDataset
from sacred import Experiment
from sacred.observers import FileStorageObserver
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from models.diffusion_model import ConditionalDiffusionModel
from models.training import (
    evaluate_full_diffusion_loss,
    train_diffusion,
)

from torch.utils.data import DataLoader, TensorDataset, random_split

from data_generation.data_from_dict import load_data
from models.diffusion_model import (
    ConditionalDiffusionModel,
)
from models.diffusion_model_resnet import ConditionalResnetDiffusionModel
from models.model_utils import save_checkpoint, save_checkpoint_if_validation_improved
from train_large_diffusion import load_model

ex = Experiment("diffusion_discovery")

# Add a file storage observer to save results to a directory
ex.observers.append(FileStorageObserver("results/bidd"))


@ex.config
def config():
    epochs = 200
    batch_size = 4048 * 100
    verbose = False
    training_hyperparameters = {
        "beta_schedule": {
            "type": "linear",
            "min": 0.0001,
            "max": 0.02,
            "timesteps": 256,
        },
        "condition_embed": True,
        "condition_dim": 4,
        "train_loss": "mse",
        "test_loss": "mse_and_hsic",
        "train_loss_schedule": {
            "min_beta": 0,
            "max_beta": 0,
            "start_beta": 0,
            "type": "constant",  # constant, cyclical(M, R), linear, adaptive, delayed_linear(M, R)
            "mi_threshold": 0.001,
            "adaptive_factor": 1.0001,
            "M": 1,
            "R": 0.1,
            "hold_fraction": 0.5,
        },
        # Used for "plain"
        "layer_sizes": [128, 64],
        # Used for "resnet"
        "hidden_dim": 512,
        "num_blocks": 2,
        "learning_rate": 1e-4,
        "optimizer_settings": {
            "optimizer": "adamw",  # adam, adamw,
            "scheduler": "cosine",  # cosine, none
            "scheduler_args": {
                "eta_min": 1e-5,
            },
        },
    }

    data_config = {
        "file_path": None,
        "name": None,
        "id": None,
        "n_samples": 1000,
        "dictionary": {
            "X": {
                "type": "uniform",
                "length": 1500,
            },
            "transformation": {
                "type": "neural_network",
                "args": {
                    "num_hidden": 10,
                    "num_parents": 1,
                    # "alpha": 1,
                },
            },
            "shape": "one_mediator",
            "depth": 2,
            "seed": 3,
            "standardize": True,
            "noise_type": "uniform",
            "noise_parameters": None,
            "mediator_noise_type": "uniform",
            "mediator_noise_parameters": {
                "mean": 0,
                "std": sqrt(0.5),
            },
        },
    }
    model_path = None  # "results/big/19/last_model_y_given_x.pt"
    model_type: str = "resnet"
    use_best_validation = False
    validation_split = 0.0
    train_test_split = 0.8
    store_prediction = False
    test_evaluation_passes = 10
    different_test_sets = False


@ex.automain
def main(
    _run,
    data_config,
    batch_size,
    epochs,
    training_hyperparameters,
    train_test_split,
    model_path,
    use_best_validation,
    validation_split,
    model_type: str = "resnet",
    store_prediction: bool = False,
    save_last_model: bool = False,
    test_evaluation_passes: int = 5,
    different_test_sets: bool = False,
):
    data = torch.from_numpy(load_data(data_config))

    print(f"Data config: {data_config}")
    print(f"Training hyperparemeters: {training_hyperparameters}")

    train_indices, test_indices = random_split(
        range(len(data)), [train_test_split, 1 - train_test_split]
    )

    train = data[train_indices]
    test = data[test_indices]

    results = []
    for train, test, direction in [
        [
            TensorDataset(train[:, 0], train[:, 1]),
            TensorDataset(test[:, 0], test[:, 1]),
            "x_given_y",
        ],
        [
            TensorDataset(train[:, 1], train[:, 0]),
            TensorDataset(test[:, 1], test[:, 0]),
            "y_given_x",
        ],
    ]:

        if model_path is None:
            if model_type == "plain":
                model = ConditionalDiffusionModel(
                    input_dim=1,
                    layer_sizes=training_hyperparameters["layer_sizes"],
                    condition_dim=training_hyperparameters["condition_dim"],
                    beta_schedule_args=training_hyperparameters["beta_schedule"],
                )
            elif model_type == "resnet":
                model = ConditionalResnetDiffusionModel(
                    input_dim=1,
                    hidden_dim=training_hyperparameters["hidden_dim"],
                    num_blocks=training_hyperparameters["num_blocks"],
                    condition_dim=training_hyperparameters["condition_dim"],
                    beta_schedule_args=training_hyperparameters["beta_schedule"],
                    condition_embed=training_hyperparameters["condition_embed"],
                )
            else:
                raise ValueError("Invalid model type")
        else:
            with open(f"{os.path.dirname(model_path)}/config.json", "r") as f:
                model_config = json.load(f)
            model, _, config = load_model(model_path, config=model_config)
            print(f"Loaded model from {model_path}")

        # Finetune model
        model_dir_path = "saved_models/unnamed"
        best_model_path = "saved_models/unnamed/best_model.pt"

        callbacks = {}
        if use_best_validation:
            callbacks["epoch_end"] = (
                lambda epoch, model, metrics: save_checkpoint_if_validation_improved(
                    epoch,
                    model,
                    metrics,
                    model_dir_path,
                    best_model_path,
                )
            )

        callbacks["store_validation_prediction"] = lambda epoch, data: None
        if store_prediction:

            def store_validation_prediction(epoch, data):
                path = os.path.join(
                    model_dir_path,
                    f"epoch_{epoch}_predictions_{direction}.pt",
                )
                torch.save(
                    data,
                    path,
                )
                _run.add_artifact(path)

            callbacks["store_validation_prediction"] = store_validation_prediction

        result = train_diffusion(
            model,
            training_hyperparameters["train_loss"],
            training_hyperparameters["test_loss"],
            training_hyperparameters["train_loss_schedule"],
            train,
            epochs,
            callbacks=callbacks,
            lr=training_hyperparameters["learning_rate"],
            optimizer_settings=training_hyperparameters["optimizer_settings"],
            batch_size=batch_size,
            validation_split=validation_split,
        )

        if save_last_model:
            path = f"saved_models/unnamed/last_model_{direction}.pt"
            torch.save(
                {
                    "epoch": epochs,
                    "model_state_dict": model.state_dict(),
                },
                path,
            )
            _run.add_artifact(path)

        if different_test_sets:
            for test_size in [5, 10, 50, 100, 200]:
                if len(test) <= test_size:
                    print(f"Test set is too small for size {test_size}")
                test_subset, _ = random_split(test, [test_size, len(test) - test_size])
                repeated_dataset = ConcatDataset([test_subset] * test_evaluation_passes)
                test_dataloader = DataLoader(
                    repeated_dataset, batch_size, shuffle=False
                )
                evaluation = evaluate_full_diffusion_loss(
                    model,
                    training_hyperparameters["test_loss"],
                    test_dataloader,
                )
                print(f"Test loss for size {len(repeated_dataset)}: {evaluation}")
                result[f"test_loss_last_model_{test_size}"] = evaluation

        prediction_save_path = None
        if store_prediction:
            prediction_save_path = os.path.join(
                model_dir_path,
                f"full_data_epoch_{epochs}_predictions_{direction}",
            )
        full_dataloader = DataLoader(
            ConcatDataset([train, test]), batch_size, shuffle=False
        )
        evaluation = evaluate_full_diffusion_loss(
            model,
            training_hyperparameters["test_loss"],
            full_dataloader,
            prediction_save_path=prediction_save_path,
            add_artifact=_run.add_artifact,
        )
        print(f"Test loss for size {len(full_dataloader)}: {evaluation}")
        result[f"test_loss_train_and_test_last_model"] = evaluation

        if train_test_split < 1.0 and not different_test_sets:
            repeated_dataset = ConcatDataset([test] * test_evaluation_passes)
            test_dataloader = DataLoader(repeated_dataset, batch_size, shuffle=False)
            prediction_save_path = None
            if store_prediction:
                prediction_save_path = os.path.join(
                    model_dir_path,
                    f"test_data_epoch_{epochs}_predictions_{direction}",
                )
            result["test_full_loss_last_model"] = evaluate_full_diffusion_loss(
                model,
                training_hyperparameters["test_loss"],
                test_dataloader,
                prediction_save_path=prediction_save_path,
                add_artifact=_run.add_artifact,
            )

        if use_best_validation and train_test_split < 1.0 and not different_test_sets:
            print(f"Loading best model from {best_model_path}")
            model.load_state_dict(
                torch.load(best_model_path, weights_only=False)["model_state_dict"]
            )

            repeated_dataset = ConcatDataset([test] * test_evaluation_passes)
            test_dataloader = DataLoader(repeated_dataset, batch_size, shuffle=False)

            result["test_full_loss_best_model"] = evaluate_full_diffusion_loss(
                model,
                training_hyperparameters["test_loss"],
                test_dataloader,
                test_evaluation_passes,
                prediction_save_path=prediction_save_path,
            )

        if different_test_sets and use_best_validation:
            for test_size in [5, 10, 50, 100, 200]:
                test_subset, _ = random_split(test, [test_size, len(test) - test_size])
                repeated_dataset = ConcatDataset([test_subset] * test_evaluation_passes)
                test_dataloader = DataLoader(
                    repeated_dataset, batch_size, shuffle=False
                )
                evaluation = evaluate_full_diffusion_loss(
                    model,
                    training_hyperparameters["test_loss"],
                    test_dataloader,
                )
                print(f"Test loss for size {len(repeated_dataset)}: {evaluation}")
                result[f"test_loss_best_model_{test_size}"] = evaluation

        results.append(result)

    return {
        "results_recovering_x_given_y": results[0],
        "results_recovering_y_given_x": results[1],
    }
