import argparse
import torch
import tiktoken
from model import GPT, inference
from difflib import SequenceMatcher
import re
import os
from datetime import datetime

def load_list(file_path):
    try:
        with open(file_path, 'r') as f:
            return [line.strip().lower() for line in f if line.strip()]
    except FileNotFoundError:
        return []

def load_companies(file_path):
    try:
        with open(file_path, 'r') as f:
            companies = []
            extra = []
            for line in f:
                if line.strip():
                    parts = line.strip().split(',')
                    if len(parts) >= 2: 
                        companies.extend(parts[:1])
                        extra.extend(parts[1:]) 
            return companies, extra
    except FileNotFoundError:
        return []

cities = load_list('data/cities.txt')
universities = load_list('data/universities.txt')
majors = load_list('data/majors.txt')
companies, extra = load_companies('data/companies.csv')
extra = list(set(extra))
def extract_birth_date(text):
    months = r'(january|february|march|april|may|june|july|august|september|october|november|december)'
    pattern = rf'{months}\s+(\d{{1,2}})(?:,)?\s+(\d{{4}})'
    
    sentences = [s.strip() for s in text.split('. ') if s.strip()]
    for sentence in sentences:
        sentence_lower = sentence.lower()
        match = re.search(pattern, sentence_lower)
        if match:
            month, day, year = match.groups()
            date_str = f"{month} {day}, {year}" 
            try:
                return datetime.strptime(date_str, '%B %d, %Y').date()
            except ValueError:
                print(f"Failed to parse birth date: {date_str}")
                return None
    return None
def check_answer_coverage(response, second_half, keyword_coverage_threshold=0.85):
    if not response or not second_half:
        return False
    response_sentences = [s.strip() for s in response.split('. ') if s.strip()]
    if not response_sentences:
        return False
    response_lower = response.lower()
    second_half_lower = second_half.lower()
    
    city_total = 0
    city_success = 0
    for city in cities:
        if city in response_lower:
            city_total += 1
            if city in second_half_lower:
                city_success += 1
    
    univ_total = 0
    univ_success = 0
    for univ in universities:
        if univ in response_lower:
            univ_total += 1
            if univ in second_half_lower:
                univ_success += 1
    
    major_total = 0
    major_success = 0
    for major in majors:
        if major in response_lower:
            major_total += 1
            if major in second_half_lower:
                major_success += 1
    
    company_total = 0
    company_success = 0
    for company_keyword in companies:
        if company_keyword in response:
            company_total += 1
            if company_keyword in second_half:
                company_success += 1
    extra_total = 0
    extra_success = 0
    
    for extra_keyword in extra:
        if extra_keyword in response:
            extra_total += 1
            if extra_keyword in second_half:
                extra_success += 1
    response_birth_date = extract_birth_date(response)
    second_half_birth_date = extract_birth_date(second_half)
    birth_date_match = response_birth_date == second_half_birth_date and response_birth_date is not None
    total_items = city_total + univ_total + major_total + company_total + extra_total + (1 if response_birth_date else 0)
    success_items = city_success + univ_success + major_success + company_success + extra_success + (1 if birth_date_match else 0)
    coverage_ratio = success_items / total_items if total_items > 0 else 0
    
    print(f"Cities: {city_success}/{city_total}")
    print(f"Universities: {univ_success}/{univ_total}")
    print(f"Majors: {major_success}/{major_total}")
    print(f"Companies: {company_success}/{company_total}")
    print(f"Companiy Extra Info: {extra_success}/{extra_total}")
    print(f"Birth Date Match: {birth_date_match} (Response: {response_birth_date}, Second Half: {second_half_birth_date})")
    print(f"Coverage ratio: {coverage_ratio:.3f}")
    details = {
        "cities": {"total": city_total, "success": city_success},
        "universities": {"total": univ_total, "success": univ_success},
        "majors": {"total": major_total, "success": major_success},
        "companies": {"total": company_total, "success": company_success},
        "company_extra": {"total": extra_total, "success": extra_success},
        "birth_date": {"total": 1 if response_birth_date else 0, "success": 1 if birth_date_match else 0}
    }
    return {
        "result": coverage_ratio >= keyword_coverage_threshold, 
        "details": details
    }

def load_list(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return {line.strip().lower() for line in f if line.strip()}
    except FileNotFoundError:
        print(f"Error: Could not find file '{file_path}'")
        return set()

first_names = load_list('data/first_names.txt')
middle_names = load_list('data/middle_names.txt')
last_names = load_list('data/last_names.txt')

def split_sentence(sentence):
    if not sentence:
        return None, None
    words = sentence.strip('\n').split()
    if not words:
        return None, None
    person_name = []
    for i in range(len(words)):
        if words[i].lower() in first_names:
            person_name = [words[i]]
            if i + 1 < len(words) and words[i + 1].lower() in middle_names:
                person_name.append(words[i + 1])
                if i + 2 < len(words):
                    last_word = words[i + 2].lower()
                    if last_word.endswith("'s") and last_word[:-2] in last_names:
                        person_name.append(words[i + 2][:-2])
                        break
                    elif last_word in last_names:
                        person_name.append(words[i + 2])
                        break
    if person_name:
        second_half = sentence.strip()
        return ' '.join(person_name), second_half
    return None, sentence.strip()


def get_sentences_from_file(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = []
            for _ in range(2000):
                try:
                    line = next(f).strip()
                    if line:
                        lines.append(line)
                except StopIteration:
                    break
            return lines if lines else []
    except FileNotFoundError:
        print(f"Error: Could not find file '{file_path}'")
        return []
    except Exception as e:
        print(f"Error reading file: {str(e)}")
        return []


def main():
    parser = argparse.ArgumentParser(description="Sentence Completion Evaluation with GPT")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPT.from_pretrained("/data/temp_log3/xs_pretrain_small_3/state_step139976.pt", device)
    enc = tiktoken.get_encoding("gpt2")
    
    print("\nModel loaded! Starting sentence completion evaluation...")

    total_sentences = 0
    correct_predictions = 0
    
    file_path = "hallucinate_small/pretrain_perturbed3.txt"
    all_sentences = get_sentences_from_file(file_path)
    
    if not all_sentences:
        print("No valid sentences found in the file.")
        return
    sentences_to_process = all_sentences
    cumulative_details = {
        "cities": {"total": 0, "success": 0},
        "universities": {"total": 0, "success": 0},
        "majors": {"total": 0, "success": 0},
        "companies": {"total": 0, "success": 0},
        "company_extra": {"total": 0, "success": 0},
        "birth_date": {"total": 0, "success": 0}
    }
    for sentence in sentences_to_process:
        first_half, second_half = split_sentence(sentence)
        if not first_half or not second_half:
            continue
        total_sentences += 1
        try:
            response = inference(
                model=model,
                input_text=first_half,
                tokenizer=enc,
                max_new_tokens=100,
                stop_token=198,
                temperature=0,
            )
            # print(response)
            # print(second_half)
            if response.startswith(first_half):
                response = response[len(first_half):].strip('\n')
            result_dict = check_answer_coverage(response, second_half)
            result = result_dict["result"]
            details = result_dict["details"]
            for category in cumulative_details:
                cumulative_details[category]["total"] += details[category]["total"]
                cumulative_details[category]["success"] += details[category]["success"]
            
            if result:
                correct_predictions += 1
                
        except Exception as e:
            print(f"Error processing sentence '{first_half}': {str(e)}")
    
    if total_sentences > 0:
        accuracy = correct_predictions / total_sentences
        print(f"\nEvaluation Summary:")
        print(f"Total Sentences: {total_sentences}")
        print(f"Correct Predictions: {correct_predictions}")
        print(f"Accuracy: {accuracy:.2%}")
        for category, stats in cumulative_details.items():
            print(f"{category.capitalize()}: {stats['success']}/{stats['total']}")
    else:
        print("No valid sentences found for processing.")

if __name__ == "__main__":
    main()