import torch.optim.lr_scheduler as lr_scheduler
import torch
from torch.optim import Adam, SGD, AdamW  # Import AdamW
from torch.optim.lr_scheduler import SequentialLR, LambdaLR, CosineAnnealingLR
from hydra.utils import get_class
import inspect
from omegaconf import OmegaConf
import inspect
import torch
import random
import numpy as np


def get_opt_scheduler(lr, wd, iters, scheduler, opt, model):
    if opt.lower() == "adam":
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=wd)
    elif opt.lower() == "sgd":
        optimizer = SGD(model.parameters(), lr=lr, weight_decay=wd)
    elif opt.lower() == "adamw":
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=wd)
    else:
        raise ValueError("Invalid optimizer specified in the configuration.")
    if scheduler.lower() == "lin_warmup_cosine":
        # use cosine schedule
        warmup_epochs = iters // 100

        # Warm-up scheduler: Linearly ramps up the learning rate from 0 to initial_lr over warmup_epochs
        warmup_scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: epoch / warmup_epochs)

        # Cosine Annealing scheduler: Decreases the LR following a cosine curve after warmup
        cosine_scheduler = CosineAnnealingLR(optimizer, T_max=iters - warmup_epochs, eta_min=0)

        # Sequential scheduler: Chains warm-up and cosine annealing
        scheduler = SequentialLR(
            optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_epochs]
        )
    else:
        scheduler = None
    return optimizer, scheduler


class CustomNLLLoss(torch.nn.Module):
    def __init__(self, epsilon=1e-6, reduction="mean"):
        super(CustomNLLLoss, self).__init__()
        self.epsilon = epsilon
        self.reduction = reduction

    def forward(self, log_probs, targets):
        # Add epsilon for numerical stability
        log_probs = log_probs + self.epsilon

        # Ensure stability: take log after adding epsilon
        log_probs = torch.log(log_probs)

        # Flatten log_probs and targets to work with sequence data
        batch_size, vocab_size, seq_len = log_probs.size()
        log_probs = log_probs.reshape(-1, vocab_size)  # Shape: (batch_size * seq_len, vocab_size)
        targets = targets.reshape(-1)  # Shape: (batch_size * seq_len)

        # Select the log probabilities corresponding to the target tokens
        nll_loss = -log_probs[torch.arange(log_probs.size(0)), targets]

        # Apply reduction (mean, sum, or none)
        if self.reduction == "mean":
            return nll_loss.mean()
        elif self.reduction == "sum":
            return nll_loss.sum()
        else:  # 'none'
            return nll_loss.view(batch_size, seq_len)


def instantiate_filtered(main_cfg, *second_cfgs, overwrite_cfgs=None):
    # Convert input dictionaries to OmegaConf configs with object support.
    main_cfg = OmegaConf.create(main_cfg, flags={"allow_objects": True})
    second_cfgs = [OmegaConf.create(cfg, flags={"allow_objects": True}) for cfg in second_cfgs]
    if overwrite_cfgs is not None:
        overwrite_cfgs = OmegaConf.create(overwrite_cfgs, flags={"allow_objects": True})

    # Retrieve the target class from the '_target_' attribute.
    target_cls = get_class(main_cfg.pop("_target_"))

    # Determine any common keys (excluding those that will be overwritten)
    common_keys = set()
    overwrite_keys = set(overwrite_cfgs.keys()) if overwrite_cfgs is not None else set()
    for cfg in second_cfgs:
        overlapping_keys = set(main_cfg.keys()).intersection(cfg.keys())
        overlapping_keys -= overwrite_keys
        common_keys.update(overlapping_keys)

    if common_keys:
        conflict_keys = ", ".join(sorted(common_keys))
        raise ValueError(
            f"The configurations have some keys in common ({conflict_keys}), "
            "this procedure will overwrite the keys in the first configuration "
            "with the ones in the subsequent configurations. Please make sure that "
            "the configurations do not have keys in common."
        )

    # Inform the user about the keys that will be overwritten.
    if overwrite_cfgs is not None:
        overwritten_keys = set()
        for cfg in (main_cfg, *second_cfgs):
            overwritten_keys |= set(cfg.keys()).intersection(overwrite_cfgs.keys())
        if overwritten_keys:
            print(
                f"Overwriting the following keys with overwrite_cfgs: {', '.join(sorted(overwritten_keys))}"
            )

    # Merge configurations, giving precedence to overwrite_cfgs if provided.
    if overwrite_cfgs is not None:
        merged_cfg = OmegaConf.merge(main_cfg, *second_cfgs, overwrite_cfgs)
    else:
        merged_cfg = OmegaConf.merge(main_cfg, *second_cfgs)

    # Filter merged configuration keys based on the target class constructor parameters.
    constructor_params = inspect.signature(target_cls.__init__).parameters
    filtered_cfg = {k: v for k, v in merged_cfg.items() if k in constructor_params}

    # Instantiate the target class with the filtered configuration.
    return target_cls(**filtered_cfg)


def get_dtype(dtype):
    if dtype == "float32":
        return torch.float32
    elif dtype == "float64":
        return torch.float64
    elif dtype == "float16":
        return torch.float16
    elif dtype == "bfloat16":
        return torch.bfloat16
    else:
        raise ValueError(f"Invalid dtype {dtype} specified in the configuration.")


def save_checkpoint(checkpoint_path, model, optimizer, iteration):
    checkpoint = {
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "iteration": iteration,
        # Save RNG states
        "rng_state": torch.get_rng_state(),
        "cuda_rng_state": torch.cuda.get_rng_state_all(),
        "numpy_rng_state": np.random.get_state(),
        "python_rng_state": random.getstate(),
    }
    torch.save(checkpoint, checkpoint_path)


def save_tensor(checkpoint_path, tensor):
    torch.save(tensor, checkpoint_path)


def load_checkpoint(checkpoint_path, model, optimizer):
    checkpoint = torch.load(checkpoint_path, map_location="cpu")  # or map_location=device
    model.load_state_dict(checkpoint["model_state"], strict=False)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state"])
    iteration = checkpoint["iteration"]

    # Restore the RNG states exactly as they were
    torch.set_rng_state(checkpoint["rng_state"])
    torch.cuda.set_rng_state_all(checkpoint["cuda_rng_state"])
    np.random.set_state(checkpoint["numpy_rng_state"])
    random.setstate(checkpoint["python_rng_state"])

    return iteration
