import os

import json
import numpy as np
import mauve
import random
import sys
import matplotlib.pyplot as plt
import re

def calculate_mauve_score(predictions, references, gpu_id, max_text_length, batch_size):
    if not predictions or not references:
        print("Warning: Predictions or references list is empty.", file=sys.stderr)
        return np.nan

    if not all(isinstance(p, str) for p in predictions):
        print("Error: All predictions must be strings.", file=sys.stderr)
        return np.nan
    if not all(isinstance(r, str) for r in references):
        print("Error: All references must be strings.", file=sys.stderr)
        return np.nan

    try:
        out = mauve.compute_mauve(
            p_text=predictions,
            q_text=references,
            device_id=gpu_id, 
            max_text_length=max_text_length,
            batch_size=batch_size,
            verbose=False
        )
        return out.mauve
    except Exception as e:
        print(f"Error computing MAUVE: {e}", file=sys.stderr)
        return np.nan

def bootstrap_mauve_from_file(data_file_path, num_bootstraps, gpu_id, max_text_length, batch_size):
    try:
        with open(data_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: Data file not found at {data_file_path}", file=sys.stderr)
        return {}
    except json.JSONDecodeError as e:
        print(f"Error: Could not decode JSON from {data_file_path}. Details: {e}", file=sys.stderr)
        return {}

    if "original" not in data:
        print("Error: 'original' key not found in the data.", file=sys.stderr)
        return {}

    original_texts = data["original"]
    if not isinstance(original_texts, list) or not all(isinstance(text, str) for text in original_texts):
        print("Error: 'original' texts must be a list of strings.", file=sys.stderr)
        return {}
    
    n_samples_original = len(original_texts)
    if n_samples_original == 0:
        print("Error: 'original' texts list is empty.", file=sys.stderr)
        return {}

    all_results = {}

    for key, texts_source in data.items():
        if key == "original":
            continue

        if not isinstance(texts_source, list) or not all(isinstance(text, str) for text in texts_source) or not texts_source:
            print(f"Warning: Texts for source '{key}' are invalid. Skipping.", file=sys.stderr)
            all_results[key] = {"direct_mauve_score": np.nan, "std_bootstrap_score": np.nan}
            continue
        
        direct_score_val = calculate_mauve_score(
            predictions=texts_source, 
            references=original_texts, 
            gpu_id=gpu_id,
            max_text_length=max_text_length,
            batch_size=batch_size
        )

        std_bs_score = np.nan
        if len(texts_source) != n_samples_original:
            print(f"Warning: Source '{key}' has {len(texts_source)} samples, but 'original' has {n_samples_original}. Skipping bootstrap.", file=sys.stderr)
        else:
            bootstrapped_scores = []
            for i in range(num_bootstraps):
                bootstrap_indices = [random.randint(0, n_samples_original - 1) for _ in range(n_samples_original)]
                current_original_sample = [original_texts[idx] for idx in bootstrap_indices]
                current_source_sample = [texts_source[idx] for idx in bootstrap_indices]
                
                if not current_original_sample or not current_source_sample:
                    continue

                score = calculate_mauve_score(
                    predictions=current_source_sample, 
                    references=current_original_sample, 
                    gpu_id=gpu_id,
                    max_text_length=max_text_length,
                    batch_size=batch_size
                )
                if not np.isnan(score):
                    bootstrapped_scores.append(score)

            if bootstrapped_scores:
                std_bs_score = np.std(bootstrapped_scores)
            else:
                print(f"  Could not compute any valid MAUVE scores for bootstrap for source '{key}'.", file=sys.stderr)

        all_results[key] = {
            "direct_mauve_score": direct_score_val,
            "std_bootstrap_score": std_bs_score
        }
        
    return all_results

def plot_mauve_comparison_lines(results, model_name_simple, output_dir):
    plot_data = {
        "Top-K": {"k_values": [], "scores": [], "stds": [], "color": "blue"},
        "Primal (alpha=1.5)": {"k_values": [], "scores": [], "stds": [], "color": "green"},
        "Primal (alpha=2.0)": {"k_values": [], "scores": [], "stds": [], "color": "red"}
    }

    for method_name, data in results.items():
        score = data["direct_mauve_score"]
        std = data["std_bootstrap_score"] if data["std_bootstrap_score"] is not None else 0

        if np.isnan(score):
            continue

        k_value = None
        method_type = None

        if method_name.startswith("top_k_"):
            try:
                k_value = int(method_name.split("top_k_")[-1])
                method_type = "Top-K"
            except ValueError:
                continue
        elif method_name.startswith("primal_alpha_1.5_k_"):
            try:
                k_value = int(method_name.split("primal_alpha_1.5_k_")[-1])
                method_type = "Primal (alpha=1.5)"
            except ValueError:
                continue
        elif method_name.startswith("primal_alpha_2.0_k_"):
            try:
                k_value = int(method_name.split("primal_alpha_2.0_k_")[-1])
                method_type = "Primal (alpha=2.0)"
            except ValueError:
                continue
        
        if k_value is not None and method_type is not None:
            plot_data[method_type]["k_values"].append(k_value)
            plot_data[method_type]["scores"].append(score)
            plot_data[method_type]["stds"].append(std if not np.isnan(std) else 0)

    fig, ax = plt.subplots(figsize=(12, 7))
    found_data_to_plot = False

    for label, data_series in plot_data.items():
        if data_series["k_values"]:
            sorted_indices = np.argsort(data_series["k_values"])
            k_values = np.array(data_series["k_values"])[sorted_indices]
            scores = np.array(data_series["scores"])[sorted_indices]
            stds = np.array(data_series["stds"])[sorted_indices]
            
            ax.plot(k_values, scores, label=label, marker='o', color=data_series["color"])
            ax.fill_between(k_values, scores - stds, scores + stds, alpha=0.2, color=data_series["color"])
            found_data_to_plot = True

    if not found_data_to_plot:
        print("No data to plot for MAUVE comparison lines. Check method names and results.", file=sys.stderr)
        plt.close(fig)
        return

    ax.set_xlabel('k value')
    ax.set_ylabel('MAUVE Score')
    ax.set_title(f'MAUVE Score vs. k for {model_name_simple}')
    ax.legend()
    ax.grid(True, linestyle='--', alpha=0.7)
    
    all_k_values = sum([ps["k_values"] for ps in plot_data.values()], [])
    if all_k_values and all(isinstance(k, int) for k in all_k_values):
        unique_sorted_k = sorted(list(set(all_k_values)))
        if len(unique_sorted_k) > 1:
             ax.set_xticks(unique_sorted_k)
        elif len(unique_sorted_k) == 1:
             ax.set_xticks([unique_sorted_k[0]])

    plt.tight_layout()
    
    output_filename = os.path.join(output_dir, f"mauve_comparison_plot_{model_name_simple}.pdf")
    try:
        plt.savefig(output_filename, format='pdf')
    except Exception as e:
        print(f"Error saving MAUVE comparison plot: {e}", file=sys.stderr)
    plt.close(fig)
