from openai import OpenAI
import os
import argparse
import logging
from utils import *
from PIL import Image
from data import COCO
from google import genai
from google.genai.types import HttpOptions
import cohere


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--victim_model", type=str, choices=["gpt", "aya", "gemini"])
    parser.add_argument("--source_model", type=str, choices=["qwen", "llava", "glm4.1v-thinking"])
    parser.add_argument("--images_dir", type=str)
    parser.add_argument("--explain", action='store_true')
    parser.add_argument("--coco_path", type=str)
    parser.add_argument("--cache_path", type=str)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--no_chunks", type=int, default=0)
    parser.add_argument("--chunk_id", type=int, default=0)
    args = parser.parse_args()
    return args

def get_log_name(args):
    return f"{args.victim_model}_from_{args.source_model}_explain={args.explain}_chunk={args.chunk_id}-{args.no_chunks}"


def get_api_client(model):
    if model == "gpt":
        client = OpenAI()
    elif model == "gemini":
        client = genai.Client()
    elif model == "aya":
        client = cohere.ClientV2()
    return client


def get_api_response(client, model, prompt, image_path):
    if model == "gpt":
        image = Image.open(image_path).convert('RGB')
        return get_gpt_output(client, prompt, image)
    elif model == "gemini":
        return get_gemini_output(client, prompt, image_path)
    elif model == "aya":
        return get_aya_output(client, prompt, image_path)


def main(args):
    # Load Images 
    image_paths = []
    objects = os.listdir(args.images_dir)
    for obj in objects:
        temp = []
        folder_dir = os.path.join(args.images_dir, obj, "images")
        for img in os.listdir(folder_dir):
            if img.endswith(('.png', '.jpg', '.jpeg')):
                temp.append(os.path.join(folder_dir, img))
        temp.sort()  # ensure stable ordering
        logger.info(f"{obj}: {len(temp)} images")
        image_paths.extend(temp)
    
    image_paths.sort()  # stable global ordering
    total_images = len(image_paths)
    logger.info(f"Total images found: {total_images}")
    if args.no_chunks > 0:
        chunk_size = total_images // args.no_chunks
        start_idx = args.chunk_id * chunk_size
        end_idx = min((args.chunk_id + 1) * chunk_size, total_images)
        image_paths = image_paths[start_idx:end_idx]
        logger.info(f"Processing chunk {args.chunk_id}/{args.no_chunks}: images {start_idx} to {end_idx}")
    else:
        logger.info("Processing all images without chunking.")
    # Initialize OpenAI API
    client = get_api_client(args.victim_model)
    # Load CLIP
    logger.info("Loading CLIP Model...")
    clip_model, clip_preprocess = get_clip_model(args.cache_path)
    clip_model.eval().cuda()
    tokenizer = open_clip.get_tokenizer('ViT-H-14')
    # Load COCO
    logger.info("Loading COCO...")
    dset = COCO(args.coco_path, split='train')
    present = [x  for cat in dset.get_all_supercategories() for x in dset.get_categories(cat) ]

    cat_spur_all_dic = {}
    

    hallucination_count = 0
    original_hallucination_count = 0
    for image_path in image_paths:
        logger.info(f"Processing image: {image_path}")

        # Extract object name from image path
        object_name = image_path.split('/')[-3].lower()
        prompt = f"Is there a {object_name} in the image? Answer yes or no."
        if args.explain:
            prompt += " Explain your answer."
        logger.info(f"Prompt: {prompt}")
        # Get response from victim model
        response = get_api_response(client, args.victim_model, prompt, image_path)
        logger.info(f"Victim model response: {response}")
        if response and "yes" in response.lower():
            hallucination_count += 1
            logger.info(f"Hallucination detected for image: {image_path}")
        
        if object_name in cat_spur_all_dic:
            cat_spur_all = cat_spur_all_dic[object_name]
        else:
            cat_spur_all = dset.get_imgIds_by_class(present_classes=present, absent_classes=[object_name])
            cat_spur_all, _ = load_and_compute_similarity(clip_model, cat_spur_all, object_name, embeddings_path="clip_embeddings.pt")
            cat_spur_all_dic[object_name] = cat_spur_all
        image_idx = int(image_path.split('_')[0].split('/')[-1])
        image, path = dset[int(cat_spur_all[image_idx])]
        # Get resoponse for the original image
        response = get_api_response(client, args.victim_model, prompt, path)
        logger.info(f"Original image path: {path}")
        logger.info(f"Victim model response on original image: {response}")
        if "yes" in response.lower():
            logger.info(f"Hullicnation detected for original image")
            original_hallucination_count += 1
        logger.info("--------------------------------------------------")

    logger.info(f"Total original hallucinations: {original_hallucination_count} out of {len(image_paths)}")
    logger.info(f"Total hallucinations: {hallucination_count} out of {len(image_paths)}")
    logger.info(f"Hallucination Rate: {hallucination_count / len(image_paths) * 100}%")
    logger.info(f"Victim Model: {args.victim_model}, Source Model: {args.source_model}")

    



if __name__ == "__main__":
    args = get_args()

    logging_level = logging.INFO

    # create folder
    os.makedirs(f"logs", exist_ok=True)
    os.makedirs(f"logs/transfer", exist_ok=True)

    logging.basicConfig(format="### %(message)s ###")

    logger = logging.getLogger("Tranfer_API")
    logger.setLevel(level=logging_level)

    logger.addHandler(logging.FileHandler(f"logs/transfer/{get_log_name(args)}.txt", mode='w'))

    # Setting Seed
    set_seed(args.seed)

    logger.info(get_log_name(args))
    logger.info(f"Arguments: {args}")

    main(args)