import argparse
import json
import torch
import tiktoken
from tqdm import tqdm 
from datetime import datetime
from model import GPT, inference  
import logging
import sys

def setup_logging(log_file):
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[
            logging.FileHandler(log_file),  
            logging.StreamHandler(sys.stdout)  
        ]
    )

def get_profiles_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    if isinstance(data, dict):
        return [data]
    elif isinstance(data, list):
        return data
    else:
        raise ValueError("not JSON")

def check_attribute_coverage(model_output, true_value):
    true_value_str = str(true_value).lower()
    return true_value_str in model_output.lower()

def main():
    parser = argparse.ArgumentParser(description="Attribute Prediction Evaluation with GPT")
    parser.add_argument("-m", "--model", type=str, required=True, help="Load model from this path")
    parser.add_argument("-f", "--file_path", type=str, required=True, help="Path to profiles JSON file")
    args = parser.parse_args()

    setup_logging(sys.argv[0] + ".log")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPT.from_pretrained(args.model, device)  
    enc = tiktoken.get_encoding("gpt2")

    logging.info("\nModel loaded! Starting attribute prediction evaluation...") 

    attributes_to_test = [
        "birth_date", "birth_city", "birth_day", "birth_month", "birth_year",
        "university", "major", "employer", "company_city",
        "pronoun",
    ]
    
    cumulative_details = {attr: {"total": 0, "success": 0} for attr in attributes_to_test}

    file_path = args.file_path
    try:
        all_profiles = get_profiles_from_file(file_path)
    except Exception as e:
        logging.error(f"Cannot Read {file_path}: {str(e)}") 
        return

    if not all_profiles:
        logging.warning("No valid profiles found in the file.")  
        return

    sample_size = min(5000, len(all_profiles))
    import random
    sampled_profiles = random.sample(all_profiles, sample_size) if sample_size < len(all_profiles) else all_profiles

    total_profiles = 0

    for profile in tqdm(sampled_profiles, desc="Evaluating profiles"):
        full_name = profile.get("full_name", "")
        total_profiles += 1
        prompt = full_name
        true_values = {attr: profile[attr] for attr in attributes_to_test if attr in profile}

        try:
            response = inference(
                model=model,
                input_text=prompt,
                tokenizer=enc,
                max_new_tokens=200,  
                stop_token=198,
                temperature=0,
            )
            logging.info(f"Prompt: {prompt}")        
            logging.info(f"Model Output: {response}")  
            logging.info(str(profile))                

            for attr, true_value in true_values.items():
                cumulative_details[attr]["total"] += 1
                if check_attribute_coverage(response, true_value):
                    cumulative_details[attr]["success"] += 1

        except Exception as e:
            logging.error(f"Error processing {prompt}: {str(e)}")  

    if total_profiles > 0:
        logging.info(f"\nEvaluation Summary:")
        logging.info(f"Total Profiles: {total_profiles}")
        logging.info("\nAttribute Accuracy:")
        for attr, stats in cumulative_details.items():
            total = stats["total"]
            success = stats["success"]
            accuracy = success / total if total > 0 else 0
            logging.info(f"{attr.capitalize()}: {success}/{total} ({accuracy:.2%})")
    else:
        logging.warning("No valid profiles processed.")

if __name__ == "__main__":
    main()