import os
import sys
import json
import gc
import random
import torch
import numpy as np
import xarray as xr
import pandas as pd
from openpyxl import Workbook
from openpyxl.styles import PatternFill
import argparse
from pathlib import Path
from typing import Dict, List, Any, Tuple, Optional

from refinegen.main import iterative_refine
from refinegen.itergen.itergen.main import IterGen
from refinegen.checkers.library_specific.ppl_checker import PyMCChecker

import sys
sys.path.append("..")
from commons.data_pymc import datas_info, build_prompt_generic 
from commons.config import config
from commons.utils import convert_np_types, set_seed
# Configuration

model_ids = [
    "meta-llama/Meta-Llama-3-8B", "google/codegemma-7b",
    "Qwen/Qwen2.5-Coder-7B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
]

# Define thresholds based on the reliability score calculation
THRESHOLDS = {
    "r_hat": 1.05,           # max_r_hat < 1.05 for score
    "ess_bulk": 400,         # min_ess_bulk >= 400 for score
    "ess_tail": 100,         # min_ess_tail >= 100 for score
    "divergences": 0,        # n_divergent == 0 for score
    "bfmi": 0.3,             # bfmi_values > 0.3 for score
    "pareto_k": 0.2,         # prop_high_pareto_k <= 0.2 for score
}

# Colors for highlighting
RED_FILL = PatternFill(start_color="FFFF0000", end_color="FFFF0000", fill_type="solid")

# Helper function to safely extract value from xarray.DataArray
def extract_xarray_value(obj):
    """Extract the actual value from an xarray.DataArray if needed"""
    if isinstance(obj, xr.DataArray):
        if obj.size == 1:
            return float(obj.values)
        return obj.values.tolist()
    return obj

def flatten_diagnostics(diagnostics: Dict[str, Any]) -> Dict[str, Any]:
    """
    Flatten the diagnostics dictionary into a single-level dictionary for easier DataFrame creation
    """
    result = {}
    
    # Top-level diagnostics
    for key in ["max_r_hat", "min_ess_bulk", "min_ess_tail", "n_divergent", 
               "prop_high_pareto_k", "elpd_loo", "loo_se", "reliability_score"]:
        if key in diagnostics:
            # Handle xarray.DataArray objects
            value = extract_xarray_value(diagnostics[key])
            result[key] = value
    
    # Add BFMI values
    if "bfmi_values" in diagnostics:
        bfmi_values = diagnostics["bfmi_values"]
        if isinstance(bfmi_values, (list, np.ndarray)):
            if isinstance(bfmi_values, np.ndarray):
                bfmi_values = bfmi_values.tolist()
            result["min_bfmi"] = min(bfmi_values)
            result["max_bfmi"] = max(bfmi_values)
            for i, val in enumerate(bfmi_values):
                result[f"bfmi_{i+1}"] = val
    
    return result

def create_excel_with_highlighting(df: pd.DataFrame, filepath: str):
    """
    Create an Excel file with conditional formatting to highlight failing metrics
    """
    # Save DataFrame to Excel
    df.to_excel(filepath, index=False)
    
    # Open the saved file with openpyxl for formatting
    from openpyxl import load_workbook
    wb = load_workbook(filepath)
    ws = wb.active
    
    # Get column indices for metrics we want to highlight
    col_indices = {}
    for i, col in enumerate(df.columns, start=1):
        col_indices[col] = i
    
    # Apply conditional formatting
    for row in range(2, len(df) + 2):  # Excel rows start at 1, plus header
        if pd.notnull(df.iloc[row-2]["reliability_score"]):  # Only process rows with reliability scores
            # Check r_hat
            if "max_r_hat" in col_indices and pd.notnull(df.iloc[row-2]["max_r_hat"]):
                r_hat = df.iloc[row-2]["max_r_hat"]
                if r_hat >= THRESHOLDS["r_hat"]:
                    ws.cell(row=row, column=col_indices["max_r_hat"]).fill = RED_FILL
            
            # Check ess_bulk
            if "min_ess_bulk" in col_indices and pd.notnull(df.iloc[row-2]["min_ess_bulk"]):
                ess_bulk = df.iloc[row-2]["min_ess_bulk"]
                if ess_bulk < THRESHOLDS["ess_bulk"]:
                    ws.cell(row=row, column=col_indices["min_ess_bulk"]).fill = RED_FILL
            
            # Check ess_tail
            if "min_ess_tail" in col_indices and pd.notnull(df.iloc[row-2]["min_ess_tail"]):
                ess_tail = df.iloc[row-2]["min_ess_tail"]
                if ess_tail < THRESHOLDS["ess_tail"]:
                    ws.cell(row=row, column=col_indices["min_ess_tail"]).fill = RED_FILL
            
            # Check divergences
            if "n_divergent" in col_indices and pd.notnull(df.iloc[row-2]["n_divergent"]):
                n_divergent = df.iloc[row-2]["n_divergent"]
                if n_divergent > THRESHOLDS["divergences"]:
                    ws.cell(row=row, column=col_indices["n_divergent"]).fill = RED_FILL
            
            # Check BFMI values
            bfmi_cols = [col for col in df.columns if col.startswith("bfmi_")]
            for bfmi_col in bfmi_cols:
                if bfmi_col in col_indices and pd.notnull(df.iloc[row-2][bfmi_col]):
                    bfmi = df.iloc[row-2][bfmi_col]
                    if bfmi <= THRESHOLDS["bfmi"]:
                        ws.cell(row=row, column=col_indices[bfmi_col]).fill = RED_FILL
            
            # Check pareto_k
            if "prop_high_pareto_k" in col_indices and pd.notnull(df.iloc[row-2]["prop_high_pareto_k"]):
                prop_high_pareto_k = df.iloc[row-2]["prop_high_pareto_k"]
                if prop_high_pareto_k > THRESHOLDS["pareto_k"]:
                    ws.cell(row=row, column=col_indices["prop_high_pareto_k"]).fill = RED_FILL
    
    # Save the workbook with formatting
    wb.save(filepath)

def extract_best_program(interim_program, reliability_threshold=6):
    filtered = []
    for entry in interim_program:
        score = entry.get("reliability_score", 0)
        diagnostics = entry.get("diagnostics", {})
        if score >= reliability_threshold and isinstance(diagnostics.get("elpd_loo"), (int, float)):
            filtered.append((score, diagnostics["elpd_loo"], entry))
    if not filtered:
        return None
    return sorted(filtered, key=lambda x: (x[0], x[1]), reverse=True)[0][2]


def get_next_experiment_folder(output_dir):
    os.makedirs(output_dir, exist_ok=True)
    existing = [
        d for d in os.listdir(output_dir)
        if d.startswith("expt_") and os.path.isdir(os.path.join(output_dir, d))
    ]
    nums = [int(d.split("_")[1]) for d in existing if d.split("_")[1].isdigit()]
    next_num = max(nums, default=-1) + 1
    return os.path.join(output_dir, f"expt_{next_num}")


def process_interim_program(interim_program: List[Dict[str, Any]], model: str, dataset: str, seed: int, total_tokens: int) -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Process interim_program entries and create a DataFrame with one row per entry
    
    Returns:
    - DataFrame with all entries
    - Dictionary with the best entry details
    """
    entries = []
    entry_number = 1
    
    for entry in interim_program:
        # Extract data from entry
        row_data = {
            "model": model,
            "dataset": dataset,
            "seed": seed,
            "entry_number": entry_number,
            "reliability_score": entry.get("reliability_score"),
        }
        
        # Add cumulative_tokens if it exists in the entry
        if "cumulative_tokens" in entry:
            row_data["cumulative_tokens"] = entry.get("cumulative_tokens")
        
        # Add flattened diagnostics
        diagnostics = entry.get("diagnostics", {})
        flat_diag = flatten_diagnostics(diagnostics)
        row_data.update(flat_diag)
        
        # Add program (we'll remove this before saving to CSV/Excel)
        row_data["program"] = entry.get("program", "")
        
        entries.append(row_data)
        entry_number += 1
    
    # Create DataFrame
    if entries:
        df = pd.DataFrame(entries)
        
        # Find the best entry (highest reliability score, then highest elpd_loo)
        best_entry = None
        if "reliability_score" in df.columns:
            valid_df = df.dropna(subset=["reliability_score"])
            if len(valid_df) > 0:
                max_reliability = valid_df["reliability_score"].max()
                best_reliability_rows = valid_df[valid_df["reliability_score"] == max_reliability]
                
                if len(best_reliability_rows) > 1 and "elpd_loo" in best_reliability_rows.columns and not best_reliability_rows["elpd_loo"].isna().all():
                    try:
                        best_entry = best_reliability_rows.loc[best_reliability_rows["elpd_loo"].idxmax()].to_dict()
                    except Exception as e:
                        print(f"Warning: Error selecting by elpd_loo: {str(e)}. Using first row instead.")
                        best_entry = best_reliability_rows.iloc[0].to_dict()
                else:
                    best_entry = best_reliability_rows.iloc[0].to_dict()
        
        # Add total tokens information (provided as a parameter)
        if best_entry:
            best_entry["total_tokens"] = total_tokens
        
        return df, best_entry
    else:
        return pd.DataFrame(), None


def run_batch(seed=0, output_dir="results/med"):
    experiment_dir = get_next_experiment_folder(output_dir)
    print(f"Running batch with seed {seed} in {experiment_dir}")
    os.makedirs(experiment_dir, exist_ok=True)
    with open(os.path.join(experiment_dir, 'seed.txt'), 'w') as f:
        f.write(str(seed))
    
    # Dictionary to store all results for summary CSV
    all_best_entries = []
    
    # Dictionary to track token usage
    token_usage = {}
    
    for model_id in model_ids:
        set_seed(seed)
        
        safe_model = model_id.replace('/', '_')
        
        # Initialize token usage tracking
        if safe_model not in token_usage:
            token_usage[safe_model] = {}

        iter_gen = IterGen(
            grammar='refinegen/itergen/itergen/syncode/syncode/parsers/grammars/python_grammar.lark',
            model_id = model_id,
            device='cuda',
            do_sample=True,
            temperature=config["temperature"],
            recurrence_penalty=config["recurrence_penalty"],
            seed=seed
        )

        for data in datas_info:
            dataset_name = data["name"]
            
            # Initialize token usage tracking for this dataset
            if dataset_name not in token_usage[safe_model]:
                token_usage[safe_model][dataset_name] = {}
            
            prompt = build_prompt_generic(data)
            try:
                symboltable, interim_program = iterative_refine(
                    prompt=prompt,
                    template=data["template_code"],
                    model_name=data["name"],
                    iter_gen=iter_gen,
                    checker_cls=PyMCChecker,
                    unit_name=config["unit_name"],
                    max_iter=config["max_iter"],
                    # max_iter=20,
                    checker_config={},
                    symboltable=config["pymc_symboltable"],
                    dedent=True,
                    seed=seed
                )

                save_dir = os.path.join(experiment_dir, safe_model, dataset_name)
                os.makedirs(save_dir, exist_ok=True)

                # --- get total_tokens and save ---
                total_tokens = iter_gen._metadata.get('total_tokens')
                with open(os.path.join(save_dir, 'total_tokens.txt'), 'w') as tf:
                    tf.write(str(total_tokens))
                
                # Track token usage
                token_usage[safe_model][dataset_name][seed] = total_tokens

                # 1) Save raw interim_program for debugging
                with open(os.path.join(save_dir, 'interim_program.txt'), 'w') as f:
                    f.write(str(interim_program))

                # 2) Write out each entry as its own JSON
                for idx, entry in enumerate(interim_program, start=1):
                    serial = convert_np_types(entry)
                    if "program" in serial and isinstance(serial["program"], str):
                        prog = serial["program"].strip("\n").replace("\t", "    ")
                        serial["program"] = prog

                    json_path = os.path.join(save_dir, f"entry_{idx}.json")
                    with open(json_path, "w") as jf:
                        json.dump(serial, jf, indent=4)

                # 3) Save the best single program & diagnostics
                best = extract_best_program(interim_program)
                if best:
                    with open(os.path.join(save_dir, 'best_program.py'), 'w') as f:
                        f.write(best.get('program', '').strip("\n"))
                    with open(os.path.join(save_dir, 'best_program_diagnostics.txt'), 'w') as f:
                        f.write(str(best.get('diagnostics', {})))
                
                # 4) Process entries for detailed analysis
                entries_df, best_entry = process_interim_program(
                    interim_program, safe_model, dataset_name, seed, total_tokens
                )
                
                if len(entries_df) > 0:
                    # Create detailed results directory for this model/dataset inside the experiment directory
                    analysis_dir = os.path.join(experiment_dir, "analysis", safe_model, dataset_name)
                    os.makedirs(analysis_dir, exist_ok=True)
                    
                    # Save to CSV and Excel
                    # Remove program column for CSV/Excel (it's too long)
                    if "program" in entries_df.columns:
                        display_df = entries_df.drop(columns=["program"])
                    else:
                        display_df = entries_df
                    
                    # Add additional columns that should always be present
                    essential_cols = ["model", "dataset", "seed", "entry_number", "reliability_score", 
                                     "elpd_loo", "max_r_hat", "min_ess_bulk", "min_ess_tail", 
                                     "n_divergent", "prop_high_pareto_k"]
                    
                    for col in essential_cols:
                        if col not in display_df.columns:
                            display_df[col] = None
                    
                    # Save entries for this seed
                    entries_csv = os.path.join(analysis_dir, f"seed_{seed}_entries.csv")
                    entries_excel = os.path.join(analysis_dir, f"seed_{seed}_entries.xlsx")
                    
                    display_df.to_csv(entries_csv, index=False)
                    create_excel_with_highlighting(display_df, entries_excel)
                    
                    # Save the best entry if found
                    if best_entry:
                        # Save to a JSON file
                        best_json_path = os.path.join(analysis_dir, f"seed_{seed}_best.json")
                        with open(best_json_path, 'w') as f:
                            json.dump(best_entry, f, indent=2)
                        
                        # Add to all_best_entries for the summary CSV
                        summary_entry = best_entry.copy()
                        
                        # Don't include the program in the summary
                        if "program" in summary_entry:
                            del summary_entry["program"]
                        
                        all_best_entries.append(summary_entry)
            
            except Exception as e:
                print(f"Error processing {dataset_name} for {safe_model}: {str(e)}")
                # Log the error to a file
                error_dir = os.path.join(experiment_dir, safe_model, dataset_name)
                os.makedirs(error_dir, exist_ok=True)
                error_path = os.path.join(error_dir, 'error.txt')
                with open(error_path, 'w') as f:
                    f.write(f"Error: {str(e)}")

        # cleanup
        del iter_gen
        torch.cuda.empty_cache()
        gc.collect()
    
    # Calculate cumulative tokens
    cumulative_tokens = {}
    for model, datasets in token_usage.items():
        cumulative_tokens[model] = {}
        
        for dataset, seeds in datasets.items():
            cumulative_tokens[model][dataset] = {}
            
            sorted_seeds = sorted(seeds.keys())
            running_sum = 0
            
            for seed_val in sorted_seeds:
                running_sum += seeds[seed_val]
                cumulative_tokens[model][dataset][seed_val] = running_sum
    
    # Save token usage and cumulative tokens to the experiment directory
    token_path = os.path.join(experiment_dir, "token_usage.json")
    with open(token_path, 'w') as f:
        json.dump(token_usage, f, indent=2)
    
    cumulative_path = os.path.join(experiment_dir, "cumulative_tokens.json")
    with open(cumulative_path, 'w') as f:
        json.dump(cumulative_tokens, f, indent=2)
    
    # Save a token budget file (120% of the maximum cumulative tokens)
    token_budget = {}
    for model, datasets in cumulative_tokens.items():
        token_budget[model] = {}
        
        for dataset, seeds in datasets.items():
            if seeds:  # Only if there are seeds
                max_tokens = max(seeds.values())
                token_budget[model][dataset] = int(max_tokens * 1.2)  # 20% buffer
    
    budget_path = os.path.join(experiment_dir, "token_budget.json")
    with open(budget_path, 'w') as f:
        json.dump(token_budget, f, indent=2)
    
    # Create summary CSV with all best entries
    if all_best_entries:
        analysis_dir = os.path.join(experiment_dir, "analysis")
        os.makedirs(analysis_dir, exist_ok=True)
        
        summary_df = pd.DataFrame(all_best_entries)
        
        # Ensure essential columns are present
        essential_cols = ["model", "dataset", "seed", "entry_number", "reliability_score", 
                         "elpd_loo", "total_tokens", "max_r_hat", "min_ess_bulk", "min_ess_tail", 
                         "n_divergent", "prop_high_pareto_k", "cumulative_tokens"]
        
        for col in essential_cols:
            if col not in summary_df.columns:
                summary_df[col] = None
        
        # Reorder columns
        first_cols = ["model", "dataset", "seed", "entry_number", "reliability_score", 
                     "elpd_loo", "total_tokens", "cumulative_tokens"]
        other_cols = [c for c in summary_df.columns if c not in first_cols]
        
        summary_df = summary_df[first_cols + sorted(other_cols)]
        
        # Save summary
        summary_path = os.path.join(analysis_dir, "best_programs_summary.csv")
        summary_excel_path = os.path.join(analysis_dir, "best_programs_summary.xlsx")
        
        summary_df.to_csv(summary_path, index=False)
        create_excel_with_highlighting(summary_df, summary_excel_path)
        
        print(f"Summary saved to {summary_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run batch with a given seed")
    parser.add_argument(
        "--seed", "-s",
        type=int,
        required=True,
        help="Random seed to use for this run"
    )
    args = parser.parse_args()
    run_batch(args.seed)