import warnings
warnings.filterwarnings("ignore") 

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='logit', type=str)
parser.add_argument('--gpu_id', default='0', type=str)
parser.add_argument('--neutral',default=False, action=argparse.BooleanOptionalAction,help='True: Neutralization, False: Bias Alignment')
parser.add_argument('--lam', default=2, type=float)
parser.add_argument('--t', default=0.1, type=float,help='threshold tau') 
parser.add_argument('--pred_cap_path', default="external/clipcap/oscar_preds.pkl", type=str)
parser.add_argument('--image_dir', default="../data/COCO/val2014", type=str)

args = parser.parse_args()
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)  # Maps GPU 2 to index 0

import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"  # Use the remapped device index

import sys
sys.path.append('./')

import numpy as np
import random
import pandas as pd 
import pickle
import json
import clip
from PIL import Image
from joblib import dump, load
from tqdm import tqdm
import json
from sklearn.linear_model import LogisticRegression
from transformers import GPT2Tokenizer
from evaluation import evaluate_image_captioning
from external.clipcap import clipcap_model
from external.clipcap.clipcap_utils import decide_gender, generate
from utils import load_and_normalize_beta,predict_gender
# Ensure that the necessary resources are downloaded
import nltk
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('omw-1.4')


clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

with open('../data/COCO/annotations/captions_val2014.json', 'r') as json_data:
    d = json.load(json_data)
annotations = d['annotations']

id_to_captions = {}
for ann in annotations:
    image_id = ann['image_id']
    caption = ann['caption']
    if image_id not in id_to_captions:
        id_to_captions[image_id] = []
    id_to_captions[image_id].append(caption)

imid_2_gender = pickle.load(open('external/clipcap/val_imid_gender.pkl','rb'))
filtered_image_ids = set(imid_2_gender.keys())

filtered_id_to_captions = {image_id: captions for image_id, captions in id_to_captions.items() if image_id in filtered_image_ids}
results = []
remove_id = pd.read_csv("external/remove_df.csv")
remove_id = remove_id['remove_id']

text_important_indices=None
text_mean_features_lowconfidence=None
prefix_length = 10
model_path = 'external/clipcap/clip_cap_coco_weight.pt'
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = clipcap_model.ClipCaptionModel(prefix_length, device=device)
model.load_state_dict(torch.load(model_path),strict=False)

model = model.eval() 
model = model.to(device)
bias_vector = None
results_filename = f"external/clipcap/result/clip_cap_{args.mode}.csv"

if args.mode=='logit':
    results_filename = f"external/clipcap/result/clip_cap_logit_features_{args.neutral}_{args.lam}.csv"
    vocab_size = tokenizer.vocab_size
    max_length = 50
    text_classifier = clipcap_model.SimpleTransformerClassifier(vocab_size=vocab_size, max_length=max_length)
    model_path = os.path.join("../nlp_classification/gender_model_gpt2_pytorch_generated", "pytorch_model.bin")
    text_classifier.load_state_dict(torch.load(model_path))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    text_classifier.to(device)
    text_classifier.eval()
    
    token_bias = load_and_normalize_beta("../nlp_classification/importance_dict_gpt2_pytorch_generated.json")
    
    classifier = LogisticRegression(max_iter=1000, solver='lbfgs')
    embedding = torch.load(f'embedding/fairface_ViTB32_train.pt')
    
    X_train = embedding['image_embeddings']
    y_train = embedding['sensitive_attributes'][:,1]
    
    try:
        classifier = load("image_gender_classifier_clipcap.pkl")
    except:
        classifier.fit(X_train, y_train)
        dump(classifier, "image_gender_classifier_clipcap.pkl")
        classifier = load("image_gender_classifier_clipcap.pkl")

if not os.path.exists(results_filename):
    for image_id, gt_captions in tqdm(filtered_id_to_captions.items()):
        with torch.no_grad():
            if image_id in remove_id:
                continue
            image_path = os.path.join(args.image_dir, f"COCO_val2014_{str(image_id).zfill(12)}.jpg")
            image = Image.open(image_path).convert('RGB')
            ground_truth_gender = imid_2_gender[image_id]
            prefix = clip_model.encode_image(preprocess(image).unsqueeze(0).to(device)).float()
            if args.mode=='logit':
                s_value = predict_gender(classifier, prefix.detach().cpu().numpy())
                generated_text = generate(model, tokenizer, embed=prefix, 
                                          mode=args.mode,s_value=s_value,
                                          text_classifier=text_classifier,token_bias=token_bias,
                                          lam=args.lam,neutral=args.neutral,
                                          threshold=args.t)
            else:
                generated_text = generate(model, tokenizer, embed=prefix)
            detected_gender = decide_gender(nltk.word_tokenize(generated_text))
            results.append({
                'image_id': image_id,
                'ground_truth_gender': ground_truth_gender,
                'detected_gender': detected_gender,
                'gt_captions': gt_captions,
                'generated_text': generated_text
            })
    df = pd.DataFrame(results)
    df.to_csv(results_filename, index=False)
        
else:
    print(f"Results file {results_filename} already exists. Skipping processing.")

evaluate_image_captioning(results_filename)
