import torch
import clip
from torch.utils.data import DataLoader
import random
from collections import defaultdict
import pandas as pd
import numpy as np
from tqdm import tqdm
from glob import glob
from pycocoevalcap.spice.spice import Spice
import ast
import re
import os
from pprint import pprint
import subprocess
import threading
from sklearn.utils import resample
from pycocoevalcap.meteor import meteor as meteor_module  # Import the module

# Locate the directory where the 'meteor.py' is located
meteor_dir = os.path.dirname(os.path.abspath(meteor_module.__file__))
# Construct the full path to 'meteor-1.5.jar'
METEOR_JAR = os.path.join(meteor_dir, 'meteor-1.5.jar')
class Meteor:
    def __init__(self):
        self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
                '-', '-', '-stdio', '-l', 'en', '-norm']
        self.meteor_p = subprocess.Popen(self.meteor_cmd, \
                # cwd=os.path.dirname(os.path.abspath(__file__)), \
                cwd = meteor_dir,\
                stdin=subprocess.PIPE, \
                stdout=subprocess.PIPE, \
                stderr=subprocess.PIPE)
        # Used to guarantee thread safety
        self.lock = threading.Lock()

    def compute_score(self, gts, res):
        assert(gts.keys() == res.keys())
        imgIds = gts.keys()
        scores = []

        eval_line = 'EVAL'
        self.lock.acquire()
        for i in imgIds:
            assert(len(res[i]) == 1)
            stat = self._stat(res[i][0], gts[i])
            eval_line += ' ||| {}'.format(stat)

        self.meteor_p.stdin.write('{}\n'.format(eval_line).encode())
        self.meteor_p.stdin.flush()
        for i in range(0,len(imgIds)):
            scores.append(float(self.meteor_p.stdout.readline().strip()))
        score = float(self.meteor_p.stdout.readline().strip())
        self.lock.release()

        return score, scores

    def method(self):
        return "METEOR"

    def _stat(self, hypothesis_str, reference_list):
        # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
        hypothesis_str = hypothesis_str.replace('|||','').replace('  ',' ')
        score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
        score_line = score_line.replace('\n', '').replace('\r', '')
        self.meteor_p.stdin.write('{}\n'.format(score_line).encode())
        self.meteor_p.stdin.flush()
        return self.meteor_p.stdout.readline().decode().strip()

    def _score(self, hypothesis_str, reference_list):
        self.lock.acquire()
        # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
        hypothesis_str = hypothesis_str.replace('|||','').replace('  ',' ')
        score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
        score_line = score_line.replace('\n', '').replace('\r', '')
        self.meteor_p.stdin.write('{}\n'.format(score_line))
        stats = self.meteor_p.stdout.readline().strip()
        eval_line = 'EVAL ||| {}'.format(stats)
        # EVAL ||| stats
        self.meteor_p.stdin.write('{}\n'.format(eval_line))
        score = float(self.meteor_p.stdout.readline().strip())
        # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
        # thanks for Andrej for pointing this out
        score = float(self.meteor_p.stdout.readline().strip())
        self.lock.release()
        return score

    def __del__(self):
        self.lock.acquire()
        self.meteor_p.stdin.close()
        self.meteor_p.kill()
        self.meteor_p.wait()
        self.lock.release()

def misclassification_rate(df):
    total_males = df[df['ground_truth_gender'] == 'Male'].shape[0]
    total_females = df[df['ground_truth_gender'] == 'Female'].shape[0]
    
    male_lowconfidence = df[
        (df['ground_truth_gender'] == 'Male') & 
        (df['detected_gender'] == 'Female')
    ].shape[0]
    female_lowconfidence = df[
        (df['ground_truth_gender'] == 'Female') & 
        (df['detected_gender'] == 'Male')
    ].shape[0]
    
    male_misclassification_rate = (
        male_lowconfidence / total_males if total_males > 0 else 0
    )
    female_misclassification_rate = (
        female_lowconfidence / total_females if total_females > 0 else 0
    )
    
    overall_misclassification_rate = (
        (male_lowconfidence + female_lowconfidence) / 
        (total_males + total_females) 
        if (total_males + total_females) > 0 
        else 0
    )
    
    # Calculate Composite Misclassification Rate (MR_C)
    composite_misclassification_rate = np.sqrt(
        overall_misclassification_rate**2 + 
        (female_misclassification_rate - male_misclassification_rate)**2
    )
    
    # Compute absolute difference (in percentage points) 
    # between male & female misclassification rates
    abs_diff = abs(male_misclassification_rate - female_misclassification_rate)

    return {
        'Male Misclassification Rate': round(male_misclassification_rate * 100, 2),
        'Female Misclassification Rate': round(female_misclassification_rate * 100, 2),
        'Overall Misclassification Rate': round(overall_misclassification_rate * 100, 2),
        'Composite Misclassification Rate': round(composite_misclassification_rate * 100, 2),
        'Absolute Difference': round(abs_diff * 100, 2)  # new key
    }

# Function to convert string list representations to actual lists
def convert_str_to_list(str_list):
    try:
        return ast.literal_eval(str_list)
    except ValueError:
        return []  # Returns an empty list in case of error


def neutralize_gender(text):
    """ Neutralize gendered words in the given text. """
    gendered_words = {
        r"\bman\b": "person", r"\bguy\b": "person", r"\bson\b": "child", r"\bboy\b": "child",
        r"\bwoman\b": "person", r"\blady\b": "person",
        r"\bmen\b": "people", r"\bwomen\b": "people"
    }
    for word, neutral in gendered_words.items():
        text = re.sub(word, neutral, text, flags=re.IGNORECASE)
    return text


def evaluate_captions_max(df):
    """ Evaluate captions taking the maximum score between original and neutralized ground truths. """
    gts = {}
    res = {}
    gts_neutral = {}
    for i, row in df.iterrows():
        original_gts = row['gt_captions']
        gts[i] = original_gts
        res[i] = [row['generated_text']]
        # Neutralize each caption in the ground truths
        gts_neutral[i] = [neutralize_gender(caption) for caption in original_gts]
        
    scorers = [
        (Meteor(), ["METEOR"]),
        (Spice(), ["SPICE"])
    ]

    results = {method[0]: 0 for scorer, method in scorers}

    for scorer, method in tqdm(scorers):
        score_orig, scores_orig = scorer.compute_score(gts, res)
        score_neutral, scores_neutral = scorer.compute_score(gts_neutral, res)
        
        if method[0] == "SPICE":
            # Extract the F1 scores from the SPICE results
            f_scores_orig = [score['All']['f'] for score in scores_orig]
            f_scores_neutral = [score['All']['f'] for score in scores_neutral]
            max_scores = [max(orig, neut) for orig, neut in zip(f_scores_orig, f_scores_neutral)]
        else:
            max_scores = [max(orig, neut) for orig, neut in zip(scores_orig, scores_neutral)]
        results[method[0]] = sum(max_scores) / len(max_scores)
    return results
def bootstrap(df, num_samples=100, sample_size=10000):
    bootstrap_results = []
    sample_size=len(df)
    for _ in tqdm(range(num_samples)):
        sample_df = resample(df, n_samples=sample_size)
        rates, results = report_df(sample_df)
        
        # Add absolute difference of misclassification rates for convenience
        male_mr = rates['Male Misclassification Rate']
        female_mr = rates['Female Misclassification Rate']
        rates['Absolute Difference'] = abs(male_mr - female_mr)
        
        bootstrap_results.append((rates, results))
    return bootstrap_results


def calculate_confidence_intervals(bootstrap_results, confidence_level=0.95):
    # Add new metric name here: 'Absolute Difference'
    metrics = [
        'Male Misclassification Rate', 
        'Female Misclassification Rate', 
        'Overall Misclassification Rate',
        'Composite Misclassification Rate', 
        'METEOR', 
        'SPICE',
        'Absolute Difference'
    ]
    
    ci_lower = {}
    ci_upper = {}
    
    for metric in metrics:
        # Some metrics are in the first dictionary (rates), some are in the second (results)
        values = [
            result[0][metric] if metric in result[0] else result[1][metric] 
            for result in bootstrap_results
        ]
        
        lower_bound = np.percentile(values, (1 - confidence_level) / 2 * 100)
        upper_bound = np.percentile(values, (1 + confidence_level) / 2 * 100)
        ci_lower[metric] = lower_bound
        ci_upper[metric] = upper_bound
        
    return ci_lower, ci_upper

def report_df(df):
    df['gt_captions'] = df['gt_captions'].apply(convert_str_to_list)
    rates = misclassification_rate(df)
    results = evaluate_captions_max(df)

    return rates, results

# Function to calculate mean and margin
def mean_margin(lower, upper):
    mean = (lower + upper) / 2
    margin = (upper - lower) / 2
    return mean, margin

def evaluate_image_captioning(file_path):
    columns = ['File', 'Male Misclassification Rate', 'Female Misclassification Rate',
               'Overall Misclassification Rate', 'Composite Misclassification Rate', 'METEOR', 'SPICE','|Male-Female|']

    print(f'Evaluating Image Captioning for {file_path}')
    df = pd.read_csv(file_path)
    
    # Run bootstrapping and calculate confidence intervals
    bootstrap_results = bootstrap(df)
    ci_lower, ci_upper = calculate_confidence_intervals(bootstrap_results)
    
   
    # Calculate mean and margin for each metric
    abs_diff_mean, abs_diff_margin = mean_margin(
        ci_lower['Absolute Difference'], 
        ci_upper['Absolute Difference']
    )
    male_mis_mean, male_mis_margin = mean_margin(ci_lower['Male Misclassification Rate'], ci_upper['Male Misclassification Rate'])
    female_mis_mean, female_mis_margin = mean_margin(ci_lower['Female Misclassification Rate'], ci_upper['Female Misclassification Rate'])
    overall_mis_mean, overall_mis_margin = mean_margin(ci_lower['Overall Misclassification Rate'], ci_upper['Overall Misclassification Rate'])
    composite_mis_mean, composite_mis_margin = mean_margin(ci_lower['Composite Misclassification Rate'], ci_upper['Composite Misclassification Rate'])
    meteor_mean, meteor_margin = mean_margin(ci_lower['METEOR']*100, ci_upper['METEOR']*100)
    spice_mean, spice_margin = mean_margin(ci_lower['SPICE']*100, ci_upper['SPICE']*100)
    
    new_row = {
        'file_path': file_path,
        'Male Misclassification Rate': f"{male_mis_mean:.2f} ± {male_mis_margin:.2f}",
        'Female Misclassification Rate': f"{female_mis_mean:.2f} ± {female_mis_margin:.2f}",
        'Overall Misclassification Rate': f"{overall_mis_mean:.2f} ± {overall_mis_margin:.2f}",
        'Composite Misclassification Rate': f"{composite_mis_mean:.2f} ± {composite_mis_margin:.2f}",
        'METEOR': f"{meteor_mean:.2f} ± {meteor_margin:.2f}",
        'SPICE': f"{spice_mean:.2f} ± {spice_margin:.2f}",
        '|Male-Female|': f"{abs_diff_mean:.2f} ± {abs_diff_margin:.2f}"
    }
    # Load existing results if the file exists
    output_file = file_path.replace(".csv","_eval.csv")
    try:
        results_df = pd.read_csv(output_file)
    except FileNotFoundError:
        results_df = pd.DataFrame(columns=new_row.keys())
    
    # Append new row to the dataframe
    results_df = pd.concat([results_df, pd.DataFrame([new_row])], ignore_index=True)
    
    # Save results back to CSV
    results_df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")

    pprint(new_row)
    