import os
import pickle
from PIL import Image
from tqdm import tqdm
import numpy as np
import pandas as pd
import argparse

import sys; sys.path.append(".")

from modules.logdet_utils import normalize_embedding, get_normL1_prob, compute_logdet, \
    compute_probL1_logdet, get_normalized_entropy, compute_eigenscore


parser = argparse.ArgumentParser()
parser.add_argument('--generation_file', type=str, required=True,
                    help='Path to the generation file')
parser.add_argument('--output_dir', type=str, required=True,
                    help='Directory to save the output files')
parser.add_argument('--jitter', type=float, default=1e-8,
                    help='Jitter value for numerical stability in logdet computation')

# parser.add_argument('')

args = parser.parse_args()

file_path = args.generation_file
if os.path.isfile(file_path):
    with open(file_path, 'rb') as r:
        llava_results = pickle.load(r)
image_df = pd.DataFrame().from_dict(llava_results)

# Ensure the 'embedding' column exists
if 'internal_embedding' in image_df.columns:
    image_df = image_df.rename(columns={'internal_embedding': 'embedding'})
if 'embedding' not in image_df.columns:
    raise ValueError("The 'embedding' column is missing from the DataFrame.")

### Compute Uncertainty Metrics ###
# Normalize embeddings
image_df['norm_embedding'] = image_df['embedding'].apply(normalize_embedding)

# compute logdet
image_df['logdet'] = image_df['norm_embedding'].apply(lambda x: compute_logdet(np.matmul(x, x.T), alpha=args.jitter))

# Compute adaptive prob_alpha
prob_values = image_df['generations_log_likelihood'].apply(get_normL1_prob)
logdet_values = image_df['logdet']
prob_alpha = np.abs(logdet_values.median() / prob_values.median())
prob_param = prob_alpha
print("apdative prob alpha", prob_param)

# Compute probL1_logdet
image_df['umpire'] = image_df.apply(lambda x: compute_probL1_logdet(x, alpha=prob_param), axis=1)  

# Save the updated DataFrame with uncertainty metrics
image_df.to_pickle(os.path.join(args.output_dir, 'image_df_with_uncertainty.pkl'))
