#!/usr/bin/env python3
import os
import json
import torch
import gc
from transformers import AutoTokenizer
from typing import Tuple, Union
import syncode.infer as infer
import sys
import random
import numpy as np
import argparse
sys.path.append("..")
from commons.data_pymc import datas_info, build_prompt_generic
from commons.utils import run_pymc_code

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def insert_model_code(template: str, raw_snippet: Union[str, list],
                      trace_var: str) -> Tuple[str,int]:
    if isinstance(raw_snippet, list):
        raw_snippet = ''.join(raw_snippet)
    lines = raw_snippet.replace('```','').splitlines()
    proc = []
    for L in lines:
        proc.append('' if not L.strip() else '\t'+L.lstrip())
        if 'pm.sample' in L:
            break
    snippet_code = '\n'.join(proc)

    out_lines, inserted = [], False
    for bl in template.splitlines():
        out_lines.append(bl.lstrip())
        if bl.strip().startswith('with pm.Model() as m:'):
            out_lines.append(snippet_code)
            inserted = True
    if not inserted:
        raise ValueError("'with pm.Model() as m:' not found in template")
    out_lines.append(f'\tsummary = az.summary({trace_var})')
    return '\n'.join(out_lines)


def insert_model_code_NT(template: str, raw_snippet: Union[str, list], trace_var: str) -> str:
    # Convert list to string if needed
    if isinstance(raw_snippet, list):
        raw_snippet = ''.join(raw_snippet)
    
    # Process snippet lines, adding indentation
    processed_lines = []
    for line in raw_snippet.replace('```', '').splitlines():
        if not line.strip():
            processed_lines.append('')
        else:
            processed_lines.append('\t' + line.lstrip())
    
    snippet_code = '\n'.join(processed_lines)
    
    # Insert the snippet into the template
    out_lines = []
    inserted = False
    
    for line in template.splitlines():
        out_lines.append(line.lstrip())
        if line.strip().startswith('with pm.Model() as m:'):
            out_lines.append(snippet_code)
            inserted = True
    
    if not inserted:
        raise ValueError("'with pm.Model() as m:' not found in template")
    
    # Append summary line at the end
    out_lines.append(f'\tsummary = az.summary({trace_var})')
    
    return '\n'.join(out_lines)


def run_experiment(temperature, num_seeds=10, use_grammar=False, model_size='medium', use_nt=False):
    # Select models based on size
    if model_size == 'medium':
        models = [
            "meta-llama/Meta-Llama-3-8B", "google/codegemma-7b",
            "Qwen/Qwen2.5-Coder-7B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
        ]
    else:  # small
        models = ["microsoft/Phi-3.5-mini-instruct", "Qwen/Qwen2.5-Coder-3B"]

    execution_counts = {}
    total_executions = 0
    token_counts = {}  # Initialize token_counts dictionary

    # Create base results directory based on configuration
    base_path = 'results/exec'
    if use_grammar:
        base_path = os.path.join(base_path, 'grammar')
    else:
        base_path = os.path.join(base_path, 'original')
    
    if use_nt:
        base_path = os.path.join(base_path, f'{model_size}-NT')
    else:
        base_path = os.path.join(base_path, model_size)
        
    temp_results_dir = os.path.join(base_path, f'{temperature}')
    os.makedirs(temp_results_dir, exist_ok=True)

    for llm in models:
        llm_key = llm.replace('/','_')
        print(f"\n=== Loading model {llm_key} onto GPU ===")
        
        try:
            local_llm = infer.Syncode(
                mode='grammar_strict' if use_grammar else 'original',
                model=llm,
                do_sample=True,
                temperature=temperature,
                grammar='syncode/parsers/grammars/python_grammar.lark',
                device='cuda',
                max_new_tokens=400
            )

            execution_counts[llm_key] = {}
            token_counts[llm_key] = {}  # Initialize token counts for this model

            for seed in range(1, num_seeds + 1):
                print(f"\nRunning seed {seed} for {llm_key}")
                set_seed(seed)
                
                execution_counts[llm_key][f"seed_{seed}"] = {}
                token_counts[llm_key][f"seed_{seed}"] = {}  # Initialize token counts for this seed

                for entry in datas_info:
                    dataset = entry['name']
                    execution_counts[llm_key][f"seed_{seed}"][dataset] = 0

                    try:
                        prompt = build_prompt_generic(entry)
                        snippet = local_llm.infer(prompt)
                        token_counts[llm_key][f"seed_{seed}"][dataset] = local_llm.model.total_tokens
                        
                        # Use appropriate insert_model_code function based on NT flag
                        if use_nt:
                            full_code = insert_model_code_NT(
                                entry['template_code'], snippet, 'trace'
                            )
                        else:
                            full_code = insert_model_code(
                                entry['template_code'], snippet, 'trace'
                            )
                        
                        # Create directory for this model/dataset/seed combination
                        save_dir = os.path.join(temp_results_dir, llm_key, dataset, f'seed_{seed}')
                        os.makedirs(save_dir, exist_ok=True)
                        
                        # Save the generated code
                        with open(os.path.join(save_dir, 'generated_code.py'), 'w') as f:
                            f.write(full_code)
                        
                        compiled, _ = run_pymc_code(full_code)
                        if compiled:
                            execution_counts[llm_key][f"seed_{seed}"][dataset] += 1
                            total_executions += 1

                    except Exception as e:
                        print(f"Error with {dataset} for {llm_key} seed {seed}: {str(e)}")
                        # Save error message
                        error_file = os.path.join(save_dir, 'error.txt')
                        with open(error_file, 'w') as f:
                            f.write(str(e))

            print(f"=== Unloading model {llm_key} from GPU ===")
            try:
                local_llm.model.cpu()
            except Exception:
                pass
            del local_llm
            torch.cuda.empty_cache()
            gc.collect()

        except Exception as e:
            print(f"Error loading model {llm_key}: {str(e)}")

    # Save results for this temperature
    with open(os.path.join(temp_results_dir, 'results.json'), 'w') as f:
        json.dump({
            'execution_counts': execution_counts,
            'total_executions': total_executions,
            'token_counts': token_counts
        }, f, indent=2)

    print(f"\nTotal successful executions at temperature {temperature}: {total_executions}/{len(models) * num_seeds * len(datas_info)}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--grammar', action='store_true', help='Use grammar-based generation')
    parser.add_argument('--model-size', choices=['small', 'medium'], default='medium', help='Size of models to use')
    parser.add_argument('--nt', action='store_true', help='Use NT version of code insertion')
    args = parser.parse_args()

    for temp in [0.3, 0.4]:
        print(f"\nRunning experiment with temperature {temp}")
        run_experiment(temp, num_seeds=10, use_grammar=args.grammar, 
                      model_size=args.model_size, use_nt=args.nt)
