from utils.myutils import AIClient, errorMessage
import json
from tqdm import tqdm
from collections import defaultdict
import concurrent.futures
import math

class TaskSuccessEvaluator:
    def __init__(self, config, save_folder):
        self.config = config
        self.model = config["diagnose_model"]
        self.client = AIClient(model=self.model, max_tokens=128)
        self.save_folder = save_folder
        self.incremental_diagnosis = (config["incremental_diagnosis"].lower()=="true") if "incremental_diagnosis" in config else None
        print("Incremental Diagnosis:", self.incremental_diagnosis)

    def diagnose(self, task, extracted_info, choices, max_retries=3):
        system_prompt = "You are a medical professional, making decisions based on the information provided. "
        user_prompt = (
            f"Based on the following information: \n\" {extracted_info}\"\n"
            f"and the candidate options: {choices}:\n"
            f"{task}"
            "DO NOT provide further advice or details. Only respond with the exact option name and NOTHING ELSE."
        )
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]

        for attempt in range(max_retries):
            try:
                diagnosis = self.client.get_response(messages)
                if diagnosis:
                    return diagnosis.strip()
            except Exception as e:
                errorMessage(f"Attempt {attempt + 1} diagnosis failed: {str(e)}", self.save_folder)
        
        return "Diagnosis failed after multiple attempts."

    def load_log_data(self, log_path):
        data_list = []
        with open(log_path, 'r', encoding='utf-8') as file:
            for line in file:
                data = json.loads(line)
                data_list.append(data)
        return data_list

    def calculate_accuracy_for_case(self, data, incremental=False):
        task = data['task']
        choices = data['choices']
        ground_truth = data['answer']

        interactions = data.get('interactions', None)
        if interactions:
            extracted_info = '\n'.join([f"Q: {q}\nA: {a}" for q, a in interactions])
        else:
            extracted_info = data.get('limit_info', '')

        if incremental and interactions:
            incremental_results = []

            for n in range(1, len(interactions), 2):
                partial_interaction = interactions[:n]
                diag_n = self.diagnose(task, '\n'.join([f"Q: {q}\nA: {a}" for q, a in partial_interaction]), choices)
                success_n = (diag_n == ground_truth)
                incremental_results.append({
                    "round": n,
                    "diagnosis": diag_n,
                    "is_successful": success_n
                })
            
            return (
                data['case_id'],
                data.get('category', None),
                ground_truth,
                incremental_results
            )
        else:
            diagnosis = self.diagnose(task, extracted_info, choices)
            is_successful = diagnosis == ground_truth
            return (
                data['case_id'],
                data.get('category', None),
                ground_truth,
                [{"round": len(interactions) if interactions else 1, "diagnosis": diagnosis, "is_successful": is_successful}]
            )

    def calculate_robustness(self, category_scores):
        category_acc_dict = {cat: sum(s)/len(s) if s else 0 for cat, s in category_scores.items()}
        counts = [len(s) for s in category_scores.values()]
        acc_values = list(category_acc_dict.values())
        total_counts = sum(counts) if counts else 1

        weighted_mean_acc = sum(acc * count for acc, count in zip(acc_values, counts)) / total_counts

        variance = sum(count * (acc - weighted_mean_acc) ** 2 for acc, count in zip(acc_values, counts)) / total_counts
        std_dev = math.sqrt(variance)

        epsilon = 1e-6  # A small delta to prevent instability
        robustness = 1 - std_dev / (weighted_mean_acc + std_dev + epsilon)

        return round(robustness, 3)

    def evaluate(self, log_path):
        # Step 1: Load and parse the log data
        data_list = self.load_log_data(log_path)

        # Step 2: Accuracy calculations
        all_results = []
        category_scores = defaultdict(list)
        
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.calculate_accuracy_for_case, data, self.incremental_diagnosis) for data in data_list]

            for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Task Success Evaluation"):
                case_id, category, ground_truth, incremental_results = future.result()

                all_results.append({
                    "case_id": case_id,
                    "ground_truth": ground_truth,
                    "category": category,
                    "diagnosis": incremental_results,
                    "is_successful": incremental_results[-1]['is_successful']
                })
                
                category_scores[category].append(incremental_results[-1]['is_successful'])

        # Step 3: Calculate accuracy and robustness
        accuracy = sum([result['is_successful'] for result in all_results]) / len(all_results) if all_results else 0.0
        robustness = self.calculate_robustness(category_scores)

        # Step 4: Return results
        return {
            "avg_scores": {
                "accuracy": round(accuracy, 3),
                "robustness": round(robustness, 3)
            },
            "details": all_results
        }
    
    