import numpy as np
import torch
from PIL import Image
from torchvision.transforms import v2


def get_preprocess_mean_std(preprocess):
    try:
        mean = preprocess.transforms[-1].mean.detach().cpu().numpy()
        std = preprocess.transforms[-1].std.detach().cpu().numpy()
    except:
        mean = np.array(preprocess.transforms[-1].mean).reshape(3, 1, 1)
        std = np.array(preprocess.transforms[-1].std).reshape(3, 1, 1)
    return mean, std


def get_images(params):
    return np.load('nsd_processed/nsd_stimuli1000.npy')


def get_activation(name, activations):
    def hook(model, input, output):
        activations[name] = output[0].detach() if isinstance(output, tuple) else output.detach()
    return hook


def get_activations(preprocessed_images, model, layer_name):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    activations = {}

    for mod in [model, getattr(model, 'model', None), getattr(model, 'visual', None)]:
        if mod and layer_name in dict(mod.named_modules()):
            layer = dict(mod.named_modules())[layer_name]
            break
    else:
        raise ValueError(f"Layer {layer_name} not found in model.")

    handle = layer.register_forward_hook(get_activation(layer_name, activations))
    batch_size, all_acts = 32, []

    with torch.no_grad():
        for i in range(0, len(preprocessed_images), batch_size):
            batch = preprocessed_images[i:i+batch_size].to(device)
            model.encode_image(batch)
            out = activations[layer_name].cpu().numpy()
            if out.ndim == 3:
                out = out.mean(axis=1)
            all_acts.append(out)
            activations.clear()
    handle.remove()
    return np.concatenate(all_acts, axis=0)


class SquarePad:
    def __call__(self, image):
        w, h = image.size
        max_wh = np.max([w, h])
        hp, vp = (max_wh - w) // 2, (max_wh - h) // 2
        padding = (hp, vp, hp, vp)
        return v2.functional.pad(image, padding, 0, 'constant')


def get_preprocess(mean, std):
    return v2.Compose([
        SquarePad(),
        v2.Resize((224, 224)),
        v2.Lambda(lambda img: img.convert('RGB')),
        v2.ToTensor(),
        v2.Normalize(mean=mean, std=std),
    ])


def preprocess_images(imgs, mean, std, device):
    preprocess = get_preprocess(mean, std)
    tensors = []
    for img in imgs:
        img_pil = Image.fromarray(img.astype('uint8'), 'RGB')
        tensors.append(preprocess(img_pil).to(device))
    return torch.stack(tensors).to(device)
