import glob
import os
from typing import List

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import AutoModel, AutoProcessor

from diffusion_arithmetics.noise_learning.distance_classification import get_noise_sample_by_distance_classification

CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
DINO_MODEL_NAME = "facebook/dino-vits16"
BS = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_clip(device=DEVICE):
    clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_NAME)
    clip_model = AutoModel.from_pretrained(CLIP_MODEL_NAME).to(device)
    return clip_processor, clip_model


def get_dino(device=DEVICE):
    dino_model = AutoModel.from_pretrained(DINO_MODEL_NAME, add_pooling_layer=False).to(device)
    return dino_model


def get_clip_features(imgs: List[Image.Image], clip_processor, clip_model, device=DEVICE):
    outs = []
    for batch_ids in range(0, len(imgs), BS):
        batch = imgs[batch_ids : batch_ids + BS]
        clip_batch_in = clip_processor(images=batch, return_tensors="pt").pixel_values.to(device)
        feats = clip_model.get_image_features(clip_batch_in)
        outs.append(feats.detach().cpu())
    return torch.cat(outs)


def get_dino_features(imgs: List[Image.Image], dino_model, device=DEVICE):
    T = transforms.Compose(
        [
            transforms.Resize(256, interpolation=3),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
    outs = []
    for batch_ids in range(0, len(imgs), BS):
        batch = imgs[batch_ids : batch_ids + BS]
        pred_imgs_processed = torch.stack([T(img).to(device) for img in batch])
        pred_features = dino_model(pred_imgs_processed).last_hidden_state[:, 0, :]
        outs.append(pred_features.detach().cpu())
    return torch.cat(outs)


def get_mean_cosine_sim(vec1, vec2):
    vec1 = vec1.view(vec1.shape[0], -1)
    vec2 = vec2.view(vec2.shape[0], -1)
    vec1 = torch.nn.functional.normalize(vec1, dim=1)
    vec2 = torch.nn.functional.normalize(vec2, dim=1)
    return torch.sum(vec1 * vec2, dim=1).mean()


def calc_angles(noises, samples, latents):
    vec_image_to_noise = noises - samples
    vec_image_to_latent = latents - samples

    vec_to_noises = noises
    vec_to_samples = samples

    vec_latent_to_noise = noises - latents

    mean_cossim_img2noise_img2lat = get_mean_cosine_sim(vec_image_to_noise, vec_image_to_latent)
    mean_cossim_noise_image = get_mean_cosine_sim(vec_to_noises, vec_to_samples)
    mean_cossim_img2noise_lat2noise = get_mean_cosine_sim(vec_image_to_noise, vec_latent_to_noise)

    return {
        "mean_cossim_img2noise_img2lat": mean_cossim_img2noise_img2lat,
        "mean_cossim_noise_image": mean_cossim_noise_image,
        "mean_cossim_img2noise_lat2noise": mean_cossim_img2noise_lat2noise,
    }


def _mean_mse(tens1: torch.Tensor, tens2: torch.Tensor) -> float:
    mse = F.mse_loss(tens1, tens2, reduction="none")
    return mse.mean(dim=[1, 2, 3]).mean()


def calc_distances(noise, sample_from_noise, latent, sample_from_latent = None) -> dict:
    out = {
        "mean_mse_img_lat": _mean_mse(tens1=sample_from_noise, tens2=latent),
        "mean_mse_img_noise": _mean_mse(tens1=sample_from_noise, tens2=noise),
        "mean_mse_noise_lat": _mean_mse(tens1=noise, tens2=latent),
    }
    if sample_from_latent is not None:
        out["mean_mse_img_img2"] = _mean_mse(tens1=sample_from_noise, tens2=sample_from_latent)
    return out


def run_metrics(noises, latents, samples, samples2):
    vec_sample_noise = noises - samples
    vec_sample_latent = latents - samples
    vec_sample_noise = vec_sample_noise.view(vec_sample_noise.shape[0], -1)
    vec_sample_latent = vec_sample_latent.view(vec_sample_latent.shape[0], -1)
    vec_sample_noise = torch.nn.functional.normalize(vec_sample_noise, dim=1)
    vec_sample_latent = torch.nn.functional.normalize(vec_sample_latent, dim=1)
    get_noise_sample_by_distance_classification(noises=noises.clone(), samples=samples.clone())
    print("COSINE SIMILARITY(V1,V2):", torch.sum(vec_sample_noise * vec_sample_latent, dim=1).mean())
    print("MSE(NOISES,SAMPLES):", ((noises - samples) ** 2).mean())
    print("MSE(LATENTS,SAMPLES):", ((latents - samples) ** 2).mean())
    print("MSE(SAMPLES,SAMPLES2):", ((samples2 - samples) ** 2).mean())
    print("MSE(NOISES,LATENTS):", ((noises - latents) ** 2).mean())
