import argparse
import torch
import tiktoken
from model import GPT, inference
from difflib import SequenceMatcher  
import re

def extract_words(text):
    text = re.sub(r'[^\w\s]', '', text.lower())
    words = set(text.split())
    return words

def check_answer_coverage(model_answer, expected_answer, threshold=0.9):
    model_words = extract_words(model_answer)
    expected_words = extract_words(expected_answer)
    
    if not expected_words:
        return False
    
    common_words = model_words & expected_words
    coverage = len(common_words) / len(expected_words)
    
    return coverage >= threshold

def main():
    parser = argparse.ArgumentParser(description="Interactive Q&A with GPT model")
    # parser.add_argument("-i", "--model_path", type=str, required=True, help="Path to the model checkpoint")
    # parser.add_argument("-t", "--text_path", type=str, required=True, help="Path to QA")
    args = parser.parse_args()

    # Initialize model and tokenizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # print(f"Loading model from {args.model_path}...")
    # model = GPT.from_pretrained("/data/temp_log2/xs_pretrain_small_2/state_step093669.pt", device)
    model = GPT.from_pretrained("/data/temp_log4/xs_pretrain_small_4/state_step046251.pt", device)
    enc = tiktoken.get_encoding("gpt2")
    
    print("\nModel loaded! Starting QA evaluation...")

    total_questions = 0
    correct_answers = 0
    
    try:
        with open("hallucinate_small/pretrain4.txt", 'r', encoding='utf-8') as f:
            lines = f.readlines()[:200]
            
        for line in lines:
            line = line.strip()
            if not line:
                continue
                
            if not line.startswith('Q:') or 'A:' not in line:
                print(f"Skipping malformed line: {line}")
                continue
                
            question = line.split('A:')[0].replace('Q:', '').strip()
            expected_answer = line.split('A:')[1].strip()
            
            total_questions += 1
            
            try:
                response = inference(
                    model=model,
                    input_text=question,
                    tokenizer=enc,
                    max_new_tokens=100,
                    stop_token=198,
                    temperature=0,
                )
                
                if response.startswith(question):
                    response = response[len(question):].strip()
                    
                is_valid = check_answer_coverage(response, expected_answer)
                
                # Uncomment the following lines to print detailed information about each question               
                # print(f"\nQuestion: {question}")
                # print(f"Model Answer: {response}")
                # print(f"Expected Answer: {expected_answer}")
                # print(f"Model Words: {extract_words(response)}")
                # print(f"Expected Words: {extract_words(expected_answer)}")
                
                if is_valid:
                    correct_answers += 1
                    # print("Result: VALID")
                else:
                    correct_answers += 0
                    # print("Result: INVALID")
                    
            except Exception as e:
                print(f"Error processing question '{question}': {str(e)}")
                
        if total_questions > 0:
            accuracy = correct_answers / total_questions
            print(f"\nEvaluation Summary:")
            print(f"Total Questions: {total_questions}")
            print(f"Correct Answers: {correct_answers}")
            print(f"Accuracy: {accuracy:.2%}")
        else:
            print("No valid questions found in the file.")
            
    except FileNotFoundError:
        print(f"Error: Could not find file at {args.text_path}")
    except Exception as e:
        print(f"Error reading file: {str(e)}")

if __name__ == "__main__":
    main()