from __future__ import annotations
from typing import Tuple, Optional, Callable

import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, Dataset
from functools import partial


from ..models.vision import TwoLayerMLP, TwoHiddenMLP, LeNet, Autoencoder8Sigmoid, TinyTransformerLM
from ..models.linear import LogisticLinear


def _build_model(
    model_kind: str,
    input_shape: Tuple[int, ...],
    num_classes: int,
    cfg: dict,
    device: torch.device,
    in_channels: int = 3,
    input_size: int = 32,
) -> nn.Module:
    """Top-level model factory matching build_from_cfg's branching logic.

    Returns a model instance placed on the requested device.
    
    Args:
        model_kind: Model type (e.g., 'alexnet', 'resnet18', 'lenet')
        input_shape: Input shape (flattened or spatial)
        num_classes: Number of output classes
        cfg: Configuration dictionary
        device: Torch device
        in_channels: Number of input channels (1 for grayscale, 3 for RGB)
        input_size: Spatial size of input images (e.g., 28 for MNIST, 32 for CIFAR)
    """
    model_kind = model_kind.lower()

    if model_kind == "mlp":
        return TwoLayerMLP(
            in_dim=input_shape[0],
            hidden=int(cfg.get("hidden", 128)),
            out_dim=num_classes,
        ).to(device)
    elif model_kind == "mlp2":
        return TwoHiddenMLP(
            in_dim=input_shape[0],
            hidden=int(cfg.get("hidden", 256)),
            out_dim=num_classes,
        ).to(device)
    elif model_kind == "autoencoder8":
        return Autoencoder8Sigmoid(
            in_dim=input_shape[0],
            widths=tuple(cfg.get("widths", [512, 256, 128, 64])),
        ).to(device)
    elif model_kind == "lenet":
        return LeNet(num_classes=num_classes, in_channels=in_channels, input_size=input_size).to(device)
    elif model_kind == "alexnet":
        from torchvision.models import alexnet

        model = alexnet(num_classes=num_classes)
        # Modify first layer for correct number of input channels
        if in_channels != 3:
            model.features[0] = nn.Conv2d(in_channels, 64, kernel_size=11, stride=4, padding=2)
        return model.to(device)
    elif model_kind == "vgg13":
        from torchvision.models import vgg13

        model = vgg13(num_classes=num_classes)
        if in_channels != 3:
            model.features[0] = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        return model.to(device)
    elif model_kind == "resnet18":
        from torchvision.models import resnet18

        model = resnet18(num_classes=num_classes)
        if in_channels != 3:
            model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        return model.to(device)
    elif model_kind == "resnet20":
        # torchvision doesn't ship resnet20; small CIFAR ResNet variant fallback
        try:
            from torchvision.models import resnet18 as _resnet18

            base = _resnet18(num_classes=num_classes)
            if in_channels != 3:
                base.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
            return base.to(device)
        except Exception:
            raise ValueError("resnet20 not available; consider resnet18 as a close substitute")
    elif model_kind == "resnet50":
        from torchvision.models import resnet50

        return resnet50(num_classes=num_classes).to(device)
    elif model_kind in {"logistic", "logreg"}:
        return LogisticLinear(d=input_shape[0]).to(device)
    elif model_kind in {"transformer2", "tiny_transformer"}:
        vocab_size = int(cfg.get("vocab_size", 10000))
        d_model = int(cfg.get("d_model", 256))
        n_heads = int(cfg.get("n_heads", 2))
        return TinyTransformerLM(vocab_size=vocab_size, d_model=d_model, n_heads=n_heads, n_layers=2).to(device)
    else:
        raise ValueError(f"Unknown model kind: {model_kind}")


def _make_loss(loss_cls: type[nn.Module]) -> nn.Module:
    """Create a new loss instance from its class (assumes default ctor)."""
    return loss_cls()


def build_from_cfg(
    cfg: dict,
    device: Optional[str] = None,
) -> Tuple[
    nn.Module,
    Callable[[], nn.Module],
    Dataset,
    Callable[[], Dataset],
    nn.Module,
    Callable[[], nn.Module],
    Optional[np.ndarray],
]:
    """Build (model, dataset, loss_fn, labels) from a dataset/model config for deep/ablation runners.

    Supports:
    - dataset: synthetic_logreg|synthetic_ridge (returns TensorDataset)
    - dataset: mnist (torchvision), cifar10 (torchvision)
    - model: resnet18 (torchvision), mlp
    """
    device = device if device else torch.device(cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu"))
    ds_kind = str(cfg.get("dataset", "synthetic_logreg")).lower()
    model_kind = str(cfg.get("model", "mlp")).lower()
    subset_fraction = float(cfg.get("subset_fraction", 1))
    # Note: return top-level, picklable callables for *ctor outputs.
    # Dataset
    model = None
    input_size = 32  # Default for most datasets; will be overridden by image datasets
    if ds_kind == "mnist":
        from torchvision import datasets, transforms
        # If using ImageNet-like backbones, upsample to 224
        if model_kind in {"alexnet", "vgg13", "resnet18", "resnet20", "resnet50"}:
            transform = transforms.Compose([transforms.Resize(224), transforms.ToTensor()])
            input_size = 224
        else:
            transform = transforms.ToTensor()
            input_size = 28
        dataset = datasets.MNIST(root=cfg.get("data_root", "./data"), train=True, download=False, transform=transform)
        dataset_ctor = partial(datasets.MNIST, root=cfg.get("data_root", "./data"), train=True, download=False, transform=transform)
        loss_fn = nn.CrossEntropyLoss()
        labels = np.array(dataset.targets)
        num_classes = 10
        input_shape = (28 * 28,)
        in_channels = 1  # MNIST is grayscale
    elif ds_kind == "fashionmnist":
        from torchvision import datasets, transforms
        if model_kind in {"alexnet", "vgg13", "resnet18", "resnet20", "resnet50"}:
            transform = transforms.Compose([transforms.Resize(224), transforms.ToTensor()])
            input_size = 224
        else:
            transform = transforms.ToTensor()
            input_size = 28
        dataset = datasets.FashionMNIST(root=cfg.get("data_root", "./data"), train=True, download=True, transform=transform)
        dataset_ctor = partial(datasets.FashionMNIST, root=cfg.get("data_root", "./data"), train=True, download=True, transform=transform)
        loss_fn = nn.CrossEntropyLoss()
        labels = np.array(dataset.targets)
        num_classes = 10
        input_shape = (28 * 28,)
        in_channels = 1  # FashionMNIST is grayscale
    elif ds_kind == "cifar10":
        from torchvision import datasets, transforms
        # Resize for ImageNet-like backbones
        if model_kind in {"alexnet", "vgg13", "resnet18", "resnet20", "resnet50"}:
            transform = transforms.Compose([transforms.Resize(224), transforms.ToTensor()])
            input_size = 224
        else:
            transform = transforms.ToTensor()
            input_size = 32
        dataset = datasets.CIFAR10(root=cfg.get("data_root", "./data"), train=True, download=False, transform=transform)
        dataset_ctor = partial(datasets.CIFAR10, root=cfg.get("data_root", "./data"), train=True, download=False, transform=transform)
        loss_fn = nn.CrossEntropyLoss()
        labels = np.array(dataset.targets)
        num_classes = 10
        input_shape = (3 * 32 * 32,)
        in_channels = 3  # CIFAR-10 is RGB
    elif ds_kind in {"imagenet"}:
        # Attempt to use torchvision datasets ImageNet (expects manual download/ILSVRC2012 layout)
        from torchvision import datasets, transforms
        data_root = cfg.get("data_root", "./data/imagenet")
        transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
        try:
            dataset = datasets.ImageNet(root=data_root, split="train", transform=transform)
            dataset_ctor = partial(datasets.ImageNet, root=data_root, split="train", transform=transform)
        except Exception as e:
            raise RuntimeError(f"ImageNet dataset not found at {data_root}. Prepare ILSVRC2012 and set data_root. Original error: {e}")
        loss_fn = nn.CrossEntropyLoss()
        # labels not trivially available as array; skip stratification by default
        labels = None
        num_classes = 1000
        input_shape = (3 * 224 * 224,)
        in_channels = 3  # ImageNet is RGB
        input_size = 224
    else:
        raise ValueError(f"Unknown dataset kind: {ds_kind}")

    # Optional subset for torchvision datasets
    if subset_fraction < 1.0 and hasattr(dataset, "__len__"):
        from torch.utils.data import Subset
        n = int(len(dataset) * subset_fraction)
        idx = np.arange(len(dataset))[:max(n, 1)]
        dataset = Subset(dataset, idx.tolist())
        # Also subset labels if they exist
        if labels is not None:
            labels = labels[idx]

    # Model (use top-level factory + picklable ctor)
    model = _build_model(model_kind, input_shape, num_classes, cfg, device, in_channels, input_size)
    model_ctor: Callable[[], nn.Module] = partial(_build_model, model_kind, input_shape, num_classes, cfg, device, in_channels, input_size)

    # Loss "ctor" that rebuilds a fresh loss instance
    loss_fn_ctor: Callable[[], nn.Module] = partial(_make_loss, loss_fn.__class__)

    return model, model_ctor, dataset, dataset_ctor, loss_fn, loss_fn_ctor, labels
