import requests
import re
from pathlib import Path
from openai import OpenAI
from config.keys import get_openai_key, LlamaClient
from utils.prompts import generate_rule_verification_prompt
from utils.data_wrangling import calculate_accuracy_stats
import logging
import os 
from datetime import datetime
import pandas as pd
import pickle

class RuleVerifier:
    def __init__(self):
        self.logger = self._create_logger()
        self.openai_client = OpenAI(api_key=get_openai_key())
        self.model_name = "o4-mini"
        self.llama_endpoint = self.llama_client = LlamaClient(self.logger) 
        self.logger.info(f' We are working with {self.model_name} as our backend.')

    def _create_logger(self):
        log_dir = "logs"
        os.makedirs(log_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_filename = f"logger_{timestamp}.log"
        log_filepath = os.path.join(log_dir, log_filename)
        logger = logging.getLogger(__name__)
        logger.setLevel(logging.DEBUG)
        if logger.hasHandlers():
            logger.handlers.clear()
        fh = logging.FileHandler(log_filepath)
        fh.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s  - %(filename)s - %(module)s.%(funcName)s -  %(message)s')
        fh.setFormatter(formatter)
        logger.addHandler(fh)
        return logger

    def verify_rule(self, rule):
        prompt_info = generate_rule_verification_prompt(rule)
        
        # Get OpenAI response
        openai_response = self._ask_openai(prompt_info["prompt"])
        normalized_model = self._normalize_response(openai_response)
        
        # Get Llama verification
        llama_prompt = self._build_llama_prompt(prompt_info, openai_response)
        llama_response = self._verify_with_llama(llama_prompt)
        normalized_llama = self._normalize_response(llama_response)
        
        # Check correctness
        is_correct = self._check_correctness(
            prompt_info, 
            openai_response, 
            llama_response
        )
        
        # Detailed logging
        self.logger.info("\n\n")  # Two empty lines for separation
        self.logger.info("========================= Rule Verification Details Start ====================")
        self.logger.info(f"Rule Type: {prompt_info['type']}")
        self.logger.info("\n--- OpenAI Prompt ---")
        self.logger.info(prompt_info["prompt"])
        self.logger.info("\n--- OpenAI Response ---")
        self.logger.info(f"Original: {openai_response}")
        self.logger.info(f"Normalized: {normalized_model}")
        self.logger.info("\n--- Llama Prompt ---")
        self.logger.info(llama_prompt)
        self.logger.info("\n--- Llama Response ---")
        self.logger.info(f"Original: {llama_response}")
        self.logger.info(f"Normalized: {normalized_llama}")
        self.logger.info("\n--- Verification Result ---")
        self.logger.info(f"Correctness: {'CORRECT' if is_correct else 'INCORRECT'}")
        self.logger.info("================ End of Verification =============")
        self.logger.info("\n\n")  # Two empty lines for separation
        
        return {
            "rule": rule,
            "type": prompt_info["type"],
            "openai_response": openai_response,
            "verification": llama_response,
            "is_correct": is_correct,
            "openai_input_prompt": prompt_info["prompt"],
            "model":self.model_name
        }
    
    def _ask_openai(self, prompt):
        try:
            response = self.openai_client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "You are a logical reasoning assistant."},
                    {"role": "user", "content": prompt}
                ]
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            self.logger.info(f"OpenAI API error: {e} for prompt {prompt}")
            return "ERROR"
    
    def _verify_with_llama(self, llama_prompt):
        result = self.llama_client.generate(llama_prompt)
        if not result:
            raise ValueError(f"Empty verification response from Llama for {llama_prompt}.")
        return result
    
    def _build_llama_prompt(self, prompt_info, model_response):
        if prompt_info["type"] == "definite_rule":
            expected = prompt_info["expected_head"]
            return (
                "VERIFICATION TASK: Rule Head Matching\n\n"
                f"Original Rule Head: {expected}\n"
                f"Model Proposed Head: {model_response}\n\n"
                "Evaluation Criteria:\n"
                "1. Predicate names must match the original.\n"
                "2. Variable names/order must match the original.\n"
                "3. Small formatting differences (spaces, commas, the name fo the preedicate not exactly the same but very similar in meaning) are acceptable. \n\n"
                "If the model's head correctly captures the relationship from the original rule, "
                "respond with exactly: [Equivalent]\n"
                "If there are any substantive differences, respond with exactly: [Not Equivalent]\n"
                "Example: For original 'parent(X,Y)' and proposed 'parent(X,Y)', respond: [Equivalent]"
                "Append one sentence to describe your reasoning."
            )
        
        elif prompt_info["type"] == "constraint":
            body = prompt_info['prompt'].split('following atoms are true:')[1].split('Can this')[0].strip()
            return (
                "VERIFICATION TASK: Evaluation of model response\n\n"
                f"Constraint Body: {body}\n"
                f"Model Evaluation: {model_response}\n\n"
                "Evaluation Criteria:\n"
                "1. The correct evaluation from the model should be [Impossible].\n"
                "2. Accept [Impossible] or [Inpossible] (common typo) or any model response that is equivalent in meaning to impossible.\n"
                "If the model's evaluation is essentially correct, "
                "you must start your response : [Correct]\n"
                "If the model's evaluation is incorrect or unclear, "
                "you must start your response :: [Incorrect]"
                "Append one sentence to describe your reasoning. Remember you want to verify if the model has evaluated that this body is impossible."
            )
        
        else:  # fact
            fact = prompt_info['prompt'].split('statement:\n')[1].split('\n\nIs this')[0].strip()
            return (
                "VERIFICATION TASK: Fact Validation\n\n"
                f"Fact Statement: {fact}\n"
                f"Model Evaluation: {model_response}\n\n"
                "Evaluation Criteria:\n"
                "1. Facts should generally evaluate to [True]\n"
                "2. Accept [True] or [Correct] or any model response that is equivalent in meaning to True.\n\n"
                "If the model's evaluation is essentially correct "
                "you must start your response : [Correct]  .\n"
                "If the model's evaluation is incorrect or unclear, "
                "you must start your response :: [Incorrect]  ."
                "Append one sentence to describe your reasoning."
            )
    
    def _check_correctness(self, prompt_info, model_response, llama_verification):
        normalized_llama = self._normalize_response(llama_verification)
        normalized_model = self._normalize_response(model_response)
        
        if prompt_info["type"] == "definite_rule":
            expected = self._normalize_response(prompt_info["expected_head"])
            model_match = (
                expected == normalized_model or
                f"({expected})" in normalized_model
            )
            llama_match = "equivalent" in normalized_llama
            return model_match or llama_match
        
        elif prompt_info["type"] == "constraint":
            model_match = (
                "impossible" in normalized_model or
                "inpossible" in normalized_model
            )
            llama_match = "correct" in normalized_llama
            return model_match or llama_match
        
        else:  # fact
            model_match = "true" in normalized_model
            llama_match = "correct" in normalized_llama
            return model_match or llama_match
    
    def _normalize_response(self, response):
        if not response:
            return ""
        
        # Convert to lowercase and remove extra spaces
        response = " ".join(str(response).lower().split())
        
        # Remove all non-alphanumeric except parentheses and commas (for predicates)
        response = re.sub(r'[^a-z0-9_(), ]', '', response)
        
        # Handle common variations
        response = response.replace("equivalent", "equivalence")  # Standardize
        response = response.replace("correctness", "correct")
        response = response.replace("incorrectness", "incorrect")
        
        # Remove extra spaces around punctuation
        response = re.sub(r'\s+([(),])', r'\1', response)
        response = re.sub(r'([(),])\s+', r'\1', response)
        
        return response.strip()
    
    def verify_rules_file(self, file_path):
        with open(file_path, 'r') as f:
            rules = [r.strip() for r in f.read().split('\n') 
                    if r.strip() and not r.strip().startswith('%')]        
        results = []
        for rule in rules:
            result = self.verify_rule(rule)
            results.append(result)   
        results = pd.DataFrame(results)
        calculate_accuracy_stats(results, self.logger)     
        self.logger.info(f''' \n=========================\n Here are final results {results}            ''')
        return results

# Usage
if __name__ == "__main__":
    ## NoRA-------------------------------------------------------------
    # verifier = RuleVerifier()
    # results = verifier.verify_rules_file("data/world_rules.txt")
    # # Save results
    # timestamp = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    # # Create the filename
    # filename = f"results/result_{timestamp}.pkl"
    # # Save the DataFrame as pickle
    # results.to_pickle(filename)
    ## baselines-------------------------------------------------------------
    verifier = RuleVerifier()
    results = verifier.verify_rules_file("data/world_rules_baseline_constr.txt")
    # Save results
    timestamp = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    # Create the filename
    filename = f"results/result_{timestamp}.pkl"
    # Save the DataFrame as pickle
    results.to_pickle(filename)
