import preproc_warning
import re
import tqdm
import os
import multiprocessing
from peft import PeftModel

data_prefixes = [
    # "20250221_162117_The_United_States_has_the_highest_federal_income_tax_rate_of_any_Western_country_01JMMKPVR985AQZE97VA1PVWZC"
    # "20241206_172559_Regular_fasting_will_not_improve_your_health_01JEEEKX412AWG9A1KMSSFC8SD"
    # "20241015_020620_Economic_growth_always_occurs_when_taxes_are_lowered_01JA6XECGBRVBCQC7GF94VZQQ8",
    # "20241028_153911_Angels_are_real_01JB9V4TTHV4FRSNK025CW4X3N",
    # "20241028_153927_A__body_cleanse,__in_which_you_consume_only_particular_kinds_of_nutrients_over_1-3_days,_helps_your_body_to_eliminate_toxins_01JB9V4TTHV4FRSNK02H14T0X8",
    # "20241028_153951_Regular_fasting_will_not_improve_your_health_01JB9V4TTHV4FRSNK01WCDRM0E"
]

error_stop = True

if error_stop and os.path.exists("error.log"):
    os.remove("error.log")

input_dir = os.path.join("../../data", "processed_data")
# filename_pattern = r"202412(06|11)_.*\.csv"
# filename_pattern = r"202(4|5)(\d{4})_.*\.csv"
filename_pattern = r"2025(03|04|05|06).*\.csv"

list_files = [f for f in os.listdir(input_dir) if re.match(filename_pattern, f) is not None and os.path.isfile(os.path.join(input_dir, f))]
# list_files = [f for f in os.listdir(input_dir) if f.endswith('.csv') and os.path.isfile(os.path.join(input_dir, f))]
# List of invalid files to exclude
invalid_files = [
    "20250307_005857_Regular_fasting_will_improve_your_health_01JNQ0VASHPVHHPWEMS9ZW45MK.csv",
    "20250221_162336_The_US_deficit_increased_after_President_Obama_was_elected_01JMMKSXQ28HBP1PQ0Z4FJ86V7.csv",
    "20250407_193709_The_position_of_the_planets_at_the_time_of_your_birth_can_influence_your_personality_01JR8T86G0GKQW3ZX1HY14N3D2.csv",
    "20250403_200510_The_US_deficit_increased_after_President_Obama_was_elected_01JQYJJF8919Z9P7QC2FXAZR1R.csv"
]

# Remove invalid files from list_files
list_files = [f for f in list_files if f not in invalid_files]

# For testing only: randomly pick 20 files for testing
# import random
# random.seed(42)  # For reproducibility
# if len(list_files) > 20:
#     list_files = random.sample(list_files, 20)

data_prefixes.extend([f[:-4] for f in list_files])

# print(len(data_prefixes))
# data_prefixes = data_prefixes[:1]  # For testing
# list_already_processed = os.listdir("../../result/simulation")
# Create a new list excluding already processed items instead of modifying during iteration
# data_prefixes = [prefix for prefix in data_prefixes if prefix not in list_already_processed]

# print(len(list_already_processed))
# print(len(data_prefixes))
# exit()

models = [
    ("openai", "gpt-4o-mini-2024-07-18"),
    # ("openai", "ft:gpt-4o-mini-2024-07-18:camer:round-split-valid:BRTJQLtG"),
    
    # ("huggingface", "mistralai/Mistral-7B-v0.1"),
    # ("huggingface", "HuggingFaceH4/mistral-7b-sft-beta"),
    # ("huggingface", "HuggingFaceH4/zephyr-7b-beta"),
    # ("huggingface", "mistralai/Mistral-7B-Instruct-v0.1"),
    
    # ("huggingface", "meta-llama/Meta-Llama-3-8B"),
    # ("huggingface", "RLHFlow/LLaMA3-SFT"),
    # ("huggingface", "RLHFlow/LLaMA3-iterative-DPO-final"),
    # ("huggingface", "meta-llama/Meta-Llama-3-8B-Instruct"),
    
    # ("huggingface", "meta-llama/Llama-3.1-8B"),
    # ("huggingface", "allenai/Llama-3.1-Tulu-3-8B-SFT"),
    # ("huggingface", "allenai/Llama-3.1-Tulu-3-8B-DPO"),
    # ("huggingface", "allenai/Llama-3.1-Tulu-3-8B"),
    # ("huggingface", "meta-llama/Llama-3.1-8B-Instruct"),
    # ("huggingface", "mini-twitter/ft:Llama-3.1-8B-Instruct-250528:round-split-valid-5epochs"),
    
    # ("huggingface", "allenai/OLMo-2-1124-7B"),
    # ("huggingface", "allenai/OLMo-2-1124-7B-SFT"),
    # ("huggingface", "allenai/OLMo-2-1124-7B-DPO"),
    # ("huggingface", "allenai/OLMo-2-1124-7B-Instruct"),
    
    # ("huggingface", "mini-twitter/Llama-3.1-Tulu-3-8B-MT-DDPO-0129")
]


print("Loading Dependencies")

import simulate_conversation as sc
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import pathlib
import typing


model_cache: typing.Dict[str, typing.Tuple[AutoModelForCausalLM, AutoTokenizer, typing.List[str]]] = {}


def get_model(model_name, data_prefix):
    if model_name in model_cache:
        model, tokenizer, _ = model_cache[model_name]
        model_cache[model_name][2].remove(data_prefix)
        return model, tokenizer, len(model_cache[model_name][2])
    
    if model_name.startswith("mini-twitter/"):
        model_path = pathlib.Path(os.path.join("../../finetuned_models", model_name.split("/")[1]))
    elif model_name.startswith("../"):
        base_model_path = 'meta-llama/Llama-3.1-8B-Instruct'
        model_path = model_name
    else:
        model_path = model_name
    # tokenizer = AutoTokenizer.from_pretrained(model_path)
    if model_name.startswith("../"):
        tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    
    model_family, _ = sc.Agent.get_model_family(model_name)
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
    if model_name.startswith("../"):
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            device_map="auto",
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        ).eval()
        model = PeftModel.from_pretrained(base_model, model_path)
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            quantization_config=quantization_config,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
        ).eval()
    model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
    if model_family == "mistral":
        model.generation_config.pad_token_id = model.generation_config.eos_token_id
    
    model_cache[model_name] = (model, tokenizer, data_prefixes)
    model_cache[model_name][2].remove(data_prefix)
    
    return model, tokenizer, len(model_cache[model_name][2])


def execute_task(api_type, model_name, data_prefix):
    if api_type == "huggingface":
        model, tokenizer, remaining_count = get_model(model_name, data_prefix)
    else:
        model = None
        tokenizer = None
    
    topic = re.search(r'\d{8}_\d{6}_(.*)_.{26}', data_prefix).group(1).replace('_', ' ')
    topic = re.sub(r' +', ' ', topic)
    # for version in ["v0", "v1", "v2"]:
    for version in ["v2"]:
        # Check if output file already exists
        output_path = os.path.join("../../result/simulation", data_prefix, sc.Agent.get_model_archive_name(model_name), f"simulation-{version}-ablation.csv")
        if os.path.exists(output_path):
            print(f"Skipping D: {data_prefix}, M: [{api_type}] {model_name}, V: {version} - output file already exists at {output_path}")
            continue
        
        print(f"D: {data_prefix}\nM: [{api_type}] {model_name}\nV: {version}")
        if '20250422_211001' in data_prefix or '20250420_204143' in data_prefix or '20250426_023611' in data_prefix or '20250425_030639' in data_prefix or '20250430_153640' in data_prefix:
            continue
        sc.main(api_type=api_type, model_name=model_name, data_prefix=data_prefix, topic=topic, version=version, model=model, tokenizer=tokenizer)
    
    if api_type == "huggingface":
        if remaining_count == 0:
            del model_cache[model_name]
            del model, tokenizer


def execute_model_task(model_info):
    api_type, model_name = model_info
    print(f"Model Task: [{api_type}] {model_name}")
    for data_prefix in data_prefixes.copy():
        execute_task(api_type, model_name, data_prefix)


if __name__ == "__main__":
    print("Simulation Pipeline")
    
    # Use multiprocessing
    ctx = multiprocessing.get_context("spawn")

    with ctx.Pool(processes=10, maxtasksperchild=1) as pool:
        for _ in tqdm.tqdm(pool.imap_unordered(execute_model_task, models), total=len(models), desc="Simulation Models"):
            pass
    
    # Use single process
    # for model_info in models:
    #     execute_model_task(model_info)

    print("Simulation Complete.")