#!/usr/bin/env python3
import os
import re
import sys
import json
import time
import random
import gc
import argparse
from typing import Tuple, Union, Dict, Any, List
import pandas as pd
from openpyxl import Workbook
from openpyxl.styles import PatternFill

import torch
import numpy as np
import xarray as xr
import arviz as az
import tiktoken
from transformers import AutoTokenizer

import syncode.infer as infer
sys.path.append("..")
from commons.data_pymc import datas_info, build_prompt_generic 
from commons.config import config
from commons.utils import convert_np_types, set_seed, check_model_reliability, run_pymc_code

# Define thresholds based on the reliability score calculation
THRESHOLDS = {
    "r_hat": 1.05,           # max_r_hat < 1.05 for score
    "ess_bulk": 400,         # min_ess_bulk >= 400 for score
    "ess_tail": 100,         # min_ess_tail >= 100 for score
    "divergences": 0,        # n_divergent == 0 for score
    "bfmi": 0.3,             # bfmi_values > 0.3 for score
    "pareto_k": 0.2,         # prop_high_pareto_k <= 0.2 for score
}

# Colors for highlighting in Excel
RED_FILL = PatternFill(start_color="FFFF0000", end_color="FFFF0000", fill_type="solid")

# Model lists by size
SMALL_MODELS = ["microsoft/Phi-3.5-mini-instruct", "Qwen/Qwen2.5-Coder-3B"]

MEDIUM_MODELS = [
            "meta-llama/Meta-Llama-3-8B", "google/codegemma-7b",
            "Qwen/Qwen2.5-Coder-7B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
        ]

# ——————————————————————————————————————————————————————————
# Define missing helper
def get_new_experiment_folder(base_dir: str, prefix: str) -> str:
    folder = os.path.join(base_dir, "metric-baseline", f"{prefix}_nt")
    os.makedirs(folder, exist_ok=True)
    return folder

def save_reliability_aggregate(agg, path):
    with open(path, 'w') as f:
        json.dump(convert_np_types(agg), f, indent=2)
    print(f"Saved reliability aggregate → {path}")

def save_token_count_aggregate(agg, path):
    with open(path, 'w') as f:
        json.dump(convert_np_types(agg), f, indent=2)
    print(f"Saved token count aggregate → {path}")

def store_token_count_aggregate(seed, dataset, llm_key, tokens, agg):
    agg.setdefault(dataset, {})
    agg[dataset].setdefault(llm_key, {'seeds': {}, 'cumulative': {}})
    agg[dataset][llm_key]['seeds'][seed] = tokens
    previous = max(
        (agg[dataset][llm_key]['cumulative'].get(s, 0)
         for s in agg[dataset][llm_key]['cumulative'] if s < seed),
        default=0
    )
    agg[dataset][llm_key]['cumulative'][seed] = previous + tokens

# Helper function to safely extract value from xarray.DataArray
def extract_xarray_value(obj):
    """Extract the actual value from an xarray.DataArray if needed"""
    if isinstance(obj, xr.DataArray):
        if obj.size == 1:
            return float(obj.values)
        return obj.values.tolist()
    return obj


# Insert snippet and return code + token count
def insert_model_code_NT(template: str, raw_snippet: Union[str, list], trace_var: str, local_llm: infer.Syncode) -> str:
    # Convert list to string if needed
    if isinstance(raw_snippet, list):
        raw_snippet = ''.join(raw_snippet)
    tok_count = local_llm.model.total_tokens
    
    # Process snippet lines, adding indentation
    processed_lines = []
    for line in raw_snippet.replace('```', '').splitlines():
        if not line.strip():
            processed_lines.append('')
        else:
            processed_lines.append('\t' + line.lstrip())
    
    snippet_code = '\n'.join(processed_lines)
    
    # Insert the snippet into the template
    out_lines = []
    inserted = False
    
    for line in template.splitlines():
        out_lines.append(line.lstrip())
        if line.strip().startswith('with pm.Model() as m:'):
            out_lines.append(snippet_code)
            inserted = True
    
    if not inserted:
        raise ValueError("'with pm.Model() as m:' not found in template")
    
    # Append summary line at the end
    out_lines.append(f'\tsummary = az.summary({trace_var})')
    
    return '\n'.join(out_lines), tok_count


# Insert snippet and return code + token count
def insert_model_code(template: str, raw_snippet: Union[str, list],
                      trace_var: str, model_name: str) -> Tuple[str,int]:
    if isinstance(raw_snippet, list):
        raw_snippet = ''.join(raw_snippet)
    lines = raw_snippet.replace('```','').splitlines()
    proc = []
    for L in lines:
        proc.append('' if not L.strip() else '\t'+L.lstrip())
        if 'pm.sample' in L:
            break
    snippet_code = '\n'.join(proc)

    try:
        enc = tiktoken.encoding_for_model(model_name)
        tok_count = len(enc.encode(snippet_code))
    except Exception:
        tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        tok_count = len(tok(snippet_code, return_tensors='pt').input_ids[0])

    out_lines, inserted = [], False
    for bl in template.splitlines():
        out_lines.append(bl.lstrip())
        if bl.strip().startswith('with pm.Model() as m:'):
            out_lines.append(snippet_code)
            inserted = True
    if not inserted:
        raise ValueError("'with pm.Model() as m:' not found in template")
    out_lines.append(f'\tsummary = az.summary({trace_var})')
    return '\n'.join(out_lines), tok_count

def flatten_diagnostics(diagnostics: Dict[str, Any]) -> Dict[str, Any]:
    """
    Flatten the diagnostics dictionary into a single-level dictionary for easier DataFrame creation
    """
    result = {}
    
    # Top-level diagnostics
    for key in ["max_r_hat", "min_ess_bulk", "min_ess_tail", "n_divergent", 
               "prop_high_pareto_k", "elpd_loo", "loo_se", "reliability_score"]:
        if key in diagnostics:
            # Handle xarray.DataArray objects
            value = extract_xarray_value(diagnostics[key])
            result[key] = value
    
    # Add BFMI values
    if "bfmi_values" in diagnostics:
        bfmi_values = diagnostics["bfmi_values"]
        if isinstance(bfmi_values, (list, np.ndarray)):
            if isinstance(bfmi_values, np.ndarray):
                bfmi_values = bfmi_values.tolist()
            result["min_bfmi"] = min(bfmi_values)
            result["max_bfmi"] = max(bfmi_values)
            for i, val in enumerate(bfmi_values):
                result[f"bfmi_{i+1}"] = val
    
    return result

def create_excel_with_highlighting(df: pd.DataFrame, filepath: str):
    """
    Create an Excel file with conditional formatting to highlight failing metrics
    """
    # Save DataFrame to Excel
    df.to_excel(filepath, index=False)
    
    # Open the saved file with openpyxl for formatting
    from openpyxl import load_workbook
    wb = load_workbook(filepath)
    ws = wb.active
    
    # Get column indices for metrics we want to highlight
    col_indices = {}
    for i, col in enumerate(df.columns, start=1):
        col_indices[col] = i
    
    # Apply conditional formatting
    for row in range(2, len(df) + 2):  # Excel rows start at 1, plus header
        if pd.notnull(df.iloc[row-2]["reliability_score"]):  # Only process rows with reliability scores
            # Check r_hat
            if "max_r_hat" in col_indices and pd.notnull(df.iloc[row-2]["max_r_hat"]):
                r_hat = df.iloc[row-2]["max_r_hat"]
                if r_hat >= THRESHOLDS["r_hat"]:
                    ws.cell(row=row, column=col_indices["max_r_hat"]).fill = RED_FILL
            
            # Check ess_bulk
            if "min_ess_bulk" in col_indices and pd.notnull(df.iloc[row-2]["min_ess_bulk"]):
                ess_bulk = df.iloc[row-2]["min_ess_bulk"]
                if ess_bulk < THRESHOLDS["ess_bulk"]:
                    ws.cell(row=row, column=col_indices["min_ess_bulk"]).fill = RED_FILL
            
            # Check ess_tail
            if "min_ess_tail" in col_indices and pd.notnull(df.iloc[row-2]["min_ess_tail"]):
                ess_tail = df.iloc[row-2]["min_ess_tail"]
                if ess_tail < THRESHOLDS["ess_tail"]:
                    ws.cell(row=row, column=col_indices["min_ess_tail"]).fill = RED_FILL
            
            # Check divergences
            if "n_divergent" in col_indices and pd.notnull(df.iloc[row-2]["n_divergent"]):
                n_divergent = df.iloc[row-2]["n_divergent"]
                if n_divergent > THRESHOLDS["divergences"]:
                    ws.cell(row=row, column=col_indices["n_divergent"]).fill = RED_FILL
            
            # Check BFMI values
            bfmi_cols = [col for col in df.columns if col.startswith("bfmi_")]
            for bfmi_col in bfmi_cols:
                if bfmi_col in col_indices and pd.notnull(df.iloc[row-2][bfmi_col]):
                    bfmi = df.iloc[row-2][bfmi_col]
                    if bfmi <= THRESHOLDS["bfmi"]:
                        ws.cell(row=row, column=col_indices[bfmi_col]).fill = RED_FILL
            
            # Check pareto_k
            if "prop_high_pareto_k" in col_indices and pd.notnull(df.iloc[row-2]["prop_high_pareto_k"]):
                prop_high_pareto_k = df.iloc[row-2]["prop_high_pareto_k"]
                if prop_high_pareto_k > THRESHOLDS["pareto_k"]:
                    ws.cell(row=row, column=col_indices["prop_high_pareto_k"]).fill = RED_FILL
    
    # Save the workbook with formatting
    wb.save(filepath)

# Main experiment per seed
def main_experiment(root: str, llm_key: str, llm_name: str,
                    local_llm: infer.Syncode, temperature: float, seed: int, use_nt: bool):
    set_seed(seed)
    exp_dir = os.path.join(root, f'seed_{seed}', llm_key)
    os.makedirs(exp_dir, exist_ok=True)
    
    # Dictionary to store results for this seed, organized by dataset
    seed_results = {}

    for entry in datas_info:
        dataset = entry['name']
        ds_dir = os.path.join(exp_dir, dataset)
        os.makedirs(ds_dir, exist_ok=True)

        # Initialize the results dictionary for this dataset if not already present
        if dataset not in seed_results:
            seed_results[dataset] = {
                "seed": seed,
                "model": llm_key,
                "dataset": dataset,
                "total_tokens": 0,
                "cumulative_tokens": 0,
                "code_compiles": False,
                "reliability_score": None
            }

        try:
            prompt = build_prompt_generic(entry)
            open(os.path.join(ds_dir,'prompt.txt'),'w').write(prompt)

            snippet = local_llm.infer(prompt)
            open(os.path.join(ds_dir,'snippet.txt'),'w').write(str(snippet))

            m = re.search(r"(\w+)\s*=\s*pm.sample", str(snippet))
            trace_var = m.group(1) if m else 'trace'

            if use_nt:
                full_code, tk_count = insert_model_code_NT(
                    entry['template_code'], snippet, trace_var, local_llm
                )
            else:
                full_code, tk_count = insert_model_code(
                    entry['template_code'], snippet, trace_var, llm_name
                )
                
            open(os.path.join(ds_dir,'final_code.py'),'w').write(full_code)

            # Update token count in the results
            seed_results[dataset]["total_tokens"] = tk_count
            
            compiled, out = run_pymc_code(full_code)
            
            # Update compilation status
            seed_results[dataset]["code_compiles"] = compiled
            
            if compiled and isinstance(out, dict) and trace_var in out:
                rel, diag = check_model_reliability(out[trace_var])
                
                # Store reliability score and elpd
                reliability_aggregate\
                    .setdefault(dataset, {})\
                    .setdefault(llm_key, {})[seed] = {
                        'reliability_score': rel,
                        'elpd': extract_xarray_value(diag.get('elpd_loo'))
                    }
                
                # Update result dictionary with diagnostics
                seed_results[dataset]["reliability_score"] = rel
                
                # Flatten and add all diagnostic metrics
                flat_diag = flatten_diagnostics(diag)
                for key, value in flat_diag.items():
                    seed_results[dataset][key] = value
                
                # Save diagnostics to dataset folder
                diag_to_save = convert_np_types(diag)
                diag_to_save["reliability_score"] = rel
                diagnostics_path = os.path.join(ds_dir, 'diagnostics.json')
                with open(diagnostics_path, 'w') as f:
                    json.dump(diag_to_save, f, indent=2)
            
            store_token_count_aggregate(
                seed, dataset, llm_key, tk_count, token_count_aggregate
            )
        
        except Exception as e:
            # Log the error but continue with the next dataset
            error_msg = f"Error processing {dataset} for {llm_key} seed {seed}: {str(e)}"
            print(error_msg)
            error_path = os.path.join(ds_dir, 'error.txt')
            with open(error_path, 'w') as f:
                f.write(error_msg)
    
    # Update cumulative tokens for each dataset
    for dataset in seed_results:
        cumulative = token_count_aggregate.get(dataset, {}).get(llm_key, {}).get('cumulative', {}).get(seed, 0)
        seed_results[dataset]["cumulative_tokens"] = cumulative
    
    # Create DataFrames for each dataset (one row per seed)
    for dataset, results in seed_results.items():
        # Create a dataset-specific folder for results in the model directory
        ds_results_dir = os.path.join(root, llm_key, dataset)
        os.makedirs(ds_results_dir, exist_ok=True)
        
        # Load existing results if any
        results_path = os.path.join(ds_results_dir, 'results.csv')
        if os.path.exists(results_path):
            try:
                df = pd.read_csv(results_path)
                # Check if this seed already exists
                if seed in df['seed'].values:
                    # Update the existing row
                    df.loc[df['seed'] == seed] = pd.Series(results)
                else:
                    # Append new row
                    df = pd.concat([df, pd.DataFrame([results])], ignore_index=True)
            except:
                # If there's an error reading the file, create a new DataFrame
                df = pd.DataFrame([results])
        else:
            # Create a new DataFrame
            df = pd.DataFrame([results])
        
        # Ensure proper column order
        # First specify critical columns that should always be present
        essential_cols = ["seed", "reliability_score", "total_tokens", "cumulative_tokens", 
                          "code_compiles", "elpd_loo"]
        
        # Add missing columns with None values to ensure consistent formatting
        for col in essential_cols:
            if col not in df.columns:
                df[col] = None
        
        # Then add remaining columns alphabetically
        remaining_cols = [col for col in df.columns if col not in essential_cols]
        
        # Reorder columns
        df = df[essential_cols + sorted(remaining_cols)]
        
        # Sort by seed
        df = df.sort_values("seed")
        
        # Save to CSV and Excel
        df.to_csv(results_path, index=False)
        
        excel_path = os.path.join(ds_results_dir, 'results.xlsx')
        create_excel_with_highlighting(df, excel_path)
    
    return seed_results

# Combined multi-seed pipeline (modified to load/unload one model at a time)
def run_multiple_seeds(num_seeds: int=10, temperature: float=0.3, model_size: str='medium', use_nt: bool=False):
    root = get_new_experiment_folder('results', model_size)
    open(os.path.join(root,'desc.txt'),'w')\
        .write(f'Seeds={num_seeds},temp={temperature},model_size={model_size},use_nt={use_nt}')

    # Select models based on size
    if model_size == 'small':
        llm_list = SMALL_MODELS
    elif model_size == 'medium':
        llm_list = MEDIUM_MODELS
    else:
        raise ValueError(f"Invalid model size: {model_size}. Must be 'small' or 'medium'")

    # Create a dict to store all results across seeds
    all_results = {}

    for llm in llm_list:
        llm_key = llm.replace('/','_')
        print(f"\n=== Loading model {llm_key} onto GPU ===")
        try:
            local_llm = infer.Syncode(
                mode='original',
                model=llm,
                do_sample=True,
                temperature=temperature,
                grammar='syncode/parsers/grammars/python_grammar.lark',
                device='cuda',
                max_new_tokens=400
            )

            for s in range(1, num_seeds+1):
                try:
                    seed_results = main_experiment(
                        root, llm_key, llm, local_llm, temperature, s, use_nt
                    )
                    
                    # Store results
                    for dataset, result in seed_results.items():
                        if dataset not in all_results:
                            all_results[dataset] = {}
                        
                        if llm_key not in all_results[dataset]:
                            all_results[dataset][llm_key] = {}
                        
                        all_results[dataset][llm_key][s] = result
                except Exception as e:
                    print(f"[Seed {s}][{llm_key}] ERROR: {str(e)}")
                    # Continue with the next seed

            print(f"=== Unloading model {llm_key} from GPU ===")
            try:
                local_llm.model.cpu()
            except Exception:
                pass
            del local_llm
            torch.cuda.empty_cache()
            gc.collect()
        except Exception as e:
            print(f"Error loading model {llm_key}: {str(e)}")
            # Continue with the next model

    # After all models & seeds
    save_reliability_aggregate(
        reliability_aggregate,
        os.path.join(root, 'aggregated_reliability.json')
    )
    save_token_count_aggregate(
        token_count_aggregate,
        os.path.join(root, 'aggregated_token_count.json')
    )
    
    # Generate token budget based on results (20% buffer)
    token_budget = {}
    for dataset, models in token_count_aggregate.items():
        for model, data in models.items():
            if 'cumulative' in data:
                if model not in token_budget:
                    token_budget[model] = {}
                
                max_tokens = max(data['cumulative'].values()) if data['cumulative'] else 0
                token_budget[model][dataset] = int(max_tokens * 1.2)  # 20% buffer
    
    # Save token budget
    budget_path = os.path.join(root, 'token_budget.json')
    with open(budget_path, 'w') as f:
        json.dump(token_budget, f, indent=2)
    
    print(f"Generated token budget saved to {budget_path}")
    
    # Process the traditional leaderboard and progression files
    datasets = [e['name'] for e in datas_info]
    for llm in llm_list:
        llm_key = llm.replace('/','_')
        lm_folder = os.path.join(root, llm_key)
        os.makedirs(lm_folder, exist_ok=True)

        # Leaderboard
        lb = []
        for s in range(1, num_seeds+1):
            for ds in datasets:
                met = reliability_aggregate.get(ds, {})\
                                           .get(llm_key, {})\
                                           .get(s)
                rel  = met['reliability_score'] if met else None
                elpd = met['elpd'] if met else None
                cum  = token_count_aggregate.get(ds, {})\
                                            .get(llm_key, {})\
                                            .get('cumulative', {})\
                                            .get(s)
                code_path = os.path.join(
                    root, f'seed_{s}', llm_key, ds, 'final_code.py'
                )
                if not os.path.exists(code_path):
                    code_path = None
                lb.append({
                    'seed': s,
                    'dataset': ds,
                    'reliability_score': rel,
                    'elpd_loo': elpd,
                    'cumulative_tokens': cum,
                    'code_path': code_path
                })
        lb.sort(
            key=lambda x: (
                (x['reliability_score'] or -1),
                (x['elpd_loo'] or -float('inf'))
            ),
            reverse=True
        )
        with open(os.path.join(lm_folder,'leaderboard.json'),'w') as f:
            json.dump(lb, f, indent=2)

        # Progression
        for ds in datasets:
            cum_map = token_count_aggregate.get(ds, {})\
                                           .get(llm_key, {})\
                                           .get('cumulative', {})
            prog = []
            for s in sorted(cum_map):
                rel = reliability_aggregate.get(ds, {})\
                                           .get(llm_key, {})\
                                           .get(s, {})\
                                           .get('reliability_score')
                prog.append({
                    'seed': s,
                    'cumulative_tokens': cum_map[s],
                    'reliability_score': rel
                })
            with open(os.path.join(lm_folder, f'progress_{ds}.json'),'w') as pf:
                json.dump(prog, pf, indent=2)

        print(f"Wrote files for {llm_key} in {lm_folder}")

    print('All done.')

# Global aggregates
reliability_aggregate = {}
token_count_aggregate = {}

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run experiments with different model sizes and NT option')
    parser.add_argument('--model-size', type=str, choices=['small', 'medium'], default='medium',
                      help='Size of models to use (small or medium)')
    parser.add_argument('--nt', action='store_true',
                      help='Use NT version of code insertion')
    parser.add_argument('--num-seeds', type=int, default=config['num_seeds'],
                      help='Number of seeds to run')
    parser.add_argument('--temperature', type=float, default=config['temperature'],
                      help='Temperature for sampling')
    
    args = parser.parse_args()
    
    run_multiple_seeds(
        num_seeds=args.num_seeds,
        temperature=args.temperature,
        model_size=args.model_size,
        use_nt=args.nt
    )
