import numpy as np
import torch


def flat_mnist(X_np):
    X = np.asarray(X_np)
    if X.ndim == 4 and X.shape[-1] == 1:
        X = X[..., 0]
    assert X.ndim == 3 and X.shape[1:] == (28, 28), f"Expected MNIST (N,28,28), got {X.shape}"
    X = X.reshape(X.shape[0], -1).astype(np.float32) / 255.0
    return X


def precompute_D2(F_test, F_train):
    test_sq = np.sum(F_test**2, axis=1, keepdims=True)
    train_sq = np.sum(F_train**2, axis=1, keepdims=True).T
    G = F_test @ F_train.T
    D2 = np.maximum(test_sq + train_sq - 2.0 * G, 0.0).astype(np.float32)
    return D2


def knn_predict(F_train, y_train, F_test, k=5):
    D2 = precompute_D2(F_test, F_train)
    Ne, Ntr = D2.shape
    k = int(k)
    idx_sorted = np.argsort(D2, axis=1)[:, :k]
    nn_labels = y_train[idx_sorted]
    preds = np.zeros(Ne, dtype=np.int64)
    for i in range(Ne):
        vals, counts = np.unique(nn_labels[i], return_counts=True)
        preds[i] = vals[np.argmax(counts)]
    return preds


def test_knn_scenario_raw(X_tr, y_tr, X_te, y_te, k=5):
    X_tr = np.asarray(X_tr)
    X_te = np.asarray(X_te)
    y_tr = np.asarray(y_tr)
    y_te = np.asarray(y_te)
    F_tr = flat_mnist(X_tr)
    F_te = flat_mnist(X_te)
    preds = knn_predict(F_tr, y_tr, F_te, k=k)
    acc = float(np.mean(preds == y_te))
    return acc


def build_concat_image(chunks):
    X_list, y_list = [], []
    for X_part, y_part in chunks:
        X_list.append(X_part)
        y_list.append(y_part)
    X_tr = np.concatenate(X_list, axis=0)
    y_tr = np.concatenate(y_list, axis=0)
    return X_tr, y_tr


def to_pil_rgb(x):
    from PIL import Image
    if x.ndim == 3 and x.shape[0] in (1, 3):
        x = np.moveaxis(x, 0, -1)
    if x.dtype != np.uint8:
        x = x.astype(np.uint8)
    return Image.fromarray(x, mode="RGB")


def build_flatten_fn(name: str, device: str = "cpu"):
    name = name.lower()
    if name == "vit_b16":
        return _build_vit_b16(device)
    if name == "resnet18":
        return _build_resnet18(device)
    if name == "resnet34":
        return _build_resnet34(device)
    if name == "resnet50":
        return _build_resnet50(device)
    raise ValueError(f"Unknown embed '{name}'")


def _build_vit_b16(device: str):
    import timm
    from torchvision import transforms as T
    model = timm.create_model("vit_base_patch16_224", pretrained=True)
    model.eval().to(device)
    size = 224
    tf = T.Compose([
        T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ])

    @torch.no_grad()
    def fn(X: np.ndarray) -> np.ndarray:
        imgs = torch.stack([tf(to_pil_rgb(x)) for x in X])
        imgs = imgs.to(device)
        feat = model.forward_features(imgs)
        if isinstance(feat, (list, tuple)):
            feat = feat[-1]
        v = feat.detach().cpu().numpy().astype("float32")
        v /= (np.linalg.norm(v, axis=1, keepdims=True) + 1e-12)
        return v

    return fn


def _build_resnet18(device: str):
    from torchvision.models import resnet18, ResNet18_Weights
    from torchvision import transforms as T
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.eval().to(device)
    size = 224
    tf = T.Compose([
        T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

    @torch.no_grad()
    def fn(X: np.ndarray) -> np.ndarray:
        imgs = torch.stack([tf(to_pil_rgb(x)) for x in X])
        imgs = imgs.to(device)
        x = model.conv1(imgs)
        x = model.bn1(x)
        x = model.relu(x)
        x = model.maxpool(x)
        x = model.layer1(x)
        x = model.layer2(x)
        x = model.layer3(x)
        x = model.layer4(x)
        x = model.avgpool(x)
        v = torch.flatten(x, 1)
        v = v.detach().cpu().numpy().astype("float32")
        v /= (np.linalg.norm(v, axis=1, keepdims=True) + 1e-12)
        return v

    return fn


def _build_resnet34(device: str):
    from torchvision.models import resnet34, ResNet34_Weights
    from torchvision import transforms as T
    model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
    model.eval().to(device)
    size = 224
    tf = T.Compose([
        T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

    @torch.no_grad()
    def fn(X: np.ndarray) -> np.ndarray:
        imgs = torch.stack([tf(to_pil_rgb(x)) for x in X])
        imgs = imgs.to(device)
        x = model.conv1(imgs)
        x = model.bn1(x)
        x = model.relu(x)
        x = model.maxpool(x)
        x = model.layer1(x)
        x = model.layer2(x)
        x = model.layer3(x)
        x = model.layer4(x)
        x = model.avgpool(x)
        v = torch.flatten(x, 1)
        v = v.detach().cpu().numpy().astype("float32")
        v /= (np.linalg.norm(v, axis=1, keepdims=True) + 1e-12)
        return v

    return fn


def _build_resnet50(device: str):
    from torchvision.models import resnet50, ResNet50_Weights
    from torchvision import transforms as T
    model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    model.eval().to(device)
    size = 224
    tf = T.Compose([
        T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

    @torch.no_grad()
    def fn(X: np.ndarray) -> np.ndarray:
        imgs = torch.stack([tf(to_pil_rgb(x)) for x in X])
        imgs = imgs.to(device)
        x = model.conv1(imgs)
        x = model.bn1(x)
        x = model.relu(x)
        x = model.maxpool(x)
        x = model.layer1(x)
        x = model.layer2(x)
        x = model.layer3(x)
        x = model.layer4(x)
        x = model.avgpool(x)
        v = torch.flatten(x, 1)
        v = v.detach().cpu().numpy().astype("float32")
        v /= (np.linalg.norm(v, axis=1, keepdims=True) + 1e-12)
        return v

    return fn


