#!/usr/bin/env python3
"""Parser for converting sleep COT JSONL files to clean format."""

import json
import re
import sys
import os
from pathlib import Path
from collections import Counter
from tqdm import tqdm

# Add the src directory to the path to import from the dataset class
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'src'))
from time_series_datasets.sleep.SleepEDFCoTQADataset import SleepEDFCoTQADataset

# We'll determine supported labels dynamically from the actual ground truth data
# Start with the dataset class labels as a fallback
FALLBACK_LABELS = SleepEDFCoTQADataset.get_labels()
SUPPORTED_LABELS = []  # Will be populated dynamically

def _canonicalize_label(text):
    """Return canonical label with stage 4 merged into stage 3.

    - Case-insensitive
    - Trims whitespace and trailing period
    - Merges "non-rem stage 4" -> "Non-REM stage 3"
    - Returns (canonical_label_str, is_supported_bool)
    """
    if text is None:
        return "", False

    cleaned = str(text).strip()
    # Remove any end-of-text tokens and trailing period
    cleaned = re.sub(r'<\|.*?\|>|<eos>$', '', cleaned).strip()
    cleaned = re.sub(r'\.$', '', cleaned).strip()

    lowered = cleaned.lower()

    # Normalize common variants and merge stage 4 into stage 3
    if "non-rem" in lowered or "nrem" in lowered:
        # unify spacing/hyphenation
        lowered = lowered.replace("nrem", "non-rem")
        lowered = lowered.replace("non rem", "non-rem")

    # Map stage 4 -> stage 3
    if "non-rem" in lowered and "stage 4" in lowered:
        canonical = "Non-REM stage 3"
    elif "non-rem" in lowered and "stage 3" in lowered:
        canonical = "Non-REM stage 3"
    elif "non-rem" in lowered and "stage 2" in lowered:
        canonical = "Non-REM stage 2"
    elif "non-rem" in lowered and "stage 1" in lowered:
        canonical = "Non-REM stage 1"
    elif "rem" in lowered and "sleep" in lowered:
        canonical = "REM sleep"
    elif lowered in {"wake", "awake"}:
        canonical = "Wake"
    elif "movement" in lowered or lowered == "mov" or lowered == "mt":
        canonical = "Movement"
    else:
        # If it exactly matches a supported label ignoring case, keep it
        # Use fallback labels if supported labels haven't been determined yet
        label_set = SUPPORTED_LABELS if SUPPORTED_LABELS else FALLBACK_LABELS
        maybe = next((lab for lab in label_set if lab.lower() == lowered), "")
        canonical = maybe if maybe else cleaned

    # Use fallback labels if supported labels haven't been determined yet
    label_set = SUPPORTED_LABELS if SUPPORTED_LABELS else FALLBACK_LABELS
    is_supported = canonical in label_set
    return canonical if canonical else cleaned, is_supported
def calculate_f1_score(prediction, ground_truth):
    """Calculate F1 score for single-label classification with supported labels.

    - Merges Non-REM stage 4 into stage 3
    - Only counts predictions within SUPPORTED_LABELS; unsupported predictions yield F1=0
    """
    pred_canon, pred_supported = _canonicalize_label(prediction)
    truth_canon, truth_supported = _canonicalize_label(ground_truth)

    # Exact-match after canonicalization for single-label F1
    f1 = 1.0 if pred_canon == truth_canon else 0.0

    return {
        'f1_score': f1,
        'precision': f1,
        'recall': f1,
        'prediction_normalized': pred_canon.lower().strip(),
        'ground_truth_normalized': truth_canon.lower().strip(),
        'prediction_supported': pred_supported,
        'ground_truth_supported': truth_supported,
    }

def calculate_f1_stats(data_points):
    """Calculate both macro-F1 and average F1 (micro-F1) statistics.

    - Only supported classes are included in the class set
    - Unsupported predictions contribute FN to the ground-truth class but do not
      create or contribute FP to an unsupported predicted class
    """
    if not data_points:
        return {}
    
    # Calculate average F1 (micro-F1) - simple average across all predictions
    f1_scores = [point.get("f1_score", 0) for point in data_points]
    average_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
    
    # Initialize class buckets for only supported classes (lowercased for consistency)
    # Use discovered labels if available, otherwise fall back to dataset labels
    labels_to_use = SUPPORTED_LABELS if SUPPORTED_LABELS else FALLBACK_LABELS
    supported_lower = {label.lower(): label for label in labels_to_use}
    class_predictions = {lab.lower(): {"tp": 0, "fp": 0, "fn": 0} for lab in labels_to_use}

    for point in data_points:
        gt_class = point.get("ground_truth_normalized", "")
        pred_class = point.get("prediction_normalized", "")
        pred_supported = point.get("prediction_supported", False)

        # Ensure ground truth is one of the supported classes; if not, skip counting it
        if gt_class not in class_predictions:
            # Skip entirely as requested: do not include it in the ground truth labels
            continue

        if pred_class == gt_class:
            class_predictions[gt_class]["tp"] += 1
        else:
            # Count FN for the ground truth class
            class_predictions[gt_class]["fn"] += 1
            # Count FP only if predicted class is supported
            if pred_supported and pred_class in class_predictions:
                class_predictions[pred_class]["fp"] += 1

    # Calculate F1 per class
    class_f1_scores = {}
    total_f1 = 0
    valid_classes = 0

    for class_name, counts in class_predictions.items():
        tp, fp, fn = counts["tp"], counts["fp"], counts["fn"]

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        # Use canonical casing in output keys
        pretty_name = supported_lower.get(class_name, class_name)
        class_f1_scores[pretty_name] = {
            "f1": f1,
            "precision": precision,
            "recall": recall,
            "tp": tp,
            "fp": fp,
            "fn": fn,
        }

        total_f1 += f1
        valid_classes += 1

    macro_f1 = total_f1 / valid_classes if valid_classes > 0 else 0

    return {
        "average_f1": average_f1,
        "macro_f1": macro_f1,
        "class_f1_scores": class_f1_scores,
        "total_classes": valid_classes,
    }

def calculate_accuracy_stats(data_points):
    """Calculate accuracy statistics from data points"""
    if not data_points:
        return {}
    
    total = len(data_points)
    correct = sum(1 for point in data_points if point.get("accuracy", False))
    accuracy_percentage = (correct / total) * 100 if total > 0 else 0
    
    return {
        "total_samples": total,
        "correct_predictions": correct,
        "incorrect_predictions": total - correct,
        "accuracy_percentage": accuracy_percentage
    }

def parse_sleep_cot_jsonl(input_file, output_file=None):
    """Parse sleep COT JSONL file and extract JSON objects."""
    if output_file is None:
        input_path = Path(input_file)
        output_file = str(input_path.parent / f"{input_path.stem}.clean.jsonl")
    
    print(f"Parsing {input_file}")
    print(f"Output will be saved to {output_file}")
    
    # First, discover the actual labels from the ground truth data
    global SUPPORTED_LABELS
    discovered_labels = discover_ground_truth_labels(input_file)
    SUPPORTED_LABELS = discovered_labels
    
    print(f"Discovered {len(discovered_labels)} labels from ground truth data:")
    for label in sorted(discovered_labels):
        print(f"  - {label}")
    
    extracted_data = extract_structured_data(input_file)
    
    if extracted_data:
        print(f"Extracted {len(extracted_data)} data points")
        
        # Calculate and display accuracy statistics
        accuracy_stats = calculate_accuracy_stats(extracted_data)
        print(f"\nAccuracy Statistics:")
        print(f"Total samples: {accuracy_stats['total_samples']}")
        print(f"Correct predictions: {accuracy_stats['correct_predictions']}")
        print(f"Incorrect predictions: {accuracy_stats['incorrect_predictions']}")
        print(f"Accuracy: {accuracy_stats['accuracy_percentage']:.2f}%")
        
        # Calculate and display F1 statistics
        f1_stats = calculate_f1_stats(extracted_data)
        print(f"\nF1 Score Statistics:")
        print(f"Average F1 Score: {f1_stats['average_f1']:.4f}")
        print(f"Macro-F1 Score: {f1_stats['macro_f1']:.4f}")
        print(f"Total Classes: {f1_stats['total_classes']}")
        
        # Display per-class F1 scores
        if f1_stats['class_f1_scores']:
            print(f"\nPer-Class F1 Scores:")
            for class_name, scores in f1_stats['class_f1_scores'].items():
                print(f"  {class_name}: F1={scores['f1']:.4f}, P={scores['precision']:.4f}, R={scores['recall']:.4f}")
                pass
        
        with open(output_file, 'w', encoding='utf-8') as f:
            for item in extracted_data:
                f.write(json.dumps(item, indent=2) + "\n")
        
        print(f"\nData saved to {output_file}")
        return extracted_data
    else:
        print("No data could be extracted from the file.")
        return []

def discover_ground_truth_labels(input_file):
    """Discover actual labels from ground truth data in the JSONL file"""
    discovered_labels = set()
    
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line.strip())
                gold_text = data.get('gold', '')
                ground_truth_raw = extract_answer(gold_text)
                gt_canon, _ = _canonicalize_label(ground_truth_raw)
                if gt_canon:
                    discovered_labels.add(gt_canon)
            except (json.JSONDecodeError, Exception):
                continue
    
    return list(discovered_labels)

def extract_structured_data(input_file):
    """Extract structured data from JSONL file"""
    data_points = []
    
    with open(input_file, 'r', encoding='utf-8') as f:
        for line_num, line in tqdm(enumerate(f, 1)):
            try:
                # Parse JSON line
                data = json.loads(line.strip())
                
                # Extract generated and gold fields
                generated_text = data.get('generated', '')
                gold_text = data.get('gold', '')
                
                # Extract answers from both fields
                model_prediction_raw = extract_answer(generated_text)
                ground_truth_raw = extract_answer(gold_text)
                # Canonicalize labels and merge stage 4 -> stage 3
                pred_canon, pred_supported = _canonicalize_label(model_prediction_raw)
                gt_canon, gt_supported = _canonicalize_label(ground_truth_raw)
                
                # Calculate accuracy (exact match)
                accuracy = (pred_canon == gt_canon) and gt_supported
                
                # Calculate F1 score
                f1_result = calculate_f1_score(model_prediction_raw, ground_truth_raw)
                
                data_point = {
                    "generated": generated_text,
                    "model_prediction": model_prediction_raw,
                    "ground_truth": ground_truth_raw,
                    "accuracy": accuracy,
                    "f1_score": f1_result['f1_score'],
                    "precision": f1_result['precision'],
                    "recall": f1_result['recall'],
                    "prediction_normalized": f1_result['prediction_normalized'],
                    "ground_truth_normalized": f1_result['ground_truth_normalized'],
                    "prediction_supported": f1_result['prediction_supported'],
                    "ground_truth_supported": f1_result['ground_truth_supported'],
                    "line_number": line_num
                }
                data_points.append(data_point)
            except json.JSONDecodeError as e:
                print(f"Error parsing line {line_num}: {e}")
                continue
            except Exception as e:
                print(f"Unexpected error on line {line_num}: {e}")
                continue
    
    return data_points

def extract_answer(text):
    """Extract the final answer from text"""
    if "Answer: " not in text:
        return text
    
    answer = text.split("Answer: ")[-1].strip()
    # Remove any end-of-text tokens (including <eos> and <|...|>)
    answer = re.sub(r'<\|.*?\|>|<eos>$', '', answer).strip()
    # Remove trailing periods and normalize
    answer = re.sub(r'\.$', '', answer).strip()
    return answer

if __name__ == "__main__":
    current_dir = Path(__file__).parent
    input_file = current_dir / "llama_1b_flamingo_predictions.jsonl"
    clean_output = current_dir / "llama_1b_flamingo_predictions.clean.jsonl"
    
    parse_sleep_cot_jsonl(input_file, clean_output)
