import importlib
import numpy as np
import os
import torch
import logging
import random
from glob import glob
import yaml
import sys
from enum import Enum
from types import SimpleNamespace
from torchinfo import summary
import ast


def load_config(path):
    try:
        with open(path, "r") as f:
            return yaml.safe_load(f)
    except FileNotFoundError:
        print(f"Config file not found: {path}", file=sys.stderr)
        sys.exit(1)


def setup_experiment_dir(args):
    # Setup an experiment folder:
    os.makedirs(
        args.results_dir, exist_ok=True
    )

    paths_glob = glob(f"{args.results_dir}/*")
    if len(paths_glob) == 0:
        experiment_index = 0
    else:
        experiment_index = (
            max([int(name.split("/")[-1].split("-")[0]) for name in paths_glob]) + 1
        )
    exp_dir_name = args.exp_dir_name.replace("/", "-")
    experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{exp_dir_name}"  # Create an experiment folder
    checkpoint_dir = f"{experiment_dir}/checkpoints"  # Stores saved model checkpoints
    os.makedirs(checkpoint_dir, exist_ok=True)
    return experiment_dir, checkpoint_dir


def create_logger(logging_dir, filename="log.log"):
    """
    Create a logger that writes to a log file and stdout.
    """
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)-5.5s]  %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(f"{logging_dir}/{filename}", mode="a+"),
        ],
        force=True,  # Force the logger to overwrite the file if it exists
        encoding="utf-8",
    )
    return logging.getLogger(__name__)


def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def instantiate_model(
    args: SimpleNamespace,
    timefreq_shape: torch.Size,
    tmps_covariates_shape: torch.Size,
    **model_kwargs: dict,
):
    """
    Instantiate a model based on the provided arguments.

    Args:
        args (Namespace): Arguments containing model configuration.
        timefreq_shape (torch.Size): Shape of the time-frequency data.
        tmps_covariates_shape (torch.Size): Shape of the temporal covariates data.

    Returns:
        torch.nn.Module: The instantiated model.
        args (Namespace): The updated arguments with model-specific attributes.
    """

    # B = batch size, T = time dimension, C = channels, F = frequency dimension, W = number of covariates
    B, T, C, F, W = timefreq_shape

    args.model.num_classes = 1
    args.model.num_frames = T
    args.model.in_channels = C
    args.model.tc_dim = (
        tmps_covariates_shape[-1] if args.time_covariates else 0
    )  # setting tc_dim to 0 deactivates the time covariates

    args.model.input_size = (
        F,
        W,
    )  # input_size is the frequency and covariates dimensions
    if args.model.patch_size is not None:
        if isinstance(args.model.patch_size, str):
            args.model.patch_size = ast.literal_eval(args.model.patch_size)
        elif isinstance(args.model.patch_size, tuple):
            args.model.patch_size = tuple(args.model.patch_size)
        else:
            raise ValueError(
                f"Invalid patch_size: {args.model.patch_size}. It should be a tuple or a string representation of a tuple."
            )
    else:
        args.model.patch_size = (F, 1)

    model_args: dict = vars(args.__dict__.pop("model"))
    model_name = model_args.pop("name", None)  # Remove the model name if it exists
    # we need to extract the learn_sigma argument from the gaussian_diffusion args
    model_args["learn_sigma"] = args.gaussian_diffusion.learn_sigma
    if model_name is None:
        raise ValueError("Model name must be provided in the arguments.")

    try:
        model: torch.nn.Module = getattr(importlib.import_module("models"), model_name)(
            **model_args, **model_kwargs
        )
    except Exception as e:
        raise ValueError(f"Failed to instantiate model '{model_name}': {e}")

    # add the model_args to args as flattened attributes
    for key, value in model_args.items():
        setattr(args, key, value)


    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    model = model.to(args.device)
    print(f"Model {model.__class__.__name__} instantiated")
    
    with torch.amp.autocast(device_type=args.device, dtype=dtype):
        args.model_summary = summary(
            model,
            input_data=[
                torch.rand((args.batch_size, args.num_frames, args.in_channels, *args.input_size)), # input
                torch.randint(0, args.gaussian_diffusion.diffusion_steps, (args.batch_size,)), # t
                torch.rand((args.batch_size, args.num_frames, tmps_covariates_shape[-1])) # tc
            ],
            verbose=0,
            col_names=["input_size", "output_size", "num_params"],
            device=args.device,
            mode="train", # Use "train" to include training-specific layers like Dropout and to enable forward/backward size computation
        )

    if args.compile:
        print("Compiling the model with torch.compile")
        # 'max-autotune' gives the best performance but has a longer compile time for the first batch.
        model = torch.compile(model, mode="max-autotune")

    return model, args


def instantiate_lrsched_optim(model, args):
    """
    Instantiate a learning rate scheduler based on the provided arguments.

    Args:
        model (torch.nn.Module): The model for which the optimizer is created.
        args (Namespace): Arguments containing scheduler configuration.

    Returns:
        torch.optim.lr_scheduler._LRScheduler: The instantiated scheduler.
        torch.optim.Optimizer: The instantiated optimizer.
    """
    scheduler_args: dict = vars(args.scheduler)
    scheduler_name = scheduler_args.pop(
        "name", None
    )  # Remove the scheduler name if it exists
    if scheduler_name is None:
        raise ValueError("Scheduler name must be provided in the arguments.")
    scheduler_name = scheduler_name.lower()

    match scheduler_name:
        case "cos_ann_warmup_restarts":
            from lr_sched import CosineAnnealingWarmupRestarts
            scheduler_class = CosineAnnealingWarmupRestarts
        case "null":
            # This is a no-op scheduler, it does not change the learning rate.
            from lr_sched import NullScheduler
            scheduler_class = NullScheduler
            scheduler_args = {}
        case _:
            raise ValueError(
                f"Unknown scheduler type: {scheduler_name}. Supported types are: 'cos_ann_warm_restarts', 'cos_ann', 'reduce_on_plateau', 'cos_ann_warmup'."
            )

    opt_name = vars(args).get("optimizer", "adamw").lower()
    match opt_name:
        case "adamw":
            opt_class = torch.optim.AdamW
        case _:
            raise ValueError(
                f"Unknown optimizer type: {args.optimizer}. Supported types are: 'adamw'."
            )

    if scheduler_name == "cos_ann_warmup_restarts":
        optimizer = opt_class(
            model.parameters(),
            lr=scheduler_args["max_lr"],
            weight_decay=args.weight_decay,
        )
    else:
        optimizer = opt_class(
            model.parameters(),
            lr=args.maxlr,
            weight_decay=args.weight_decay,
        )

    return scheduler_class(optimizer=optimizer, **scheduler_args), optimizer


def find_latest_checkpoint(checkpoint_dir, return_int=True):
    """
    Find the latest checkpoint in the given directory.
    Returns the path to the latest checkpoint file or None if no checkpoints are found.
    If return_int is True, returns the checkpoint number as an integer.
    If return_int is False, returns the full path to the checkpoint file.
    """

    checkpoints = glob(os.path.join(checkpoint_dir, "*.pt"))
    # filter only the values castabe to integers
    checkpoints = [
        ckpt for ckpt in checkpoints if ckpt.split("/")[-1].split(".")[0].isdigit()
    ]
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("/")[-1].split(".")[0]))
    if not checkpoints:
        return None
    if return_int:
        return int(checkpoints[-1].split("/")[-1].split(".")[0])
    else:
        return checkpoints[-1]


class EMA:
    def __init__(self, model, decay):
        """
        Initialize EMA class to manage exponential moving average of model parameters.

        Args:
            model (torch.nn.Module): The model for which EMA will track parameters.
            decay (float): Decay rate, typically a value close to 1, e.g., 0.999.
        """
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

        # Store initial parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        """
        Update shadow parameters with exponential decay.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_average = (
                    1.0 - self.decay
                ) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        """
        Apply shadow (EMA) parameters to model.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self):
        """
        Restore original model parameters from backup.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]

    def state_dict(self):
        """
        Get the state dictionary of the EMA.

        Returns:
            dict: A dictionary containing the shadow parameters.
        """
        return self.shadow

    def load_state_dict(self, state_dict, strict=True):
        """
        Load the state dictionary into the EMA.

        Args:
            state_dict (dict): The state dictionary to load.
            strict (bool): Whether to enforce that the keys in the state_dict match the EMA's shadow parameters.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                if name not in state_dict:
                    if strict:
                        raise KeyError(f"Key {name} not found in provided state_dict.")
                    else:
                        print(
                            f"Warning: Key {name} not found in provided state_dict. Skipping."
                        )
                        continue
                else:
                    # If the key exists, we clone the value from the state_dict
                    # to ensure we don't modify the original state_dict.
                    self.shadow[name] = state_dict[name].clone()


def is_model_compiled(model: torch.nn.Module) -> bool:
    """
    Checks if a model has been wrapped by torch.compile.
    """
    # The most reliable check is to see if the _orig_mod attribute exists.
    return hasattr(model, "_orig_mod")


def is_checkpoint_from_compiled_model(state_dict: dict) -> bool:
    """
    Checks if a checkpoint state_dict is from a model wrapped by torch.compile.
    This is done by checking for the '_orig_mod.' prefix in the keys.
    """
    return any(key.startswith("_orig_mod.") for key in state_dict.keys())


class CkptKey(Enum):
    MODEL = "model"
    EMA = "ema"
    OPTIMIZER = "opt"
    SCHEDULER = "scheduler"


def prepare_state_dict(
    state_dict: dict, prefix: str, model_compiled: bool, ckpt_compiled: bool
) -> dict:
    if ckpt_compiled and not model_compiled:
        # Remove the compile prefix if it exists
        cleaned_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items()}
    elif model_compiled and not ckpt_compiled:
        # Add the compile prefix to the state_dict keys
        cleaned_state_dict = {f"{prefix}{k}": v for k, v in state_dict.items()}
    else:
        # If both are compiled or neither is compiled, use the state_dict as is.
        cleaned_state_dict = state_dict

    return cleaned_state_dict


def load_checkpoint(
    checkpoint_path: str,
    model: torch.nn.Module,
    ema: EMA = None,
    optimizer: torch.optim.Optimizer = None,
    scheduler: torch.optim.lr_scheduler._LRScheduler = None,
):
    """
    Loads a full training checkpoint to resume from.

    This function should be called AFTER the model and optimizer have been
    initialized and the model has been moved to the target device.

    Args:
        checkpoint_path (str): The path to the checkpoint file.
        model (torch.nn.Module): The model instance (already on the target device).
        optimizer (torch.optim.Optimizer): The optimizer instance.
        ema (EMA): The EMA handler instance.

    Returns:
        A tuple containing:
            - model (torch.nn.Module): The model with loaded state.
            - ema (EMA): The EMA handler with loaded state.
            - optimizer (torch.optim.Optimizer): The optimizer with loaded state.
            - scheduler (torch.optim.lr_scheduler._LRScheduler): The scheduler with loaded state.

    """
    if not os.path.exists(checkpoint_path):
        print(f"No checkpoint found at '{checkpoint_path}'. Starting from scratch.")
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # --- INFER DEVICE FROM THE MODEL ---
    # This is the key improvement. It makes the function more robust.
    try:
        device = next(model.parameters()).device
    except StopIteration:
        # This handles the case of a model with no parameters.
        print(
            "Warning: Model has no parameters, cannot infer device. Defaulting to CPU."
        )
        device = torch.device("cpu")

    # Use the passed device for map_location and weights_only=True for security.
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
    except Exception as e:
        print(f"Could not load checkpoint with weights_only=True. Error: {e}")
        print("Attempting to load with weights_only=False (use with caution)...")
        checkpoint = torch.load(
            checkpoint_path, map_location=device, weights_only=False
        )

    # --- Load Model State ---
    # The torch.compile wrapper handles unwrapping the model automatically.
    # We will get a detailed report of missing/unexpected keys.
    ckpt_model_state_dict = checkpoint[CkptKey.MODEL.value]
    compile_prefix = "_orig_mod."

    model_compiled = is_model_compiled(model)
    ckpt_compiled = is_checkpoint_from_compiled_model(ckpt_model_state_dict)
    print(f"Model compiled: {model_compiled}, Checkpoint compiled: {ckpt_compiled}")

    prepared_model_state_dict = prepare_state_dict(
        ckpt_model_state_dict, compile_prefix, model_compiled, ckpt_compiled
    )
    model.load_state_dict(prepared_model_state_dict, strict=True)
    print("Model state loaded successfully.")

    # NOTE: the optimizer does not use named_parameters, it relies on the model's parameters order. Therefore, we do not need to look for a cleaned state_dict for the optimizer.
    # --- Load Optimizer State ---
    if CkptKey.OPTIMIZER.value in checkpoint and optimizer is not None:
        optimizer.load_state_dict(checkpoint[CkptKey.OPTIMIZER.value])
        print("Optimizer state loaded successfully.")
    elif optimizer is not None:
        print(
            "Warning: No optimizer state found in checkpoint. Optimizer will start from scratch."
        )

    # --- Load EMA State ---
    if CkptKey.EMA.value in checkpoint and ema is not None:
        prepared_ema_state_dict = prepare_state_dict(
            checkpoint[CkptKey.EMA.value], compile_prefix, model_compiled, ckpt_compiled
        )
        try:
            ema.load_state_dict(prepared_ema_state_dict, strict=True)
            print("EMA state loaded successfully.")
        except Exception as e:
            print(f"Error loading EMA state. EMA will start from scratch. Error: {e}")
    elif ema is not None:
        print("Warning: No EMA state found in checkpoint. EMA will start from scratch.")

    if CkptKey.SCHEDULER.value in checkpoint and scheduler is not None:
        scheduler.load_state_dict(checkpoint[CkptKey.SCHEDULER.value])
        print("Scheduler state loaded successfully.")
    elif scheduler is not None:
        print(
            "Warning: No scheduler state found in checkpoint. Scheduler will start from scratch."
        )

    return model, ema, optimizer, scheduler


def save_checkpoint(
    model: torch.nn.Module,
    ema: EMA,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    file_path: str,
):
    """
    Saves the state of the model, EMA, and optimizer to a single file.

    Args:
        model: The PyTorch model (can be compiled).
        ema: The ExponentialMovingAverage handler.
        optimizer: The optimizer.
        file_path (str): The full path to save the checkpoint file (e.g., 'path/to/my_model.pt').
    """
    # Ensure the directory exists
    checkpoint_dir = os.path.dirname(file_path)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
        print(f"Created checkpoint directory: {checkpoint_dir}")

    # Create the checkpoint dictionary with only the essential components.
    checkpoint = {
        CkptKey.MODEL.value: model.state_dict(),
        CkptKey.OPTIMIZER.value: optimizer.state_dict(),
        CkptKey.EMA.value: ema.state_dict(),
        CkptKey.SCHEDULER.value: scheduler.state_dict(),
    }

    # Save the checkpoint.
    torch.save(checkpoint, file_path)
