import argparse
import json
import torch
import tiktoken
from model import GPT, inference
from tqdm import tqdm 
from datetime import datetime
import logging
import sys
import random
import re

# Import templates (assuming templates.py is in the same directory)
from templates import (
    birth_date_question_templates,
    birth_city_question_templates,
    university_question_templates,
    major_question_templates,
    employer_question_templates,
    company_city_question_templates
)

def setup_logging(log_file):
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[
            logging.FileHandler(log_file),  
            logging.StreamHandler(sys.stdout)  
        ]
    )

def get_profiles_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    if isinstance(data, dict):
        return [data]
    elif isinstance(data, list):
        return data
    else:
        raise ValueError("not JSON")

def check_attribute_coverage(model_output, true_value):
    true_value_str = str(true_value).lower()
    return true_value_str in model_output.lower()

def fill_template(template, profile):
    """Fill a template with profile information."""
    pronoun = profile.get("pronoun", "They").lower()
    possessive_pronoun = "their" if pronoun in ["they", "them"] else "his" if pronoun == "he" else "her"
    full_name = profile.get("full_name", "")
    
    return template.format(
        full_name=full_name,
        pronoun=pronoun.capitalize(),
        possessive_pronoun=possessive_pronoun
    )

def evaluate_profiles(profiles, model, enc, attributes_to_test, desc, question_templates):
    cumulative_details_ques = {attr: {"total": 0, "success": 0} for attr in attributes_to_test}
    total_profiles = 0

    sample_size = min(1501, len(profiles))
    sampled_profiles = random.sample(profiles, sample_size) if sample_size < len(profiles) else profiles

    for profile in tqdm(sampled_profiles, desc=desc):
        total_profiles += 1
        true_values = {attr: profile[attr] for attr in attributes_to_test if attr in profile}

        # Question-Based Prompting
        for attr, true_value in true_values.items():
            # Select the appropriate question templates for the attribute
            templates = question_templates.get(attr, [])
            if not templates:
                continue

            # Randomly select one template
            selected_template = random.choice(templates)
            prompt = fill_template(selected_template, profile)

            try:
                response = inference(
                    model=model,
                    input_text=prompt,
                    tokenizer=enc,
                    max_new_tokens=100000,  
                    stop_token=198,
                    temperature=0,
                )
                logging.info(f"Question Prompt: {prompt}")        
                logging.info(f"Model Output: {response}")  
                logging.info(f"True Value for {attr}: {true_value}")

                cumulative_details_ques[attr]["total"] += 1
                if check_attribute_coverage(response, true_value):
                    cumulative_details_ques[attr]["success"] += 1

            except Exception as e:
                logging.error(f"Error processing question prompt {prompt}: {str(e)}")

    return cumulative_details_ques, total_profiles

def main():
    parser = argparse.ArgumentParser(description="Attribute Prediction Evaluation with GPT")
    parser.add_argument("-m", "--model", type=str, required=True, help="Load model from this path")
    parser.add_argument("-f", "--file_path", type=str, required=True, help="Path to profiles JSON file")
    parser.add_argument("-nf", "--new_file_path", type=str, required=True, help="Path to profiles new JSON file")
    args = parser.parse_args()

    setup_logging(sys.argv[0] + ".log")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPT.from_pretrained(args.model, device)  
    enc = tiktoken.get_encoding("gpt2")

    logging.info("\nModel loaded! Starting attribute prediction evaluation...") 

    attributes_to_test = [
        "birth_date", "birth_city",
        "university", "major",
        "employer", "company_city",
    ]

    # Map attributes to their question templates
    question_templates = {
        "birth_date": birth_date_question_templates,
        "birth_city": birth_city_question_templates,
        "university": university_question_templates,
        "major": major_question_templates,
        "employer": employer_question_templates,
        "company_city": company_city_question_templates,
    }

    # Evaluate first profiles file
    try:
        all_profiles = get_profiles_from_file(args.file_path)
    except Exception as e:
        logging.error(f"Unable to read file {args.file_path}: {str(e)}") 
        return

    if not all_profiles:
        logging.warning("No valid profiles found in the first file.")  
        return

    cumulative_details_old_ques, total_profiles_old = evaluate_profiles(
        all_profiles, model, enc, attributes_to_test, "Evaluating old profiles", question_templates
    )

    # Evaluate second profiles file
    try:
        all_new_profiles = get_profiles_from_file(args.new_file_path)
    except Exception as e:
        logging.error(f"Unable to read file {args.new_file_path}: {str(e)}") 
        return

    if not all_new_profiles:
        logging.warning("No valid profiles found in the new file.")  
        return

    cumulative_details_new_ques, total_profiles_new = evaluate_profiles(
        all_new_profiles, model, enc, attributes_to_test, "Evaluating new profiles", question_templates
    )

    # Output results for first profiles (question prompting)
    if total_profiles_old > 0:
        logging.info(f"\nOld Profiles ({args.file_path}) - Question Prompting:")
        logging.info(f"Total Profiles: {total_profiles_old}")
        logging.info("\nAttribute Accuracy:")
        for attr, stats in cumulative_details_old_ques.items():
            total = stats["total"]
            success = stats["success"]
            accuracy = success / total if total > 0 else 0
            logging.info(f"{attr.capitalize()}: {success}/{total} ({accuracy:.2%})")
    else:
        logging.warning("No valid profiles processed in the old file (question).")

    # Output results for second profiles (question prompting)
    if total_profiles_new > 0:
        logging.info(f"\nNew Profiles ({args.new_file_path}) - Question Prompting:")
        logging.info(f"Total Profiles: {total_profiles_new}")
        logging.info("\nAttribute Accuracy:")
        for attr, stats in cumulative_details_new_ques.items():
            total = stats["total"]
            success = stats["success"]
            accuracy = success / total if total > 0 else 0
            logging.info(f"{attr.capitalize()}: {success}/{total} ({accuracy:.2%})")
    else:
        logging.warning("No valid profiles processed in the new file (question).")

if __name__ == "__main__":
    main()