import os
import numpy as np
import torch
from torch import nn
from torchvision import transforms
from torchvision.models import inception_v3, Inception_V3_Weights
from sklearn import preprocessing

from vendi_score import data_utils


def get_inception(pretrained=True, pool=True):
    if pretrained:
        weights = Inception_V3_Weights.DEFAULT
    else:
        weights = None
    model = inception_v3(
        weights=weights, transform_input=True
    ).eval()
    if pool:
        model.fc = nn.Identity()
    return model


def inception_transforms():
    return transforms.Compose(
        [
            transforms.Resize(299),
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.expand(3, -1, -1)),
        ]
    )


def dinov2_transforms():
    return transforms.Compose(
            [
                transforms.Grayscale(num_output_channels=3),
                transforms.Resize(
                    (224, 224),
                    interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
                ),
            ]
    )


def compute_embeddings(
    images,
    model=None,
    transform=None,
    batch_size=64,
    device=torch.device("cpu"),
    extractor_name=None,
    cache=None,
    normalize=False,
):
    if cache is not None:
        if os.path.isfile(cache):
            # print('using cached embedding')
            loaded_embeddings = np.load(cache)
            return loaded_embeddings
    if type(device) is str:
        device = torch.device(device)
    if model is None:
        model = get_inception(pretrained=True, pool=True).to(device)
        transform = inception_transforms()
    else:
        if extractor_name == 'dino':
            model = model.to(device)
            transform = dinov2_transforms()
        else:
            transform = transforms.ToTensor()
    embeddings = []
    for batch in data_utils.to_batches(images, batch_size):
        if extractor_name != 'dreamsim':
            x = torch.stack([transform(img) for img in batch], 0).to(device)
        else:
            x = torch.stack([img for img in batch], 0).to(device)
        with torch.no_grad():
            if extractor_name == 'dino':
                output = model(x)['pooler_output']
            elif extractor_name == 'inception':
                output = model(x)
            elif extractor_name == 'dreamsim':
                output_list = []
                for img in x:
                    output = model.embed(img)
                    output_list.append(output)
                output = torch.stack(output_list, 0)
            else:
                raise ValueError(
                    f"Unknown extractor name: {extractor_name}"
                )
        if type(output) is list:
            output = output[0]
        saved_output = output.squeeze().cpu().numpy()
        try:
            assert len(saved_output.shape) == 2
        except AssertionError:
            assert len(saved_output.shape) == 1
            saved_output = np.expand_dims(saved_output, axis=0)
        except Exception as e:
            print(f"Unexpected exception: {e}")
        embeddings.append(saved_output)
    loaded_embeddings = np.concatenate(embeddings, 0)
    if normalize:
        loaded_embeddings = preprocessing.normalize(
            loaded_embeddings, axis=1
        )
    if cache is not None:
        np.save(cache, loaded_embeddings)
    return loaded_embeddings
