import os
import re
import pickle
import logging
from datetime import datetime
from pathlib import Path
import pandas as pd
from openai import OpenAI
from config.keys import get_openai_key
from utils.task_recall_complete_prompts import generate_recall_and_completion
from utils.data_wrangling import calculate_accuracy_stats
from sklearn.metrics import precision_recall_fscore_support
from math import sqrt
class ReCallAndCompletionTask:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.logger = self._create_logger()
        self.openai_client = OpenAI(api_key=get_openai_key())
        self.model_name = "o4-mini"
        self.logger.info(f'Working with {self.model_name} as our backend.')
        self.llm_api_errors = 0
        self.missing_query_label = 0
        self.missing_rules_used = 0

    def _create_logger(self):
        log_dir = os.path.join(self.data_dir, "logs")
        os.makedirs(log_dir, exist_ok=True)
        log_filename = f"logger_{self.timestamp}.log"
        log_path = os.path.join(log_dir, log_filename)

        logging.basicConfig(filename=log_path,
                            level=logging.INFO,
                            format='%(asctime)s %(levelname)s %(message)s')
        return logging.getLogger("ReCallAndCompletionTask")

    def _ask_openai(self, prompt):
        try:
            response = self.openai_client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "You are a logical reasoning assistant."},
                    {"role": "user", "content": prompt}
                ]
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            self.logger.info(f"OpenAI API error: {e} for prompt: {prompt}")
            self.llm_api_errors += 1
            return "ERROR"

    def verify_result(self, openai_response, correct_completion, correct_recall):
        '''
        Task 1: Query Completion

            Given a story (a collection of entity pairs and their predicates) and a query (a pair of entities), the LLM is asked to predict the correct predicate (i.e., the relationship) between the query entities.

            The model's prediction is compared to the ground truth label in the dataset (query_relation).

            The outcome is binary: success if the predicted predicate matches the correct one, and failure otherwise.

            The overall success rate and its 95% confidence interval (using a binomial approximation) are reported.

        Task 2: Rule Recall

            Each example also has a known set of world rules that are required to derive the query predicate (these are logic rules indexed in the dataset).

            The LLM is asked to return a set of rule indices it believes were used to derive the answer to Task 1.

            The predicted rule set is compared against the ground truth using precision, recall, and F1-score, defined as:

                Precision = True Positives / (True Positives + False Positives)

                Recall = True Positives / (True Positives + False Negatives)

                F1-score = Harmonic mean of precision and recall

            The report provides the mean and 95% confidence intervals for these metrics across all examples.

        Columns in rule_recall_summary_<timestamp>.pkl

            rule_recall_precision_mean: Average precision score across all examples for Task 2.

            rule_recall_precision_ci: 95% confidence interval for the precision score, estimated using the standard error.

            rule_recall_recall_mean: Average recall score across all examples for Task 2.

            rule_recall_recall_ci: 95% confidence interval for the recall score.

            rule_recall_fscore_mean: Average F1-score across all examples for Task 2.

            rule_recall_fscore_ci: 95% confidence interval for the F1-score.

            query_completion_success_rate: Proportion of examples where the model correctly completed the query predicate in Task 1.

            query_completion_success_rate_ci: 95% confidence interval for the success rate in Task 1, using the binomial approximation.

            llm_api_error_rate: Proportion of examples where the OpenAI API call failed (e.g., network error, unsupported parameters).

            missing_query_label_rate: Proportion of examples where the LLM response did not contain a recognizable query_label: line.

            missing_rules_used_rate: Proportion of examples where the LLM response did not contain a recognizable rules_used: line.
        '''
        pattern_label = re.search(r"query_label:\s*(.*)", openai_response)
        pattern_rules = re.search(r"rules_used:\s*\{([^}]*)\}", openai_response)

        if not pattern_label:
            self.missing_query_label += 1
        if not pattern_rules:
            self.missing_rules_used += 1

        predicted_label = pattern_label.group(1).strip() if pattern_label else ""
        predicted_rules = set(map(int, pattern_rules.group(1).split(','))) if pattern_rules else set()

        success_label = predicted_label == correct_completion

        true_rules = set(correct_recall)
        tp = len(predicted_rules & true_rules)
        fp = len(predicted_rules - true_rules)
        fn = len(true_rules - predicted_rules)

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

        return {
            "llm_label": predicted_label,
            "llm_rules": predicted_rules,
            "success_label": success_label,
            "precision": precision,
            "recall": recall,
            "f_score": f_score,
            "reasoning": self._extract_reasoning(openai_response),
            "raw_response": openai_response
        }

    def _extract_reasoning(self, response):
        match = re.search(r"reasoning:\s*(.*)", response, re.DOTALL)
        return match.group(1).strip() if match else ""

    def solve_examples(self, num_rows=12):
        results_dir = os.path.join(self.data_dir, "results")
        os.makedirs(results_dir, exist_ok=True)

        path = os.path.join(self.data_dir, "derivations_to_query.pkl")
        df = pd.read_pickle(path)
        if num_rows:
            sample_num = min(num_rows, df.shape[0])
            df = df.sample(n=sample_num, random_state=42).reset_index(drop=True)

        detailed_report = []
        for idx, row in df.iterrows():
            info = generate_recall_and_completion(row, self.data_dir)
            response = self._ask_openai(info['prompt'])
            self.logger.info(f"Example Info: source_file={row['source_file']}, story_index={row['story_index']}")
            self.logger.info(f"OPen AI Response:\n{response}")
            self.logger.info(f"Correct Response should be:\n{info['correct_completion']} \n {info['correct_recall']} \n {row['derivation_chain']} \n \n ")
            results = self.verify_result(response, info['correct_completion'], info['correct_recall'])

            detailed_report.append({
                "story": row['story_edges'],
                "story_index": row['story_index'],
                "source_file": row['source_file'],
                "query_edge": row['query_edge'],
                "query_relation": row['query_relation'],
                "set_of_world_used": row['set_of_world_used'],
                "llm_label": results['llm_label'],
                "llm_rules": results['llm_rules'],
                "llm_response": response,
                "input_prompt_to_llm": info['prompt'],
                "reasoning": results['reasoning'],
                "precision": results['precision'],
                "recall": results['recall'],
                "f_score": results['f_score'],
                "ASPDerivation":row['derivation_chain'],
                "successful_query_completion": results['success_label']
            })

        df_report = pd.DataFrame(detailed_report)
        detailed_report_path = os.path.join(results_dir, f"detailed_report_{self.timestamp}.pkl")
        df_report.to_pickle(detailed_report_path)

        # Global statistics
        success_rate = df_report['successful_query_completion'].mean()
        success_ci = 1.96 * sqrt(success_rate * (1 - success_rate) / len(df_report))
        prec_mean = df_report['precision'].mean()
        prec_sd = df_report['precision'].std()
        recall_mean = df_report['recall'].mean()
        recall_sd = df_report['recall'].std()
        f1_mean = df_report['f_score'].mean()
        f1_sd = df_report['f_score'].std()

        global_report = pd.DataFrame([{
            "rule_recall_precision_mean": prec_mean,
            "rule_recall_precision_sd": prec_sd,
            "rule_recall_precision_ci": 1.96 * prec_sd / (len(df_report) ** 0.5),
            "rule_recall_recall_mean": recall_mean,
            "rule_recall_recall_sd": recall_sd,
            "rule_recall_recall_ci": 1.96 * recall_sd / (len(df_report) ** 0.5),
            "rule_recall_fscore_mean": f1_mean,
            "rule_recall_fscore_sd": f1_sd,
            "rule_recall_fscore_ci": 1.96 * f1_sd / (len(df_report) ** 0.5),
            "query_completion_success_rate": success_rate,
            "query_completion_success_rate_ci": success_ci,
            "llm_api_error_rate": self.llm_api_errors / len(df_report),
            "missing_query_label_rate": self.missing_query_label / len(df_report),
            "missing_rules_used_rate": self.missing_rules_used / len(df_report),
            "model-used": self.model_name
        }])

        global_report_path = os.path.join(results_dir, f"rule_recall_summary_{self.timestamp}.pkl")
        global_report.to_pickle(global_report_path)

        self.logger.info(f" \n GLOBAL REPORT:\n"
                 f"Success rate: {success_rate*100:.2f}% ± {success_ci*100:.2f}%\n"
                 f"Precision: {prec_mean:.3f} ± {1.96 * prec_sd / (len(df_report) ** 0.5):.3f}\n"
                 f"Recall: {recall_mean:.3f} ± {1.96 * recall_sd / (len(df_report) ** 0.5):.3f}\n"
                 f"F1: {f1_mean:.3f} ± {1.96 * f1_sd / (len(df_report) ** 0.5):.3f}")
        
        self.logger.info(f"LLM API Errors: {self.llm_api_errors}")
        self.logger.info(f"Missing query_label in response: {self.missing_query_label}")
        self.logger.info(f"Missing rules_used in response: {self.missing_rules_used}")
        print("Evaluation complete.")

def find_example_dirs(base_dir="ASPdata", opec_values=None):
    """
    Find all valid example directories that contain the required pickle files.
    
    Args:
        base_dir (str): The root directory to search in (default: "ASPdata")
        opec_values (list): List of OPEC values to look for (default: [0, 3])
        
    Returns:
        list: A list of Path objects pointing to valid example directories
        
    The function looks for directories in the pattern:
    {base_dir}/OPEC{i}_examples/chain_len*
    and checks if they contain these four required pickle files:
    - derivations_to_query.pkl
    - world_rule_body_index.pkl
    - world_rule_head_index.pkl
    - world_rule_index.pkl
    """
    if opec_values is None:
        opec_values = [0, 3,4,5]
        
    example_dirs = []
    required_files = {
        'derivations_to_query.pkl',
        'world_rule_body_index.pkl',
        'world_rule_head_index.pkl',
        'world_rule_index.pkl'
    }
    base_dir = Path(base_dir).absolute() 
    for i in opec_values:
        # Construct the OPEC examples directory path
        opec_dir = Path(base_dir) / f"OPEC{i}_examples"
        
        if not opec_dir.exists():
            print(f"Warning: Directory {opec_dir} does not exist")
            continue
            
        # Find all chain_len* subdirectories
        for chain_dir in opec_dir.glob("chain_len*"):
            if not chain_dir.is_dir():
                continue
                
            # Check if all required files exist
            files_present = {f.name for f in chain_dir.glob("*.pkl")}
            if required_files.issubset(files_present):
                example_dirs.append(str(chain_dir.absolute()))
            else:
                missing_files = required_files - files_present
                print(f"Warning: Directory {chain_dir} is missing files: {missing_files}")
    
    return example_dirs

if __name__ == "__main__":
    ### test
    # p_dir = 'ASPdata/OPEC0_examples/chain_len2.0'
    # verifier = ReCallAndCompletionTask(p_dir)
    # verifier.solve_examples(num_rows=2)
    ### real deal 
    example_dirs = find_example_dirs(opec_values=[0, 3])
    print(example_dirs)
    example_dirs = sorted(example_dirs, key=lambda x: float(x.split('chain_len')[-1].split('/')[0]))
    print(example_dirs)
    # # for p_dir in example_dirs:
    # #     print(f"\nProcessing directory: {p_dir}")
    # #     try:
    # #         verifier = ReCallAndCompletionTask(p_dir)
    # #         verifier.solve_examples()
    # #     except Exception as e:
    # #         print(f"Error processing {p_dir}: {str(e)}")
    # # print(example_dirs)
    # # print(str(example_dirs[0]))