import functools
import logging
from typing import Dict, List, Optional, Tuple, cast

import torch
from hydra.utils import instantiate
from omegaconf import DictConfig

from crps_retrofitting.optim.distributed_shampoo.shampoo_types import (
    FSDPShampooConfig,
    HSDPShampooConfig,
)
from crps_retrofitting.optim.distributed_shampoo.utils.shampoo_fsdp_utils import (
    compile_fsdp_parameter_metadata,
)
from crps_retrofitting.optim.staged_lr_scheduler import StagedLRScheduler

logger = logging.getLogger(__name__)


def create_parameter_groups(
    model: torch.nn.Module,
    model_info: Optional[Dict] = None,
    new_params_lr: float = 1e-3,
    common_params_lr: float = 1e-4,
    new_params_kwargs: Optional[Dict] = None,
    common_params_kwargs: Optional[Dict] = None,
) -> Tuple[List[Dict], bool]:
    """
    Create parameter groups for staged learning.

    Args:
        model: The model to create parameter groups for
        model_info: Information about loaded/missing parameters from checkpoint
        new_params_lr: Learning rate for new parameters
        common_params_lr: Learning rate for common parameters
        new_params_kwargs: Additional kwargs for new parameters group
        common_params_kwargs: Additional kwargs for common parameters group

    Returns:
        Tuple of (parameter_groups, has_common_params)
    """
    new_params = []
    common_params = []

    # Check if we should use staged learning
    has_missing_params = (
        model_info and "missing" in model_info and len(model_info["missing"]) > 0
    )

    if has_missing_params:
        # We have staging information - separate parameters
        missing_param_names = set(model_info["missing"])

        for name, param in model.named_parameters():
            if param.requires_grad:
                if name in missing_param_names:
                    new_params.append(param)
                else:
                    common_params.append(param)

        # Safety check: if no new params, disable staging to avoid issues
        if len(new_params) == 0:
            logger.warning("No new parameters found! Disabling staged learning")
            # Move all common params to new params with normal LR
            new_params = common_params
            common_params = []

    else:
        # Either staged learning is disabled or no new layers - treat all parameters as new
        new_params = [p for p in model.parameters() if p.requires_grad]

    # Log the parameter distribution
    if common_params:
        logger.info(
            f"Staged learning: {len(new_params)} new params, {len(common_params)} common params"
        )
    else:
        logger.info(
            f"No new layers detected: {len(new_params)} params (normal training)"
        )

    param_groups = []
    if new_params:
        group = {"params": new_params, "lr": new_params_lr}
        if new_params_kwargs:
            group.update(new_params_kwargs)
        param_groups.append(group)
    if common_params:
        group = {"params": common_params, "lr": common_params_lr}
        if common_params_kwargs:
            group.update(common_params_kwargs)
        param_groups.append(group)

    return param_groups, len(common_params) > 0


def setup_optimizer_and_scheduler(
    cfg: DictConfig,
    model: torch.nn.Module,
    device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
    model_info: Optional[Dict] = None,
    last_epoch: int = -1,
) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler._LRScheduler]]:
    """Setup optimizer and learning rate scheduler with staged learning support."""

    logger.info(f"Instantiate optimizer {cfg.optimizer._target_}")

    # Check if staged learning is enabled in config
    enable_staged_learning = getattr(cfg.trainer, "enable_staged_learning", False)

    # Extract LRs and kwargs for both groups from config
    new_params_lr = cfg.optimizer.get("new_params_lr", cfg.optimizer.lr)
    common_params_lr = cfg.optimizer.get("common_params_lr", cfg.optimizer.lr)
    new_params_kwargs = dict(cfg.optimizer.get("new_params_kwargs", {}))
    common_params_kwargs = dict(cfg.optimizer.get("common_params_kwargs", {}))

    # Remove keys that shouldn't be passed to optimizer
    for d in (new_params_kwargs, common_params_kwargs):
        for k in ["_target_", "lr", "params"]:
            d.pop(k, None)

    # Create parameter groups for staged learning
    param_groups, has_common_params = create_parameter_groups(
        model,
        model_info,
        new_params_lr=new_params_lr,
        common_params_lr=common_params_lr,
        new_params_kwargs=new_params_kwargs,
        common_params_kwargs=common_params_kwargs,
    )

    # Ensure we have valid parameter groups
    if not param_groups:
        raise ValueError("No parameter groups created!")

    # Create optimizer using direct instantiation (bypass Hydra issues)
    logger.info("Creating optimizer with direct instantiation...")

    from hydra.utils import get_class

    optimizer_class = get_class(cfg.optimizer._target_)

    # Extract optimizer kwargs for global settings (if any)
    optimizer_kwargs = {}
    for key, value in cfg.optimizer.items():
        if key not in [
            "_target_",
            "lr",
            "params",
            "new_params_lr",
            "common_params_lr",
            "new_params_kwargs",
            "common_params_kwargs",
        ]:
            optimizer_kwargs[key] = value

    logger.info(
        f"Creating {optimizer_class.__name__} with {len(param_groups)} parameter groups"
    )
    logger.info(f"Optimizer kwargs: {list(optimizer_kwargs.keys())}")

    # Handle DistributedShampoo special case
    if "DistributedShampoo" in cfg.optimizer._target_:
        # Setup distributed configuration
        distribution_type = cfg.distribution.distribution_type.upper()
        distributed_config = None

        if distribution_type == "FSDP":
            distributed_config = FSDPShampooConfig(
                param_to_metadata=compile_fsdp_parameter_metadata(model)
            )
        elif distribution_type == "HSDP":
            if device_mesh is None:
                raise ValueError("`device_mesh` is required for HSDP")
            distributed_config = HSDPShampooConfig(
                param_to_metadata=compile_fsdp_parameter_metadata(model),
                device_mesh=device_mesh,
                num_trainers_per_group=cfg.optimizer.distributed_config.num_trainers_per_group,
            )

        # Create DistributedShampoo with distributed_config
        if distributed_config is not None:
            optimizer = optimizer_class(
                param_groups, distributed_config=distributed_config, **optimizer_kwargs
            )
        else:
            optimizer = optimizer_class(param_groups, **optimizer_kwargs)
    else:
        # Standard optimizer (AdamW, SGD, etc.)
        optimizer = optimizer_class(param_groups, **optimizer_kwargs)

    # Setup learning rate scheduler
    lr_scheduler = None
    if hasattr(cfg, "lr_scheduler"):
        logger.info(f"Instantiate learning rate scheduler {cfg.lr_scheduler._target_}")

        # Calculate step multiplier for per-step scheduling
        if cfg.trainer.lr_scheduler_per_step:
            step_mult_factor = (
                cfg.data.module_parameters.max_samples / cfg.trainer.grad_acc_steps
            )
        else:
            step_mult_factor = 1

        # Create main LR scheduler using Hydra
        main_lr_scheduler = instantiate(
            cfg.lr_scheduler,
            optimizer=optimizer,
            max_epochs=cfg.trainer.max_epoch,
            step_mult_factor=step_mult_factor,
            last_epoch=max(-1, last_epoch - 1),
        )

        # Add staged learning if we have common parameters and staged learning is enabled
        if has_common_params and enable_staged_learning:
            warmup_epochs = getattr(cfg.trainer, "common_params_warmup_epochs", 5)
            logger.info(
                f"Setting up staged learning: common params will start training at epoch {warmup_epochs}"
            )

            lr_scheduler = StagedLRScheduler(
                optimizer=optimizer,
                main_scheduler=main_lr_scheduler,
                warmup_epochs=warmup_epochs,
                last_epoch=max(-1, last_epoch - 1),
            )
        else:
            lr_scheduler = main_lr_scheduler
    else:
        logger.info("No learning rate scheduler")

    return optimizer, lr_scheduler
