import torch
import clip
from torch.utils.data import DataLoader
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
from bert_score import score as bert_score
#from lavis.models import load_model_and_preprocess
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
import torchvision.transforms as T
import torch.nn.functional as F

def compute_sbert_similarity(model,preds, refs):
    pred_embeds = model.encode(preds, convert_to_tensor=True)
    ref_embeds = model.encode(refs, convert_to_tensor=True)
    cos_sim = torch.nn.functional.cosine_similarity(pred_embeds, ref_embeds)
    return cos_sim.mean().item()

# ... existing code ...
def evaluate_retrieval_clip(model, test_data, data_dir, k_values=[1, 5], device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    multimodal_data = test_data['multimodal']
    image_features = []
    text_features = []
    image_ids = []
    
    print(f"Extracting features for {len(multimodal_data)} test samples...")
    
    for item in multimodal_data:
        img_path = os.path.join(data_dir, item['image_path'])
        image = Image.open(img_path).convert('RGB')

        text = item['caption'][:77]

        with torch.no_grad():
            if hasattr(model, 'encode_image') and hasattr(model, 'encode_text'):
                if hasattr(model, 'preprocess'):
                    image_processed = model.preprocess(image).unsqueeze(0).to(device)
                else:
                    from torchvision import transforms
                    transform = transforms.Compose([
                        transforms.Resize((224, 224)),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
                                            std=[0.26862954, 0.26130258, 0.27577711]),
                    ])
                    image_processed = transform(image).unsqueeze(0).to(device)
                
                sample = {
                    "image": image_processed,
                    "text_input": [text]
                }
                features = model.extract_features(sample)
                img_feat = features["image_embeds"]
                txt_feat = features["text_embeds"]
                
            elif hasattr(model, 'extract_features'):
                from torchvision import transforms
                transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ])
                image_processed = transform(image).unsqueeze(0).to(device)
                
                sample = {
                    "image": image_processed,
                    "text_input": [text]
                }
                features = model.extract_features(sample)
                img_feat = features["image_embeds"]
                txt_feat = features["text_embeds"]
            else:
                print(f"Model type not supported for feature extraction")
                continue

        image_features.append(img_feat.cpu())
        text_features.append(txt_feat.cpu())
        image_ids.append(item['image_id'])
            
    if len(image_features) == 0:
        print("No valid features extracted!")
        return {}

    image_features = torch.cat(image_features, dim=0)
    text_features = torch.cat(text_features, dim=0)
 
    similarity = (image_features @ text_features.T).numpy()

    results = {}
    
    i2t_ranks = []
    for i in range(len(similarity)):
        scores = similarity[i]
        sorted_indices = np.argsort(scores)[::-1]
        rank = np.where(sorted_indices == i)[0][0] + 1
        i2t_ranks.append(rank)

    t2i_ranks = []
    for i in range(len(similarity)):
        scores = similarity[:, i]
        sorted_indices = np.argsort(scores)[::-1]
        rank = np.where(sorted_indices == i)[0][0] + 1
        t2i_ranks.append(rank)

    for k in k_values:
        i2t_recall_k = np.mean([1 if rank <= k else 0 for rank in i2t_ranks])
        t2i_recall_k = np.mean([1 if rank <= k else 0 for rank in t2i_ranks])
        
        results[f'i2t_recall@{k}'] = i2t_recall_k
        results[f't2i_recall@{k}'] = t2i_recall_k

    results['i2t_mean_rank'] = np.mean(i2t_ranks)
    results['t2i_mean_rank'] = np.mean(t2i_ranks)
    results['i2t_median_rank'] = np.median(i2t_ranks)
    results['t2i_median_rank'] = np.median(t2i_ranks)
    
    return results

from sentence_transformers import SentenceTransformer, util
def compute_sbert_similarity(sbert_model,preds, refs):
    emb1 = sbert_model.encode(preds, convert_to_tensor=True)
    emb2 = sbert_model.encode(refs, convert_to_tensor=True)

    cos_sim = util.pytorch_cos_sim(emb1, emb2)
    return cos_sim

def compute_use_similarity(preds, refs):
    
    model = load_use_model()
    pred_embeds = model(preds)
    ref_embeds = model(refs)
    sim = tf.keras.losses.cosine_similarity(pred_embeds, ref_embeds, axis=1)
    sim = 1 + sim.numpy()  # cosine similarity output is [-1, 0], shift to [0, 1]
    return np.mean(sim)