import argparse
from data_generation.data_from_dict import load_data
from models.diffusion_model import ConditionalDiffusionModel
import matplotlib.pyplot as plt
import os
import json
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

from models.diffusion_model_resnet import ConditionalResnetDiffusionModel


def parse_args():
    parser = argparse.ArgumentParser(
        description="Train diffusion model with model saving"
    )
    parser.add_argument(
        "--epochs", type=int, default=5000, help="Number of training epochs"
    )
    parser.add_argument("--batch_size", type=int, default=512, help="Batch size")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument(
        "--continue_training",
        action="store_true",
        help="Continue training from a saved model",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default=None,
        help="Path to saved model for continued training",
    )
    parser.add_argument(
        "--save_dir", type=str, default="saved_models", help="Directory to save models"
    )
    parser.add_argument(
        "--visualize_only",
        action="store_true",
        help="Only visualize results without training",
    )
    parser.add_argument(
        "--config_file",
        type=str,
        default=None,
        help="Path to JSON configuration file for data and model parameters",
    )
    return parser.parse_args()


def get_default_config(seed=42):
    # Hyperparameter configuration
    training_hyperparameters = {
        "beta_schedule": {
            "type": "quadratic",
            "min": 0.0001,
            "max": 0.02,
            "timesteps": 100,
        },
        "condition_dim": 4,
        # Used for "plain"
        "layer_sizes": [2048, 1024, 512],
        # Used for "resnet"
        "hidden_dim": 64,
        "num_blocks": 4,
    }

    # Data generation configuration
    data_generation_config = {
        "file_path": None,
        "dictionary": {
            "X": {
                "type": "normal",
                "length": 100_000,
            },
            "transformation": {
                "type": "tanh",
                "args": {
                    "num_hidden": 10,
                    "num_parents": 1,
                },
            },
            "shape": "sequence",
            "depth": 3,
            "noise_type": "normal",
        },
    }

    return training_hyperparameters, data_generation_config


def get_config(args):
    # If a config file is provided, load it
    if args.config_file is not None:
        config = load_config(args.config_file)
    else:
        # Otherwise, build the default config
        training_hyperparameters, data_generation_config = get_default_config(
            seed=args.seed
        )
        config = {
            "epochs": args.epochs,
            "batch_size": args.batch_size,
            "data_config": data_generation_config,
            "training_hyperparameters": training_hyperparameters,
            "validation_split": 0.1,
            "seed": args.seed,
            "number_of_different_mechanisms": 1,
            "model_type": "resnet",  # default model type
        }

    return config


def save_config(config, path):
    with open(path, "w") as f:
        json.dump(config, f, indent=2)


def load_config(path):
    with open(path, "r") as f:
        return json.load(f)


def create_diffusion_model(model_type, config):
    """
    Create a new diffusion model with the given configuration

    Args:
        config (dict): Configuration for model parameters
        data (dict): Data dictionary with training data

    Returns:
        model: A new conditional diffusion model
    """

    if model_type == "plain":
        model = ConditionalDiffusionModel(
            input_dim=1,
            layer_sizes=config["layer_sizes"],
            condition_dim=config["condition_dim"],
            beta_schedule_args=config["beta_schedule"],
        )
    elif model_type == "resnet":
        model = ConditionalResnetDiffusionModel(
            input_dim=1,
            hidden_dim=config["hidden_dim"],
            num_blocks=config["num_blocks"],
            condition_dim=config["condition_dim"],
            beta_schedule_args=config["beta_schedule"],
        )

    return model


def load_model(model_path, config=None):
    """
    Load a model from a checkpoint

    Args:
        model_path (str): Path to the model checkpoint
        config (dict, optional): Configuration for model creation

    Returns:
        model: The loaded model
        checkpoint (dict): Checkpoint information
    """
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location="cpu")

    # If config not provided, try to load from checkpoint's directory
    if config is None:
        try:
            config_dir = os.path.dirname(os.path.dirname(model_path))
            print(f"Loading config from {config_dir}")
            config_path = os.path.join(config_dir, "config.json")
            config = load_config(config_path)
        except:
            raise ValueError(
                f"Config not provided and not found in checkpoint directory {config_path}"
            )

    # Load data to get dimensions
    data = load_data(config["data_config"])

    # Create new model with proper dimensions
    model = create_diffusion_model(
        config["model_type"], config["training_hyperparameters"]
    )

    # Load state dict
    model.load_state_dict(checkpoint["model_state_dict"])

    return model, checkpoint, config


def train_diffusion_model(args):
    """Train the diffusion model with the provided arguments"""
    config = get_config(args)

    # Set up directories
    save_dir = args.save_dir
    seed = config["seed"]
    torch.manual_seed(seed)

    model_dir = os.path.join(save_dir, f"seed_{seed}")
    model_dir_path = os.path.join(model_dir, "models")

    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(model_dir_path, exist_ok=True)

    print(f"Model directory: {model_dir}")

    # Model paths
    y_given_x_model_path = os.path.join(model_dir_path, "y_given_x_best.pt")
    config_path = os.path.join(model_dir, "config.json")

    # Save configuration
    save_config(config, config_path)

    # Load data
    datas = []
    for dgm_seed in range(config["number_of_different_mechanisms"]):
        config["data_config"]["seed"] = dgm_seed
        datas.append(torch.from_numpy(load_data(config["data_config"])))

    data = torch.cat(datas)

    # Create or load model
    if args.continue_training and args.model_path:
        model, _, _ = load_model(args.model_path, config)
    else:
        if "model_type" not in config:
            config["model_type"] = "plain"
        model = create_diffusion_model(
            config["model_type"], config["training_hyperparameters"]
        )

    # Initialize best validation loss
    best_val_loss = float("inf")
    if args.continue_training and args.model_path:
        _, checkpoint, _ = load_model(args.model_path)
        best_val_loss = checkpoint.get("validation_loss", float("inf"))

    print(f"Training with configuration\n{json.dumps(config, indent=4)}")
    # Train the model
    result = train_diffusion(
        model,
        TensorDataset(data[:, 1], data[:, 0]),  # Recover y given x
        epochs=config["epochs"],
        callbacks={
            "epoch_end": lambda epoch, model, metrics: save_checkpoint(
                epoch,
                model,
                metrics,
                model_dir_path,
                y_given_x_model_path,
                best_val_loss,
            )
        },
        validation_split=0.1,
    )

    # Save final metrics
    metrics_path = os.path.join(model_dir, "metrics.json")

    # Try to load existing metrics if continuing training
    if os.path.exists(metrics_path):
        try:
            with open(metrics_path, "r") as f:
                metrics = json.load(f)
        except:
            metrics = {"y_given_x": {}}
    else:
        metrics = {"y_given_x": {}}

    # Update metrics
    metrics["y_given_x"].update(
        {
            f"run_0_to_{config['epochs']}": {
                "final_train_loss": result["losses"][-1],
                "best_train_loss": min(result["losses"]),
                "min_validation_loss": result.get("min_validation_loss", "N/A"),
                "test_loss": result.get("test_loss", "N/A"),
                "end_epoch": config["epochs"],
            }
        }
    )

    with open(metrics_path, "w") as f:
        json.dump(metrics, f, indent=2)

    return result, model_dir


def save_checkpoint(
    epoch, model, metrics, model_dir_path, best_model_path, best_val_loss
):
    """Save model checkpoint and best model if validation improves"""
    val_loss = metrics.get("validation_loss", float("inf"))

    # Save checkpoint
    checkpoint_path = os.path.join(model_dir_path, f"checkpoint_epoch_{epoch}.pt")
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "validation_loss": val_loss,
            "train_loss": metrics.get("train_loss", None),
        },
        checkpoint_path,
    )

    # Check if validation improved
    if val_loss < best_val_loss:
        print(
            f"Epoch {epoch}: Saving best y_given_x model with validation loss: {val_loss:.4f}"
        )
        # Save best model
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "validation_loss": val_loss,
                "train_loss": metrics.get("train_loss", None),
            },
            best_model_path,
        )
        return val_loss  # Return new best validation loss

    return best_val_loss  # Return unchanged best validation loss


def visualize_results(result, model_dir, start_epoch=0):
    """Visualize training results"""

    # Plot training, validation losses
    plt.figure(figsize=(12, 5))

    # Plot Y given X training loss
    plt.subplot(1, 2, 1)
    plt.plot(
        range(start_epoch, start_epoch + len(result["losses"])),
        result["losses"],
        label=f"Train",
        alpha=0.8,
    )
    plt.title("Y given X - Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.legend()

    # Plot validation losses if available
    if "validation_losses" in result:
        plt.subplot(1, 2, 2)
        plt.plot(
            range(start_epoch, start_epoch + len(result["validation_losses"])),
            result["validation_losses"],
            label=f"Val",
            alpha=0.8,
        )
        plt.title("Y given X - Validation Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.yscale("log")
        plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, f"training_curves_{start_epoch}.png"))
    # plt.show()

    # Print metrics
    print(f"Y given X - Final train loss: {result['losses'][-1]:.4f}")
    print(f"Y given X - Best train loss: {min(result['losses']):.4f}")
    print(f"Y given X - Validation loss: {result.get('validation_loss', 'N/A')}")
    print(f"Y given X - Test loss: {result.get('test_loss', 'N/A')}")


def generate_samples_from_model(model, num_samples=5):
    """Generate samples using the loaded model"""
    with torch.no_grad():
        # Generate Y given X
        x = torch.randn(num_samples, model.condition_dim)  # Sample random X
        y = model.sample(x)  # Generate corresponding Y
        return x, y


def main():
    args = parse_args()

    # Get configuration (either from file or default, with possible command-line overrides)
    config = get_config(args)

    # Save configuration for reproducibility (if not continuing training)
    model_dir = os.path.join(args.save_dir, f"seed_{config['seed']}")
    os.makedirs(model_dir, exist_ok=True)
    config_path = os.path.join(model_dir, "config.json")
    save_config(config, config_path)

    # Train the model (pass the config to your training function)
    result, model_dir = train_diffusion_model(args)

    # Visualize results
    visualize_results(result, model_dir)


if __name__ == "__main__":
    main()
