import json
from pathlib import Path
import torch
import time

from uq_diagcfm.paths import CHECKPOINTS_DIR

RUN_INFO_FILENAME = "run_info.json"
CHECKPOINT_FILENAME = "model_checkpoint.pth"


def name_run(dataset_name: str, run_info: dict) -> str:
    time_stamp = time.strftime("%Y%m%d-%H%M%S")

    # Handle different model types
    model_type = run_info.get("model_type", "CFM")

    if model_type == "INN":
        # INN model naming
        run_name = (
            f"{time_stamp}_INN_"
            f"nb{run_info['num_blocks']}_hd{run_info['hidden_dim']}_sd{run_info['subnet_depth']}_"
            f"bs{run_info['batch_size']}_ep{run_info['epochs']}_lr{run_info['learning_rate']}"
        )
    else:
        # CFM/Diag-CFM model naming (original logic)
        diag_cfm_str = "diagcfm" if run_info["diag_cfm"] else "vanilla"
        shuffle_str = (
            f"shuffled{run_info['shuffle_params_seed']}"
            if run_info.get("shuffle_params_seed") is not None
            else "unshuffled"
        )
        run_name = (
            f"{time_stamp}_{diag_cfm_str}_{shuffle_str}_"
            f"hd{run_info['model_hidden_dimension']}_d{run_info['model_depth']}_"
            f"bs{run_info['batch_size']}_ep{run_info['epochs']}_lr{run_info['learning_rate']}"
        )

    run_path = CHECKPOINTS_DIR / dataset_name / run_name
    run_path.mkdir(parents=True, exist_ok=True)

    return run_path, run_name


def save_run_info(run_path, run_info):
    """
    Save run information to a JSON file.

    Args:
        run_path (str): Path where the run information will be saved.
        run_info (dict): Dictionary containing run information.
    """
    with open(f"{run_path}/{RUN_INFO_FILENAME}", "w") as f:
        json.dump(run_info, f, indent=4)
    return


def load_run_info(run_path):
    """
    Load run information from a JSON file.

    Args:
        run_path (str): Path where the run information is saved.

    Returns:
        dict: Dictionary containing run information.
    """
    with open(f"{run_path}/{RUN_INFO_FILENAME}", "r") as f:
        run_info = json.load(f)
    return run_info


def save_model_checkpoint(run_path, model, filename=None):
    """
    Save model checkpoint.

    Args:
        run_path (str): Path where the model checkpoint will be saved.
        model (torch.nn.Module): Model to be saved.
        filename (str): Optional custom filename. Defaults to CHECKPOINT_FILENAME.
    """
    if filename is None:
        filename = CHECKPOINT_FILENAME
    torch.save(model.state_dict(), f"{run_path}/{filename}")
    return


def save_epoch_checkpoint(run_path, model, epoch):
    """
    Save model checkpoint for a specific epoch.

    Args:
        run_path (str): Path where the model checkpoint will be saved.
        model (torch.nn.Module): Model to be saved.
        epoch (int): Epoch number.
    """
    filename = f"model_checkpoint_epoch{epoch}.pth"
    torch.save(model.state_dict(), f"{run_path}/{filename}")
    return filename


def find_checkpoints_by_criteria(
    criteria: dict, max_results: int = None
) -> list[tuple[Path, dict]]:
    """
    Find checkpoint directories matching the given criteria.

    Args:
        criteria: Dictionary containing criteria to filter runs
        max_results: Maximum number of results to return (None = return all)

    Returns:
        List of (checkpoint_dir, run_info) tuples for matching runs
    """
    matching_checkpoints = []
    dataset_name = criteria.get("dataset")

    if dataset_name:
        # Search in specific dataset directory
        search_paths = [CHECKPOINTS_DIR / dataset_name]
    else:
        # Search all dataset directories
        search_paths = [d for d in CHECKPOINTS_DIR.iterdir() if d.is_dir()]

    for dataset_dir in search_paths:
        if not dataset_dir.exists() or not dataset_dir.is_dir():
            continue

        # Iterate over alphabetically sorted run directories
        for run_dir in sorted(dataset_dir.iterdir()):
            if not run_dir.is_dir():
                continue

            # Check if run folder contains checkpoint file
            checkpoint_path = run_dir / CHECKPOINT_FILENAME
            if not checkpoint_path.exists():
                continue

            run_info = load_run_info(run_dir)

            # Check if run matches all criteria
            match = all(
                key in run_info and run_info[key] == value
                for key, value in criteria.items()
            )

            if match:
                matching_checkpoints.append((run_dir, run_info))
                if max_results and len(matching_checkpoints) >= max_results:
                    return matching_checkpoints

    return matching_checkpoints


def load_run_info_according_to_criteria(criteria: dict) -> list[dict]:
    """
    Load run information from all runs that match the given criteria.

    Args:
        criteria (dict): Dictionary containing criteria to filter runs.

    Returns:
        list[dict]: List of dictionaries containing run information for matching runs.
    """
    checkpoints = find_checkpoints_by_criteria(criteria)
    return [run_info for _, run_info in checkpoints]


def load_ensemble(
    criteria: dict,
    device: str | torch.device = None,
    verbose: bool = True,
    max_models: int = None,
) -> tuple[list[torch.nn.Module], list[dict], list[str]]:
    """
    Load an ensemble of models matching the given criteria.

    Args:
        criteria: Dictionary containing criteria to filter runs (e.g., {'dataset': 'gas_turbine', 'diag_cfm': True})
        device: Device to load models on (defaults to auto-detection)
        verbose: Whether to print loading progress
        max_models: Maximum number of models to load (None = load all matching)

    Returns:
        Tuple of (models, run_infos, checkpoint_names) where:
            - models: List of loaded PyTorch models in eval mode
            - run_infos: List of run info dictionaries for each model
            - checkpoint_names: List of checkpoint directory names
    """
    if device is None:
        from uq_diagcfm.utils import get_device

        device = get_device()

    # Find all checkpoint directories matching criteria
    checkpoints = find_checkpoints_by_criteria(criteria, max_results=max_models)

    if len(checkpoints) == 0:
        return ([], [], [])

    # Determine dataset type from first matching run
    _, first_run_info = checkpoints[0]
    dataset_type = first_run_info["dataset"]

    # Import appropriate model builder based on dataset type
    if dataset_type == "gas_turbine":
        from uq_diagcfm.models_for_datasets import (
            models_for_gas_turbine as model_builder,
        )
    elif dataset_type == "unifoil":
        from uq_diagcfm.models_for_datasets import models_for_unifoil as model_builder
    elif dataset_type == "dtlz":
        from uq_diagcfm.models_for_datasets import models_for_dtlz

        model_builder = None  # Will be handled specially in the loop
    else:
        raise ValueError(f"Unknown dataset type: {dataset_type}")

    if verbose:
        print(f"Loading {len(checkpoints)} models for {dataset_type} ensemble...")

    models = []
    run_infos = []
    checkpoint_names = []

    for ckpt_dir, run_info in checkpoints:
        checkpoint_names.append(ckpt_dir.name)
        checkpoint_path = ckpt_dir / CHECKPOINT_FILENAME
        run_infos.append(run_info)

        # Check if this is an INN model
        is_inn = run_info.get("model_type", "") == "INN"

        # Create model with appropriate architecture
        if is_inn:
            # INN model
            from uq_diagcfm.models_for_datasets import (
                inn_for_gas_turbine,
                conditional_inn_for_unifoil,
                inn_for_dtlz,
            )

            if dataset_type == "gas_turbine":
                model = inn_for_gas_turbine(
                    num_blocks=run_info["num_blocks"],
                    hidden_dim=run_info["hidden_dim"],
                    subnet_depth=run_info["subnet_depth"],
                    clamp=run_info.get("clamp", 2.0),
                    activation=run_info.get("activation", "LeakyReLU"),
                )
            elif dataset_type == "unifoil":
                # Unifoil uses ConditionalINN due to conditioning on physical params
                model = conditional_inn_for_unifoil(
                    num_blocks=run_info["num_blocks"],
                    hidden_dim=run_info["hidden_dim"],
                    subnet_depth=run_info["subnet_depth"],
                    clamp=run_info.get("clamp", 2.0),
                    activation=run_info.get("activation", "LeakyReLU"),
                )
            elif dataset_type == "dtlz":
                model = inn_for_dtlz(
                    num_design_params=run_info["num_design_params"],
                    num_objectives=run_info["num_objectives"],
                    num_blocks=run_info["num_blocks"],
                    hidden_dim=run_info["hidden_dim"],
                    subnet_depth=run_info["subnet_depth"],
                    clamp=run_info.get("clamp", 2.0),
                    activation=run_info.get("activation", "LeakyReLU"),
                )
            else:
                raise ValueError(f"Unknown dataset type for INN: {dataset_type}")
        elif dataset_type == "dtlz":
            # DTLZ requires extra dimension parameters from run_info
            model = models_for_dtlz(
                diag_cfm=run_info.get("diag_cfm", True),
                model_hidden_dimension=run_info["model_hidden_dimension"],
                model_depth=run_info["model_depth"],
                dropout=run_info.get("dropout", 0.0),
                model_activation=run_info.get("model_activation", "LeakyReLU"),
                num_design_params=run_info["num_design_params"],
                num_objectives=run_info["num_objectives"],
            )
        else:
            model = model_builder(
                diag_cfm=run_info.get("diag_cfm", True),
                model_hidden_dimension=run_info["model_hidden_dimension"],
                model_depth=run_info["model_depth"],
                dropout=run_info.get("dropout", 0.0),
                model_activation=run_info.get("model_activation", "LeakyReLU"),
            )

        # Load weights and prepare for inference
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        model = model.to(device)
        model.eval()
        models.append(model)

        if verbose:
            print(f"  Loaded: {ckpt_dir.name}")

    if verbose:
        print(f"Successfully loaded {len(models)} models")

    return models, run_infos, checkpoint_names


if __name__ == "__main__":
    import sys

    if len(sys.argv) == 2 and sys.argv[1] == "run_info_with_criteria":
        criteria = {
            "epochs": 20,
        }
        run_infos = load_run_info_according_to_criteria(criteria)
        for run_info in run_infos:
            print(run_info["diag_cfm"], run_info["shuffle_params_seed"])

    elif len(sys.argv) == 2 and sys.argv[1] == "load_ensemble":
        criteria = {
            "dataset": "gas_turbine",
            "diag_cfm": True,
            "shuffle_params_seed": None,
            "epochs": 20,
        }
        models, run_infos = load_ensemble(criteria, verbose=True)
        print(f"Loaded {len(models)} models for ensemble.")
