import torch
import clip

from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import pandas as pd
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import accuracy_score

def cosine_similarity(x1, x2):
    x1 = x1 / x1.norm(dim=1, keepdim=True)
    x2 = x2 / x2.norm(dim=1, keepdim=True)
    return x1 @ x2.T

class FacetDataset(Dataset):
    def __init__(self, clip_name, device, class_to_idx):
        # prefix = model_name.replace("/", "").replace('-', '')

        # Load train data
        # train_file = torch.load(f'../facet/{clip_name}_facet_train_data_dict.pt')
        # val_file = torch.load(f'../facet/{clip_name}_facet_val_data_dict.pt')
        # self.image_embedding = torch.cat((train_file['image_embedding'], val_file['image_embedding']), dim=0)
        # self.gender = torch.cat((train_file['gender'], val_file['gender']), dim=0)
        # self.class_names = train_file['class1'] + val_file['class1']
        # self.race = torch.cat((train_file['race'], val_file['race']), dim=0)
        # self.age = torch.cat((train_file['age'], val_file['age']), dim=0)
        data = np.load(f'data/facet_{clip_name}_combined.npz')

        # Access the data
        image_embedding_np = data['image_embedding']
        gender_np = data['gender']
        class1_np = data['class1']

        # If you need to convert them back to PyTorch tensors
        self.image_embedding = torch.from_numpy(image_embedding_np)
        self.gender = torch.from_numpy(gender_np)
        self.class_names = list(class1_np)
        self.device = device
        self.class_to_idx = class_to_idx
        
    def __len__(self):
        return len(self.image_embedding)

    def __getitem__(self, idx):
        image_embedding = torch.tensor(self.image_embedding[idx]).to(self.device).float()
        gender = torch.tensor(self.gender[idx]).to(self.device).long()
        class_name = self.class_names[idx]
        class_label = torch.tensor(self.class_to_idx[class_name]).to(self.device).long()
        # race = torch.tensor(self.race[idx]).to(self.device).float()
        # age = torch.tensor(self.age[idx]).to(self.device).float()
        
        return image_embedding, gender, class_label

def zero_shot_classifier(image_embeddings, text_embeddings, class_labels):
    similarities = cosine_similarity(image_embeddings.float(), text_embeddings.float())
    similarities = similarities.softmax(-1)
    predictions = similarities.argmax(dim=-1)
    predicted_labels = [class_labels[pred] for pred in predictions.cpu().numpy()]
    return predicted_labels


def calculate_accuracy(preds, trues):
    return accuracy_score(trues, preds)




def evaluate_gender_difference(image_embeddings, image_ids, image_genders, text_embeddings, text_ids, text_genders, top_k):
    # Calculate the similarities between images and texts
    image_embeddings = image_embeddings.float()
    text_embeddings = text_embeddings.float()
    similarities = (100 * image_embeddings @ text_embeddings.T).softmax(dim=-1)

    # Initialize counters for the number of retrieved images for each gender
    # retrieved_counts = {'male': 0, 'female': 0}
    min_skew_list = []
    max_skew_list = []
    bias_diff=[]
    # Iterate over each text and its corresponding image
    for text_index, text_id in enumerate(text_ids):
        # Find the indices of the top K most similar images
        top_k_indices = torch.topk(similarities[:, text_index], k=top_k).indices
        retrieved_image_ids = np.array(image_ids)[top_k_indices[:top_k]]
        retrieved_counts = {'male': 0, 'female': 0}
        # Count the number of retrieved images for each gender
        for image_id in retrieved_image_ids:
            image_gender = image_genders[image_ids.index(image_id)]
            gender_key = 'male' if image_gender == 0 else 'female'
            retrieved_counts[gender_key] += 1
        male_skew = np.log((retrieved_counts['male']/top_k) /0.5)  # The same number of images for each gender
        female_skew = np.log((retrieved_counts['female']/top_k) /0.5) 
        bias_diff = retrieved_counts['male'] - retrieved_counts['female']
        # min_skew = min(male_skew, female_skew)
        
        max_skew = max(abs(male_skew), abs(female_skew))
        # print(retrieved_counts['male'],retrieved_counts['female'] ,male_skew, female_skew, max_skew)
    # min_skew_list.append(min_skew)    
    max_skew_list.append(max_skew)    
    
    return  np.round(np.mean(max_skew_list),4)

def evaluate_recall(image_embeddings, image_ids, text_embeddings, text_ids, top_k):
    # Calculate the similarities between images and texts
    image_embeddings = image_embeddings.float()
    text_embeddings = text_embeddings.float()
    similarities = (100 * image_embeddings @ text_embeddings.T).softmax(dim=-1)

    # Initialize counters for Recall@K
    recall_counts = [0] * len(text_ids)
    
    # Iterate over each text and its corresponding image
    cont_list = []
    for text_index, text_id in enumerate(text_ids):
        # Find the indices of the top K most similar images
        top_k_indices = torch.topk(similarities[:, text_index], k=top_k).indices
        if np.isin(text_id,np.array(image_ids)[top_k_indices[:top_k]]):
            recall_counts[text_index] +=1
            
    
    # Calculate the Recall@K percentages
    recall_percentages = sum(recall_counts) /len(text_ids)

    return recall_percentages



def evaluate_flickr(args,clip_model,device,clip_name,img_important_indices=None,img_mean_features_misclassified=None ,text_important_indices=None,text_mean_features_misclassified=None ):
    caption_df = pd.read_csv("data/flickr_with_gender_neutral_captions.csv")
    image_df = pd.read_csv("data/flickr_1000images.csv")
    caption_df = caption_df[caption_df['id'].isin(image_df['id'])]
    caption_df = pd.merge(caption_df, image_df[['id', 'gender']], on='id', how='left')
    captions = caption_df['neutral_caption'].tolist()
    genders = caption_df['gender_y'].tolist()
    ids = caption_df['id'].tolist()
    truncated_captions = []
    
    # Truncate captions to fit within CLIP's token limit
    max_tokens = 77
    captions = caption_df['neutral_caption'].tolist()
    genders = caption_df['gender_y'].tolist()
    ids = caption_df['id'].tolist()
    truncated_captions = []
    
    model, preprocess = clip.load(clip_model, device=device)
    batch_size = 512
    text_features = []
    all_genders = []
    all_ids = []
    for i in range(0, len(captions), batch_size):
        batch_captions = captions[i:i + batch_size]
    
        truncated_captions = []
        for caption in batch_captions:
            tokens = clip.tokenize([caption], truncate=True).squeeze(0)
            if len(tokens) > max_tokens:
                tokens = tokens[:max_tokens]
            truncated_captions.append(tokens)
        text_inputs = torch.stack(truncated_captions).to(device)
        batch_genders = genders[i:i + batch_size]
        batch_ids = ids[i:i + batch_size]
        with torch.no_grad():
            batch_text_features = model.encode_text(text_inputs)
            batch_text_features = batch_text_features / batch_text_features.norm(dim=-1, keepdim=True)
            if 'text' in args.target:    
                batch_text_features = batch_text_features.to(device)
                batch_text_features = batch_text_features.float()
                text_mean_features_misclassified = text_mean_features_misclassified.float()
                batch_text_features[:,text_important_indices] = text_mean_features_misclassified[text_important_indices]
                batch_text_features = batch_text_features / batch_text_features.norm(dim=-1, keepdim=True)
        text_features.append(batch_text_features.cpu())
        all_genders.extend(batch_genders)
        all_ids.extend(batch_ids)
    text_embeddings = torch.cat(text_features)
    # image_embeddings = torch.load(f'data_backup/flickr1000_ViTB32_image.pt')
    # image_embedding = image_embeddings['image_embeddings'].float()
    # image_id = image_embeddings['id']
    # img_gender = image_embeddings['gender']
    image_embeddings = np.load(f'data/flickr1000_ViTB32_image.npz')
    image_embedding = torch.tensor(image_embeddings['image_embedding']).float()
    image_id = list(image_embeddings['image_id'])
    img_gender = list(image_embeddings['img_gender'])
    text_id = all_ids


    text_gender = all_genders
    if 'image' in args.target:
        image_embedding = image_embedding.to(device)
        image_embedding[:,img_important_indices] = img_mean_features_misclassified[img_important_indices]
        image_embedding = image_embedding.cpu()
    return_results=[]
    for top_k in [1,5,10]:
        recall_results = evaluate_recall(image_embedding, image_id, text_embeddings, text_id, top_k)
        print(f'{clip_name}: Recall@{top_k}:', recall_results*100)
        return_results.append(recall_results)
    for top_k in [100]:
        gender_difference = evaluate_gender_difference(image_embedding, image_id, img_gender, text_embeddings, text_id, text_gender, top_k)
        print(f"{clip_name}: Top-{top_k} Gender Difference in Retrieved Images: MaxSkew {gender_difference}")
        return_results.append(gender_difference)
    return tuple(return_results)
def evaluate_facet(args,clip_model,device,clip_name,img_important_indices=None,img_mean_features_misclassified=None ,text_important_indices=None,text_mean_features_misclassified=None ):

    annotations = pd.read_csv('data/facet_annotations.csv')
    class_list = annotations['class1'].unique().tolist()

    #Convert class names to integer labels
    unique_class_names = sorted(set(class_list))
    class_to_idx = {class_name: idx for idx, class_name in enumerate(unique_class_names)}
    article_list = ['an' if job[0] in ['a','e','i','o','u'] else 'a' for job in unique_class_names ]
    prompts = [f'a photo of {article} {job}' for article, job in zip(article_list, unique_class_names)]

    total_accuracy = 0
    total_accuracy_count = 0
    total_accuracy_difference_sum = 0
    total_accuracy_difference_count = 0
    predictions_by_class_gender = defaultdict(lambda: defaultdict(list))
    true_labels_by_class_gender = defaultdict(lambda: defaultdict(list))
    
    class_gender_counts = defaultdict(lambda: defaultdict(int))
    class_gender_totals = defaultdict(lambda: defaultdict(int))
    
    
    val_dataset = FacetDataset(clip_name=clip_name, device=device,class_to_idx=class_to_idx)
    val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False)
    
    model, preprocess = clip.load('ViT-B/32', device=device)
    text = clip.tokenize(prompts).to(device)                
    text_embeddings = []
    with torch.no_grad():
        text_embedding = model.encode_text(text)
        # Normalize the text embedding
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
        text_embedding = text_embedding.float()
        if 'text' in args.target:    
            text_mean_features_misclassified = text_mean_features_misclassified.float()
            text_embedding[:,text_important_indices] = text_mean_features_misclassified[text_important_indices]
        text_embeddings.append(text_embedding)
        text_embeddings = torch.stack(text_embeddings).squeeze(0).to(device)
    
    for data in val_dataloader:
        image_embeddings, genders, class_labels = data
        if 'image' in args.target:
            image_embeddings = image_embeddings.to(device)
            image_embeddings[:,img_important_indices] = img_mean_features_misclassified[img_important_indices]
            image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
        predicted_labels = zero_shot_classifier(image_embeddings, text_embeddings, unique_class_names)
        
        # Collect data by gender and class
        for pred, true, gender in zip(predicted_labels, class_labels.cpu().numpy(), genders.cpu().numpy()):            
            pred = class_to_idx[pred]
            # true = class_to_idx[true]
            gender_key = 'Male' if gender == 1 else 'Female'
            class_key = unique_class_names[true]
            predictions_by_class_gender[class_key][gender_key].append(pred)
            true_labels_by_class_gender[class_key][gender_key].append(true)
            class_gender_counts[class_key][gender_key] += (pred == true)
            class_gender_totals[class_key][gender_key] += 1
    # Calculate accuracies and differences
    accuracy_by_class_gender = defaultdict(dict)
    for class_key, genders in true_labels_by_class_gender.items():
        for gender_key in genders.keys():
            accuracy = calculate_accuracy(predictions_by_class_gender[class_key][gender_key],
                                        true_labels_by_class_gender[class_key][gender_key])
            accuracy_by_class_gender[class_key][gender_key] = accuracy
            total_accuracy += accuracy
            total_accuracy_count += 1
    for class_key, gender_accuracies in accuracy_by_class_gender.items():
        if 'Male' in gender_accuracies and 'Female' in gender_accuracies:
            accuracy_difference = abs(gender_accuracies['Male'] - gender_accuracies['Female'])
            total_accuracy_difference_sum += accuracy_difference
            total_accuracy_difference_count += 1
    demographic_parities = []
    for class_key, counts in class_gender_counts.items():
        if 'Male' in counts and 'Female' in counts:
            P_yk_given_a0 = counts['Female'] / class_gender_totals[class_key]['Female'] if class_gender_totals[class_key]['Female'] > 0 else 0
            P_yk_given_a1 = counts['Male'] / class_gender_totals[class_key]['Male'] if class_gender_totals[class_key]['Male'] > 0 else 0
            demographic_parity = abs(P_yk_given_a0 - P_yk_given_a1)*100
            demographic_parities.append(demographic_parity)
    # Calculate overall averages
    average_accuracy = (total_accuracy / total_accuracy_count) if total_accuracy_count else 0
    average_accuracy_difference = (total_accuracy_difference_sum / total_accuracy_difference_count) if total_accuracy_difference_count else 0
    print(f"Average Accuracy: {average_accuracy * 100:.2f}%")
    print(f"Average of Accuracy Differences: {average_accuracy_difference * 100:.2f}%")
    mean_demographic_parity = np.mean(demographic_parities) if demographic_parities else 0
    print(f"Mean Demographic Parity: {mean_demographic_parity:.2f}")
    
    return average_accuracy, mean_demographic_parity
