import torch
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
from scipy import linalg
from transformers import AutoFeatureExtractor, AutoModel, CLIPProcessor, CLIPModel
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
import requests
import clip
from torch.nn import functional as F
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
from scipy.stats import entropy
from torchvision.models.inception import inception_v3
from torchvision import transforms
from PIL import Image


class ImageEvaluator:
    def __init__(self, batch_size=32):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.inception_model = self._load_inception_model()
        self.clip_model, self.clip_processor = self._load_clip_model()
        self.dino_model, self.dino_processor = self._load_dino_model()
        self.transform = transforms.Compose([
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.batch_size = batch_size

    def _load_inception_model(self):
        model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1)
        model.fc = torch.nn.Identity()
        model = model.to(self.device)
        model.eval()
        return model
        
    def _load_dino_model(self):
        model_name = "facebook/dinov2-base"
        try:
            processor = AutoFeatureExtractor.from_pretrained(model_name)
            model = AutoModel.from_pretrained(model_name)
            model = nn.DataParallel(model).to(self.device)
        except (requests.exceptions.RequestException, OSError, ValueError) as e:
            print(f"Error loading DINO model: {e}")
            print("Falling back to a dummy DINO model and processor.")
            processor = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
            model = nn.DataParallel(torch.nn.Identity()).to(self.device)
        model.eval()
        return model, processor

    def _get_activations(self, image_path, model, preprocess=None):
        if os.path.isfile(image_path):
            image = Image.open(image_path).convert('RGB')
            transform = preprocess if preprocess else self.transform
            image = transform(image).unsqueeze(0).to(self.device)
            with torch.no_grad():
                activations = model(image)
            return activations.cpu().numpy()
        else:
            class ImageDataset(Dataset):
                def __init__(self, image_path, transform):
                    self.image_paths = [os.path.join(image_path, f) for f in os.listdir(image_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
                    self.transform = transform

                def __len__(self):
                    return len(self.image_paths)

                def __getitem__(self, idx):
                    image = Image.open(self.image_paths[idx]).convert('RGB')
                    return self.transform(image)

            transform = preprocess if preprocess else self.transform
            dataset = ImageDataset(image_path, transform)
            dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)

            all_activations = []
            with torch.no_grad():
                for batch in dataloader:
                    if preprocess:
                        batch = {k: v.to(self.device) for k, v in batch.items()}
                    else:
                        batch = batch.to(self.device)
                    activations = model(**batch) if preprocess else model(batch)
                    if isinstance(activations, torch.Tensor):
                        activations = activations.cpu().numpy()
                    else:  # For models that return a dictionary of tensors
                        activations = activations.last_hidden_state.mean(dim=1).cpu().numpy()
                    all_activations.append(activations)

            return np.concatenate(all_activations, axis=0)

    def _calculate_fid(self, real_activations, generated_activations):
        mu1, sigma1 = real_activations.mean(axis=0), np.cov(real_activations, rowvar=False)
        mu2, sigma2 = generated_activations.mean(axis=0), np.cov(generated_activations, rowvar=False)
        
        ssdiff = np.sum((mu1 - mu2)**2.0)
        
        # Check if sigma1 and sigma2 are matrices before using linalg.sqrtm
        if sigma1.ndim == 2 and sigma2.ndim == 2:
            covmean = linalg.sqrtm(sigma1.dot(sigma2))
            
            if np.iscomplexobj(covmean):
                covmean = covmean.real
            
            fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
        else:
            print("Warning: Covariance matrices are not 2D. FID calculation may be inaccurate.")
            fid = ssdiff
        
        return fid
    
    def _load_clip_model(self):
        try:
            model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
            model = nn.DataParallel(model).to(self.device)
            processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        except Exception as e:
            print(f"Error loading CLIP model: {e}")
            print("Falling back to a dummy CLIP model and processor.")
            processor = None
            model = nn.DataParallel(torch.nn.Identity()).to(self.device)
        model.eval()
        return model, processor

    def get_CLIP_score(self, image_path, text):
        if isinstance(self.clip_model.module, torch.nn.Identity):
            print("Warning: Using dummy CLIP model. Score may not be meaningful.")
            return 0.0

        image = Image.open(image_path).convert("RGB")
        inputs = self.clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to(self.device)

        with torch.no_grad():
            outputs = self.clip_model(**inputs)
            logits_per_image = outputs.logits_per_image
            clip_score = logits_per_image.item()

        return clip_score

    def get_FID_score(self, real_image_path, generated_image_path):
        real_activations = self._get_activations(real_image_path, self.inception_model)
        generated_activations = self._get_activations(generated_image_path, self.inception_model)
        return self._calculate_fid(real_activations, generated_activations)


    def get_DINO_score(self, image_path1, image_path2):
        if isinstance(self.dino_processor, transforms.Compose):
            print("Using fallback DINO processor.")
            image1 = Image.open(image_path1).convert("RGB")
            image2 = Image.open(image_path2).convert("RGB")
            
            inputs1 = self.dino_processor(image1).unsqueeze(0).to(self.device)
            inputs2 = self.dino_processor(image2).unsqueeze(0).to(self.device)

            with torch.no_grad():
                features1 = self.dino_model(inputs1)
                features2 = self.dino_model(inputs2)
        else:
            image1 = Image.open(image_path1).convert("RGB")
            image2 = Image.open(image_path2).convert("RGB")
            
            inputs1 = self.dino_processor(images=image1, return_tensors="pt").to(self.device)
            inputs2 = self.dino_processor(images=image2, return_tensors="pt").to(self.device)

            with torch.no_grad():
                features1 = self.dino_model(**inputs1).last_hidden_state.mean(dim=1)
                features2 = self.dino_model(**inputs2).last_hidden_state.mean(dim=1)

        # Flatten the features if they're not already 1D
        features1 = features1.view(features1.size(0), -1)
        features2 = features2.view(features2.size(0), -1)

        # Compute cosine similarity
        similarity = torch.nn.functional.cosine_similarity(features1, features2)
        
        # If similarity is a single-element tensor, return it as a scalar
        if similarity.numel() == 1:
            return similarity.item()
        else:
            # If it's not a single-element tensor, return the mean similarity
            return similarity.mean().item()