import argparse
import json
import torch
import tiktoken
from tqdm import tqdm 
from datetime import datetime
from model import GPT, inference  
import logging
import sys
import random
import re
from multiprocessing import Pool
import os

def setup_logging(log_file):
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[
            logging.FileHandler(log_file),  
            logging.StreamHandler(sys.stdout)  
        ]
    )

def get_profiles_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    if isinstance(data, dict):
        return [data]
    elif isinstance(data, list):
        return data
    else:
        raise ValueError("not JSON")

def process_line(line):
    line = line.strip()
    if line.startswith("Q:") and "A:" in line:
        q, a = line.split("A:", 1)
        q = q.replace("Q:", "").strip()
        a = a.strip()
        return (q, a)
    return None

def load_qa_from_file(file_path):
    qa_pairs = []
    nprocs = max(1, os.cpu_count() - 2)
    with open(file_path, 'r', encoding='utf-8') as f:
        with Pool(nprocs) as pool:
            for result in pool.imap(process_line, f, chunksize=nprocs):
                if result is not None:
                    qa_pairs.append(result)
    return qa_pairs

def check_attribute_coverage(model_output, true_value):
    true_value_str = str(true_value).lower()
    return true_value_str in model_output.lower()

def evaluate_profiles(profiles, model, enc, attributes_to_test, desc, new_model=None):
    cumulative_details = {attr: {"total": 0, "success": 0} for attr in attributes_to_test}
    cumulative_details_new = {attr: {"total": 0, "success": 0} for attr in attributes_to_test} if new_model else None
    total_profiles = 0
    binary_results_model = {attr: [] for attr in attributes_to_test}
    binary_results_new_model = {attr: [] for attr in attributes_to_test}

    sample_size = min(200, len(profiles))
    sampled_profiles = random.sample(profiles, sample_size) if sample_size < len(profiles) else profiles

    for profile in tqdm(sampled_profiles, desc=desc):
        full_name = profile.get("full_name", "")
        total_profiles += 1
        prompt = full_name
        true_values = {attr: profile[attr] for attr in attributes_to_test if attr in profile}

        try:
            response = inference(
                model=model,
                input_text=prompt,
                tokenizer=enc,
                max_new_tokens=10000,  
                stop_token=198,
                temperature=0,
            )
            logging.info(f"Prompt: {prompt}")        
            logging.info(f"Model Output: {response}")  
            logging.info(str(profile))                

            new_response = None
            if new_model:
                new_response = inference(
                    model=new_model,
                    input_text=prompt,
                    tokenizer=enc,
                    max_new_tokens=10000,  
                    stop_token=198,
                    temperature=0,
                )
                logging.info(f"New Model Output: {new_response}")

            for attr in attributes_to_test:
                if attr in true_values:
                    true_value = true_values[attr]
                    cumulative_details[attr]["total"] += 1
                    success_m = check_attribute_coverage(response, true_value)
                    if success_m:
                        cumulative_details[attr]["success"] += 1
                    binary_results_model[attr].append("1" if success_m else "0")

                    if new_model:
                        cumulative_details_new[attr]["total"] += 1
                        success_nm = check_attribute_coverage(new_response, true_value)
                        if success_nm:
                            cumulative_details_new[attr]["success"] += 1
                        binary_results_new_model[attr].append("1" if success_nm else "0")
                else:
                    binary_results_model[attr].append("")
                    if new_model:
                        binary_results_new_model[attr].append("")

        except Exception as e:
            logging.error(f"Error processing {prompt}: {str(e)}")  

    return cumulative_details, total_profiles, binary_results_model, binary_results_new_model, cumulative_details_new

def main():
    parser = argparse.ArgumentParser(description="Attribute Prediction Evaluation with GPT")
    parser.add_argument("-m", "--model", type=str, required=True, help="Load model from this path")
    parser.add_argument("-nm", "--new_model", type=str, help="Load new model from this path for comparison")
    parser.add_argument("-f", "--file_path", type=str, required=True, help="Path to profiles JSON file")
    parser.add_argument("-nf", "--new_file_path", type=str, required=True, help="Path to profiles new JSON file")
    parser.add_argument("-s", "--sft_path", type=str, help="Path to SFT QA file")
    parser.add_argument("-l", "--log_dir", type=str, required=True, help="Directory for log and record files")
    args = parser.parse_args()

    log_file = os.path.join(args.log_dir, "log.txt")
    record_file = os.path.join(args.log_dir, "record.txt")
    setup_logging(log_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPT.from_pretrained(args.model, device)  
    new_model = GPT.from_pretrained(args.new_model, device) if args.new_model else None
    enc = tiktoken.get_encoding("gpt2")

    logging.info("\nModel loaded! Starting attribute prediction evaluation...") 

    attributes_to_test = [
        "birth_date", "birth_city", "birth_day", "birth_month", "birth_year",
        "university", "major", "employer", "company_city",
        "pronoun",
    ]

    try:
        all_profiles = get_profiles_from_file(args.file_path)
    except Exception as e:
        logging.error(f"Cannot read {args.file_path}: {str(e)}") 
        return

    if not all_profiles:
        logging.warning("No valid profiles found in the first file.")  
        return

    cumulative_details_old, total_profiles_old, binary_model, binary_new_model, cumulative_details_new = evaluate_profiles(
        all_profiles, model, enc, attributes_to_test, "Evaluating old profiles", new_model
    )

    with open(record_file, 'w', encoding='utf-8') as rf:
        for attr in attributes_to_test:
            rf.write(f"Attribute: {attr} - Original Model Binary Results:\n")
            rf.write(" ".join(binary_model[attr]) + "\n")
            if new_model and binary_new_model[attr]:
                rf.write(f"Attribute: {attr} - New Model Binary Results:\n")
                rf.write(" ".join(binary_new_model[attr]) + "\n")
            rf.write("\n")
    logging.info(f"Binary comparison results written to {record_file}")

    if total_profiles_old > 0:
        logging.info(f"\nOld Profiles ({args.file_path}) - Original Model:")
        logging.info(f"Total Profiles: {total_profiles_old}")
        logging.info("\nAttribute Accuracy (Original Model):")
        for attr, stats in cumulative_details_old.items():
            total = stats["total"]
            success = stats["success"]
            accuracy = success / total if total > 0 else 0
            logging.info(f"{attr.capitalize()}: {success}/{total} ({accuracy:.2%})")
    else:
        logging.warning("No valid profiles processed in the old file (Original Model).")

    if new_model and total_profiles_old > 0:
        logging.info(f"\nOld Profiles ({args.file_path}) - New Model:")
        logging.info(f"Total Profiles: {total_profiles_old}")
        logging.info("\nAttribute Accuracy (New Model):")
        for attr, stats in cumulative_details_new.items():
            total = stats["total"]
            success = stats["success"]
            accuracy = success / total if total > 0 else 0
            logging.info(f"{attr.capitalize()}: {success}/{total} ({accuracy:.2%})")
    elif new_model:
        logging.warning("No valid profiles processed in the old file (New Model).")

    try:
        all_new_profiles = get_profiles_from_file(args.new_file_path)
    except Exception as e:
        logging.error(f"Cannot read {args.new_file_path}: {str(e)}") 
        return

    if not all_new_profiles:
        logging.warning("No valid profiles found in the new file.")  
        return

    cumulative_details_new_profiles, total_profiles_new, _, _, _ = evaluate_profiles(
        all_new_profiles, model, enc, attributes_to_test, "Evaluating new profiles"
    )

    if total_profiles_new > 0:
        logging.info(f"\nNew Profiles ({args.new_file_path}):")
        logging.info(f"Total Profiles: {total_profiles_new}")
        logging.info("\nAttribute Accuracy:")
        for attr, stats in cumulative_details_new_profiles.items():
            total = stats["total"]
            success = stats["success"]
            accuracy = success / total if total > 0 else 0
            logging.info(f"{attr.capitalize()}: {success}/{total} ({accuracy:.2%})")
    else:
        logging.warning("No valid profiles processed in the new file.")

    if args.sft_path:
        try:
            qa_pairs = load_qa_from_file(args.sft_path)
            sample_size = min(200, len(qa_pairs))
            sampled_qa_pairs = random.sample(qa_pairs, sample_size) if sample_size < len(qa_pairs) else qa_pairs
            logging.info(f"Loaded {len(qa_pairs)} QA pairs, sampled {sample_size} for evaluation.")
        except Exception as e:
            logging.error(f"Cannot Read {args.sft_path}: {str(e)}")
            sampled_qa_pairs = []
    else:
        sampled_qa_pairs = []

    total_qa = 0
    success_qa = 0
    for question, true_answer in tqdm(sampled_qa_pairs, desc="Evaluating QA pairs"):
        prompt = f"Q: {question} A:"
        total_qa += 1

        try:
            response = inference(
                model=model,
                input_text=prompt,
                tokenizer=enc,
                max_new_tokens=10000,  
                stop_token=198,
                temperature=0,
            )
            logging.info(f"QA Prompt: {prompt}")        
            logging.info(f"Model Output: {response}")  
            logging.info(f"True Answer: {true_answer}")

            if check_attribute_coverage(response, true_answer):
                success_qa += 1

        except Exception as e:
            logging.error(f"Error processing QA {prompt}: {str(e)}")

    if total_qa > 0:
        accuracy_qa = success_qa / total_qa
        logging.info(f"\nEvaluation Summary for QA Pairs:")
        logging.info(f"Total QA Pairs: {total_qa}")
        logging.info(f"Correct Predictions: {success_qa}")
        logging.info(f"Overall Accuracy: {accuracy_qa:.2%}")
    else:
        logging.warning("No valid QA pairs processed.")

if __name__ == "__main__":
    main()