import os
import sys
import json
import random
import subprocess
import tempfile
import numpy as np
import torch
import re

from refinegen.base import iterative_refine
from refinegen.itergen.itergen.main import IterGen
# Import the desired checker (or multiple checkers)
from refinegen.checkers.library_specific.ppl_checker import PyMCChecker

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from commons.model_pymc import models_info, build_prompt_generic


def extract_elpd(output_text):
    elpd_match = re.search(r"ELPD using LOO:\s*([\d.-]+)", output_text)
    return float(elpd_match.group(1)) if elpd_match else None

def extract_trace_name(snippet):
    for key, value in snippet.items():
        if value == "pymc.sample":
            return key
    return 'trace'

def run_pymc_code(full_code: str):
    with tempfile.TemporaryDirectory() as tmpdir:
        script_path = os.path.join(tmpdir, "generated_model.py")
        with open(script_path, "w") as f:
            f.write(full_code)
        try:
            result = subprocess.check_output(["python", script_path], text=True)
            return True, result
        except subprocess.CalledProcessError as e:
            print("Error running generated code:", e)
            return False, str(e)

def store_elpd(seed, dataset, llm_model, elpd_value, aggregate_dict):
    if dataset not in aggregate_dict:
        aggregate_dict[dataset] = {}
    if llm_model not in aggregate_dict[dataset]:
        aggregate_dict[dataset][llm_model] = {}
    aggregate_dict[dataset][llm_model][seed] = elpd_value

def store_total_tokens(seed, dataset, llm_model, total_tokens, aggregate_dict):
    if dataset not in aggregate_dict:
        aggregate_dict[dataset] = {}
    if llm_model not in aggregate_dict[dataset]:
        aggregate_dict[dataset][llm_model] = {}
    aggregate_dict[dataset][llm_model][seed] = total_tokens

def insert_model_code(boilerplate: str, raw_snippet, trace_name) -> str:
    if isinstance(raw_snippet, list):
        raw_snippet = "".join(raw_snippet)
    model_snippet = raw_snippet.replace("```", "")
    lines = boilerplate.split('\n')
    new_lines = []
    processed_snippet_lines = []
    
    for snippet_line in model_snippet.splitlines():
        if not snippet_line.strip():
            processed_snippet_lines.append("")
        else:
            processed_snippet_lines.append("\t" + snippet_line.lstrip())
    final_snippet = "\n".join(processed_snippet_lines)
    
    snippet_inserted = False
    for line in lines:
        new_lines.append(line.lstrip())
        if line.strip().startswith("with pm.Model() as m:"):
            new_lines.append(final_snippet)
            snippet_inserted = True
    if not snippet_inserted:
        raise ValueError("'with pm.Model() as m:' not found in boilerplate.")
    
    diagnostics = (
        "\t# Posterior diagnostics\n"
        f"\tprint(\"Gelman-Rubin statistic (R-hat):\", az.rhat({trace_name}))\n"
        f"\tprint(\"Effective Sample Size (ESS):\", az.ess({trace_name}))\n"
        f"\tloo = az.loo({trace_name})\n"
        f"\tprint(\"ELPD using LOO:\", loo.elpd_loo)\n"
    )
    new_lines.append(diagnostics)
    
    return "\n".join(new_lines)

# Global dictionaries to aggregate ELPD and token counts.
elpd_aggregate = {}
total_token_aggregate = {}

def process_experiment_output(seed, dataset, llm_model, exec_output, total_tokens):
    store_total_tokens(seed, dataset, llm_model, total_tokens, total_token_aggregate)
    elpd_value = extract_elpd(exec_output)
    if elpd_value is not None:
        store_elpd(seed, dataset, llm_model, elpd_value, elpd_aggregate)
    else:
        print(f"ELPD not found for seed {seed}, dataset {dataset}, model {llm_model}")

def save_elpd_aggregate(aggregate_dict, output_path):
    with open(output_path, "w") as f:
        json.dump(aggregate_dict, f, indent=2)
    print(f"Aggregated ELPD data saved to {output_path}")

def save_total_tokens_aggregate(aggregate_dict, output_path):
    with open(output_path, "w") as f:
        json.dump(aggregate_dict, f, indent=2)
    print(f"Aggregated total tokens saved to {output_path}")

def set_seed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False  

checker_cls = PyMCChecker
pymc_symboltable = {'pm': 'pymc'}

def get_next_experiment_folder(base_dir, modelsize='medium', temperature=0.3):
    new_base_dir = os.path.join(base_dir, modelsize, str(temperature))
    if not os.path.exists(new_base_dir):
        os.makedirs(new_base_dir)
    return new_base_dir

# This function runs experiments for a given seed using the already loaded iter_gen.
def run_experiment_for_seed(iter_gen, seed, llm_model, parent_folder):
    set_seed(seed)
    experiment_folder = os.path.join(parent_folder, f"seed_{seed}")
    os.makedirs(experiment_folder, exist_ok=True)
    with open(os.path.join(experiment_folder, "seed.txt"), "w") as f:
        f.write(f"Seed used: {seed}")
    print(f"Storing logs in: {experiment_folder} for seed {seed}")
    
    overall_results = []
    for model_entry in models_info:
        model_name = model_entry["name"]
        pair_folder = os.path.join(experiment_folder, model_name)
        os.makedirs(pair_folder, exist_ok=True)
        
        prompt = build_prompt_generic(model_entry)
        prompt_file = os.path.join(pair_folder, f"{model_name}_prompt.txt")
        with open(prompt_file, "w") as pf:
            pf.write(prompt)
        
        final_content, symboltable, interim_content = iterative_refine(
            prompt=prompt,
            iter_gen=iter_gen,
            checker_cls=checker_cls,
            unit_name='function_call',
            max_iter=35,
            checker_config={},
            symboltable=pymc_symboltable,
            dedent=True,
        )
        inter_content_file = os.path.join(pair_folder, "interim_output.txt")
        with open(inter_content_file, "w") as icf:
            icf.write(interim_content)
        
        trace_name = extract_trace_name(symboltable)
        full_code = insert_model_code(prompt, final_content, trace_name)
        code_file = os.path.join(pair_folder, "final_code.py")
        with open(code_file, "w") as cf:
            cf.write(full_code)
        
        compiled, exec_output = run_pymc_code(full_code)
        output_file = os.path.join(pair_folder, "exec_output.txt")
        with open(output_file, "w") as outf:
            outf.write(str(exec_output))
        
        total_tokens = iter_gen._metadata['total_tokens']
        process_experiment_output(seed, model_name, llm_model.replace("/", "_"), exec_output, total_tokens)
        
        result_record = {
            "llm_model": llm_model,
            "model_name": model_name,
            "folder": pair_folder,
            "compiled": compiled
        }
        overall_results.append(result_record)
        print(f"Completed: {model_name} for {llm_model}")
    
    summary_file = os.path.join(experiment_folder, "overall_results.txt")
    total_codes = len(overall_results)
    compiled_success = sum(1 for r in overall_results if r["compiled"])
    compiled_failure = total_codes - compiled_success
    with open(summary_file, "w") as f:
        f.write("=== OVERALL RESULTS SUMMARY ===\n")
        f.write(f"Seed: {seed}\n")
        f.write(f"Total generated codes    : {total_codes}\n")
        f.write(f"Compiled successfully    : {compiled_success}\n")
        f.write(f"Compilation failed       : {compiled_failure}\n\n")
        if compiled_success > 0:
            f.write("Successful compilations:\n")
            for r in overall_results:
                if r["compiled"]:
                    f.write(f" - LLM: {r['llm_model']}, Model: {r['model_name']}, Folder: {r['folder']}\n")
        if compiled_failure > 0:
            f.write("\nFailed compilations:\n")
            for r in overall_results:
                if not r["compiled"]:
                    f.write(f" - LLM: {r['llm_model']}, Model: {r['model_name']}, Folder: {r['folder']}\n")
    print(f"\n=== PIPELINE COMPLETE FOR SEED {seed} ===\n")
    return {
        "seed": seed,
        "total_codes": total_codes,
        "compiled_success": compiled_success,
        "compiled_failure": compiled_failure,
        "experiment_folder": experiment_folder
    }

###############################################################################
# Running multiple seeds for each LLM model (load each model once)
###############################################################################
def run_experiments_across_models(num_seeds=10, temperature=0.3, modelsize="medium"):
    aggregated_summaries = {}
    multi_seed_folder = get_next_experiment_folder(base_dir="results/exec", modelsize=modelsize, temperature=temperature)
    print(f"All experiments will be stored under: {multi_seed_folder}")

    # expt_description = input("Please enter a description for the experiment: ")
    # description_file = os.path.join(multi_seed_folder, "desc.txt")
    # with open(description_file, "w") as df:
    #     df.write(expt_description)

    small_models = [
        "microsoft/Phi-3.5-mini-instruct",
        "Qwen/Qwen2.5-Coder-3B"
    ]

    medium_models = [
        "meta-llama/Meta-Llama-3-8B",
        "google/codegemma-7b",
        "Qwen/Qwen2.5-Coder-7B",
        "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    ]
    # For each LLM model:
    models = small_models if modelsize == 'small' else medium_models

    for llm_model in models:
        print(f"\n=== Processing LLM model: {llm_model} ===")
        llm_folder = os.path.join(multi_seed_folder, llm_model.replace("/", "_"))
        os.makedirs(llm_folder, exist_ok=True)
        aggregated_summaries[llm_model] = []
        # Load the model only once for this llm_model.
        iter_gen = IterGen(
            grammar='refinegen/itergen/itergen/syncode/syncode/parsers/grammars/python_grammar.lark',
            # grammar='python',
            model_id=llm_model,
            device='cuda',
            do_sample=True,
            temperature=temperature
        )
        for seed in range(1, num_seeds+1):
            print(f"\n--- Running seed: {seed} for model: {llm_model} ---")
            summary = run_experiment_for_seed(iter_gen, seed, llm_model, parent_folder=llm_folder)
            aggregated_summaries[llm_model].append(summary)
        del iter_gen
        torch.cuda.empty_cache()
    
    # Aggregate overall summary across models.
    total_codes_all = sum(s["total_codes"] for model in aggregated_summaries for s in aggregated_summaries[model])
    total_success_all = sum(s["compiled_success"] for model in aggregated_summaries for s in aggregated_summaries[model])
    total_failure_all = sum(s["compiled_failure"] for model in aggregated_summaries for s in aggregated_summaries[model])
    aggregated_summary_file = os.path.join(multi_seed_folder, "aggregated_summary.txt")
    with open(aggregated_summary_file, "w") as f:
        f.write("=== AGGREGATED SUMMARY ACROSS MODELS AND SEEDS ===\n")
        f.write(f"Number of models: {len(models)}\n")
        f.write(f"Number of seeds per model: {num_seeds}\n\n")
        f.write(f"Total generated codes (all): {total_codes_all}\n")
        f.write(f"Total compiled successfully: {total_success_all}\n")
        f.write(f"Total compilation failed   : {total_failure_all}\n")
    print(f"\nAggregated summary saved at: {aggregated_summary_file}")
    
    # Save aggregated ELPD and token count information.
    aggregated_elpd_file = os.path.join(multi_seed_folder, "aggregated_elpd.json")
    save_elpd_aggregate(elpd_aggregate, aggregated_elpd_file)
    
    aggregated_tokens_file = os.path.join(multi_seed_folder, "total_tokens.json")
    save_total_tokens_aggregate(total_token_aggregate, aggregated_tokens_file)

if __name__ == "__main__":
    temperature = float(sys.argv[1]) if len(sys.argv) > 1 else 0.3
    modelsize = sys.argv[2] if len(sys.argv) > 2 else "medium"
    seeds = int(sys.argv[3]) if len(sys.argv) > 3 else 10
    
    run_experiments_across_models(num_seeds=seeds, temperature=temperature, modelsize=modelsize)

