from __future__ import annotations

from typing import Any, Optional
from pathlib import Path
import warnings

import torch

from .core import CMPluginConfig, CMPathManager


class CMRefinerBase:
    def __init__(self, enabled: bool = False) -> None:
        self.enabled = enabled

    def refine(self, state, context: Optional[dict[str, Any]] = None):
        return state


class NoOpCMRefiner(CMRefinerBase):
    pass


class TorchCMRefiner(CMRefinerBase):
    def __init__(self, model: torch.nn.Module, enabled: bool = True) -> None:
        super().__init__(enabled=enabled)
        self.model = model.eval()

    def _ensure_device(self, state):
        if not torch.is_tensor(state):
            return state
        model_device = next(self.model.parameters()).device if any(p is not None for p in self.model.parameters()) else state.device
        if model_device != state.device:
            self.model = self.model.to(state.device)
        return state

    def refine(self, state, context: Optional[dict[str, Any]] = None):
        state = self._ensure_device(state)
        context = context or {}
        with torch.no_grad():
            try:
                return self.model(state, context)
            except TypeError:
                try:
                    return self.model(state, **context)
                except TypeError:
                    return self.model(state)


def build_cm_refiner(config: CMPluginConfig, device: Optional[torch.device] = None) -> CMRefinerBase:
    if not config.enabled:
        return NoOpCMRefiner(enabled=False)

    paths = CMPathManager(config)
    load_path = Path(paths.resolve_load_path())
    if not load_path.exists():
        raise FileNotFoundError(f"CM model not found: {load_path}")
    if load_path.stat().st_size == 0:
        warnings.warn(f"CM model file is empty, fallback to NoOp: {load_path}")
        return NoOpCMRefiner(enabled=False)

    try:
        model = torch.jit.load(str(load_path), map_location=device or "cpu")
        return TorchCMRefiner(model=model, enabled=True)
    except Exception:
        obj = torch.load(str(load_path), map_location=device or "cpu")
        if isinstance(obj, torch.nn.Module):
            return TorchCMRefiner(model=obj, enabled=True)
        raise ValueError("Unsupported CM model format. Provide a TorchScript file or a serialized nn.Module.")


def maybe_refine(state, cm_refiner: Optional[CMRefinerBase], context: Optional[dict[str, Any]] = None):
    if cm_refiner is None or not getattr(cm_refiner, "enabled", False):
        return state
    return cm_refiner.refine(state, context=context)
