import subprocess
import os
import re
import functools
import time
from .llm_report_aggr import generate_llm_report_aggr
from . import util
# import random

def get_valid_data_prefixes():
    data_prefixes = [
        # "20250418_185022_The_United_States_has_the_highest_federal_income_tax_rate_of_any_Western_country_01JS520FN7Q0VBJHK3Z8HTEH6X",
        # "20250330_032252_The_position_of_the_planets_at_the_time_of_your_birth_can_influence_your_personality_01JQJFH58XWK9X4H5B9C8WG613",
        # "20250403_210428_The_United_States_has_the_highest_federal_income_tax_rate_of_any_Western_country_01JQYN88BDNXKQQVD39SX8AWV7",
        # '20250407_193243_A_"body_cleanse,"_in_which_you_consume_only_particular_kinds_of_nutrients_over_1-3_days,_helps_your_body_to_eliminate_toxins_01JR8SPMQ4VTKFZ82TQ58S2N97',
        # "20250330_032118_Angels_are_real_01JQJFH2BS7DSH4TZ9CR6C7H91",
        # "20250331_201512_Regular_fasting_will_improve_your_health_01JQPVV1V9FD5ZWH8THKNE27MR",
        # "20250401_201754_Everything_that_happens_can_eventually_be_explained_by_science_01JQSEG2F203R2MB7CYNFQ1CBY",
        # "20250403_200404_The_US_deficit_increased_after_President_Obama_was_elected_01JQYG39C7BGG85M31Q3AWWV0N",
    ]

    input_dir = os.path.join("../../data", "processed_data")
    filename_pattern = r"2025(03|04|05|06).*\.csv"

    list_files = [f for f in os.listdir(input_dir) if re.match(filename_pattern, f) is not None and os.path.isfile(os.path.join(input_dir, f))]

    invalid_files = [
        "20250307_005857_Regular_fasting_will_improve_your_health_01JNQ0VASHPVHHPWEMS9ZW45MK.csv",
        "20250221_162336_The_US_deficit_increased_after_President_Obama_was_elected_01JMMKSXQ28HBP1PQ0Z4FJ86V7.csv",
        "20250407_193709_The_position_of_the_planets_at_the_time_of_your_birth_can_influence_your_personality_01JR8T86G0GKQW3ZX1HY14N3D2.csv",
        "20250403_200510_The_US_deficit_increased_after_President_Obama_was_elected_01JQYJJF8919Z9P7QC2FXAZR1R.csv",
        "20250331_202343_The_United_States_has_the_highest_federal_income_tax_rate_of_any_Western_country_01JQPWA5XP98NBXZYWQRQ3TA73.csv",
        "20250401_040642_The_United_States_has_the_highest_federal_income_tax_rate_of_any_Western_country_01JQQPKVAARAHTCV41MJK27K5E.csv",
        "20250306_182733_Everything_that_happens_can_eventually_be_explained_by_science_01JNP8FT36R1JZ56F3ZP1H1V1T.csv",
        "20250403_191749_A_\"body_cleanse,\"_in_which_you_consume_only_particular_kinds_of_nutrients_over_1-3_days,_helps_your_body_to_eliminate_toxins_01JQYFRNFBG8RG0JRY47K2CGSP.csv",
        "20250404_195614_Everything_that_happens_can_eventually_be_explained_by_science_01JR144VYTRCJGQ76WYYCFCHXZ.csv"
        "20250403_212439_The_United_States_has_the_highest_federal_income_tax_rate_of_any_Western_country_01JQYNW3ZWDPFM7ZN7JAJEDSV0.csv",
        "20250404_195614_Everything_that_happens_can_eventually_be_explained_by_science_01JR144VYTRCJGQ76WYYCFCHXZ.csv"
        # TODO: CHECK
        "20250330_032637_Angels_are_real_01JQJFH5YC773P14ZA6XQ51B28.csv",
        "20250404_195614_Everything_that_happens_can_eventually_be_explained_by_science_01JR144VYTRCJGQ76WYYCFCHXZ.csv"
    ]

    # Remove invalid files from list_files
    list_files = [f for f in list_files if f not in invalid_files]
    data_prefixes.extend([f[:-4] for f in list_files])
    
    return data_prefixes


def get_typed_data_prefixes(data_prefixes, t):
    if t == "breadth":
        dirname = "phase_2_breadth_topics"
    else:
        dirname = "phase_1_depth_topics"
    dir = os.path.join("../../data", "raw_data", dirname)
    typed_set = set()
    for root, dirs, files in os.walk(dir):
        for filename in files:
            if filename.endswith('.csv'):
                # Remove _0.0.1 before .csv if it exists
                base_name = filename.replace('_0.0.1', '').replace('.csv', '')
                typed_set.add(base_name)
    return [dp for dp in data_prefixes if dp in typed_set]

# data_prefixes = get_typed_data_prefixes(data_prefixes, "depth")

# list_already_processed = os.listdir("../../result/eval/human_llm")
# Create a new list excluding already processed items instead of modifying during iteration
# data_prefixes = [prefix for prefix in data_prefixes if prefix not in list_already_processed]

# random sample 5 data prefixes with seed 42
# random.seed(42)
# data_prefixes = random.sample(data_prefixes, 5)  # For testing

models = [
    # ("openai", "gpt-4o-mini-2024-07-18"),
    # ("openai", "ft:gpt-4o-mini-2024-07-18:camer:group-split-all:BOvZjzvU"),
    # ("openai", "ft:gpt-4o-mini-2024-07-18:camer:round-split-all:BOvS862Y"),
    # ("openai", "ft:gpt-4o-mini-2024-07-18:camer:topic-split-all:BOqtcdMB"),
    # ("openai", "ft:gpt-4o-mini-2024-07-18:camer:round-split-valid:BRTJQLtG"),
    
    # ("huggingface", "mistralai/Mistral-7B-v0.1"),
    # ("huggingface", "HuggingFaceH4/mistral-7b-sft-beta"),
    # ("huggingface", "HuggingFaceH4/zephyr-7b-beta"),
    # ("huggingface", "mistralai/Mistral-7B-Instruct-v0.1"),
    # ("huggingface", "mistralai/Mistral-7B-Instruct-v0.3"),
    
    # ("huggingface", "meta-llama/Meta-Llama-3-8B"),
    # ("huggingface", "RLHFlow/LLaMA3-SFT"),
    # ("huggingface", "RLHFlow/LLaMA3-iterative-DPO-final"),
    # ("huggingface", "meta-llama/Meta-Llama-3-8B-Instruct"),
    
    # ("huggingface", "meta-llama/Llama-3.1-8B"),
    # ("huggingface", "allenai/Llama-3.1-Tulu-3-8B-SFT"),
    # ("huggingface", "allenai/Llama-3.1-Tulu-3-8B-DPO"),
    # ("huggingface", "allenai/Llama-3.1-Tulu-3-8B"),
    ("huggingface", "meta-llama/Llama-3.1-8B-Instruct"),
    # ("huggingface", "ft:Llama-3.1-8B-Instruct:round-split-valid-5epochs"),
    # ("huggingface", "meta-llama/Llama-3.1-70B-Instruct"),
    # ("huggingface", "mini-twitter/ft:Llama-3.1-8B-Instruct-250528:round-split-valid-5epochs")
    
    # ("huggingface", "allenai/OLMo-2-1124-7B"),
    # ("huggingface", "allenai/OLMo-2-1124-7B-SFT"),
    # ("huggingface", "allenai/OLMo-2-1124-7B-DPO"),
    # ("huggingface", "allenai/OLMo-2-1124-7B-Instruct"),
    
    # ("huggingface", "mini-twitter/Llama-3.1-Tulu-3-8B-MT-DDPO-0129"),
    
    # ("huggingface", "Qwen/Qwen2.5-32B-Instruct"),
    
    ("huggingface", "mini-twitter/ft:Llama-3.1-8B-Instruct-SFT-20250710:round-5epochs"),
    ("huggingface", "mini-twitter/ft:Llama-3.1-8B-Instruct-SFT-20250710:topic-5epochs"),
    ("huggingface", "mini-twitter/ft:Llama-3.1-8B-Instruct-SFT-20250711:group-5epochs"),
]
model_names = [model_fullname.split("/")[1] if len(model_fullname.split("/")) == 2 else model_fullname for (_, model_fullname) in models]
versions = ["v0", "v1", "v2"]

# List[(type, model_name)]
# type: "seq2seq" or "generation"
eval_models = [
    ("openai", "gpt-4o-mini-2024-07-18"),
    # ("generation", "mistralai/Mistral-7B-Instruct-v0.3"),
    # ("seq2seq", "google/flan-t5-large"),
    # ("seq2seq", "google/flan-t5-xl"),
    # ("seq2seq", "google/flan-t5-xxl"),
]

constants_template = r"""
data_prefix = "INSERT_DATA_PREFIX"
model_name = "INSERT_MODEL_NAME"
topic = "INSERT_TOPIC"
version = "INSERT_VERSION"
player_name_col = "worker_id"
eval_model_name = "INSERT_EVAL_MODEL_NAME"
eval_model_save_name = eval_model_name.split("/")[1] if len(eval_model_name.split("/")) == 2 else eval_model_name
"""


def write_constants(data_prefix: str, topic: str, model_name: str, version: str, eval_model_name: str):
    constants = constants_template \
        .replace("INSERT_DATA_PREFIX", data_prefix) \
            .replace("INSERT_TOPIC", topic) \
                .replace("INSERT_MODEL_NAME", model_name) \
                    .replace("INSERT_VERSION", version) \
                        .replace("INSERT_EVAL_MODEL_NAME", eval_model_name)
    with open("constants.py", "w") as f:
        f.write(constants)

def run_python(script_name, *args):
    ret = subprocess.call(["python", script_name, *args])
    if ret != 0:
        raise Exception(f"Error running {script_name}")


if __name__ == "__main__":

    run_data_prefixes = False  # run for all data prefixes

    print("Loading Dependencies")


    if run_data_prefixes:
        # cache for eval model: {model_name: (model, tokenizer)}
        eval_model_cache = {}

        from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, BitsAndBytesConfig
        import torch

        def get_eval_model(eval_model_type: str, eval_model_name: str):
            global eval_model_cache
            if eval_model_type == "openai":
                eval_model_config = ("openai", eval_model_name)
            else:  # huggingface
                if eval_model_name in eval_model_cache:
                    eval_model, eval_tokenizer = eval_model_cache[eval_model_name]
                    eval_model_config = (eval_model_type, eval_model_name, eval_model, eval_tokenizer)
                else:  # load
                    quantization_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_compute_dtype=torch.bfloat16,
                        bnb_4bit_use_double_quant=True,
                        bnb_4bit_quant_type="nf4",
                    )
                    tokenizer = AutoTokenizer.from_pretrained(eval_model_name)
                    if eval_model_type == "seq2seq":
                        model = AutoModelForSeq2SeqLM.from_pretrained(
                            eval_model_name,
                            device_map="auto",
                            quantization_config=quantization_config,
                            attn_implementation="flash_attention_2",
                            torch_dtype=torch.bfloat16,
                        ).eval()
                    else:  # generation
                        model = AutoModelForCausalLM.from_pretrained(
                            eval_model_name,
                            device_map="auto",
                            quantization_config=quantization_config,
                            attn_implementation="flash_attention_2",
                            torch_dtype=torch.bfloat16,
                        ).eval()
                    model = torch.compile(model, mode="max-autotune", fullgraph=True)
                    if "mistral" in model_name:
                        model.generation_config.pad_token_id = model.generation_config.eos_token_id
                    eval_model_config = (eval_model_type, eval_model_name, model, tokenizer)
                    eval_model_cache[eval_model_name] = (model, tokenizer)
            return eval_model_config
        
        def skip_if_exists(file_path, days=3):
            if os.path.exists(file_path):
                # skip only if the timestamp is within the last d days
                if os.path.getmtime(file_path) > time.time() - days * 24 * 60 * 60:
                    print(f"Skipping {file_path} as it already exists.")
                    return True
            return False

        from . import opinion_proc
        from . import opinion_plot_human_llm
        from . import human_llm

    from . import opinion_plot_topics
    print("Evaluation Pipeline")

    data_prefixes = get_valid_data_prefixes()

    if run_data_prefixes:
        for data_prefix in data_prefixes:
            topic = util.dp_to_topic(data_prefix)
            # write_constants(data_prefix, topic, "invalid/invalid", "invalid", "invalid/invalid")
            
            # run_python("group.py")  # human-only: per group and across group
            
            for version in versions:
                if os.path.exists(f"../../result/eval/human_llm/{data_prefix}/llm_report_{version}.csv"):
                    os.remove(f"../../result/eval/human_llm/{data_prefix}/llm_report_{version}.csv")  # remove old report
                for model_name in model_names:
                    print(f"DP: {data_prefix}\nM: {model_name}\nV: {version}")
                    
                    if not os.path.exists(f"../../result/simulation/{data_prefix}/{model_name}/simulation-{version}.csv"):
                        with open('error.log', 'a+') as f:
                            f.write(f"Simulation not found for {data_prefix} {model_name} {version}\n")
                        continue

                    if not skip_if_exists(f"../../result/eval/human_llm/{data_prefix}/{model_name}/human_llm_score_{version}.csv"):
                        try:
                            human_llm.main(data_prefix, model_name, "worker_id", version)  # human-LLM: similarity
                        except Exception as e:
                            with open('error.log', 'a+') as f:
                                f.write(f"[human_llm] Error processing {data_prefix} {model_name} {version}: {e}\n")
                            continue

                    for eval_model_type, eval_model_name in eval_models:  # human-LLM: opinion
                        if skip_if_exists(f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_memory_{eval_model_name}_{version}.csv"):
                            continue
                        eval_model_config = get_eval_model(eval_model_type, eval_model_name)
                        proc = functools.partial(opinion_proc.main, data_prefix, topic, eval_model_config, version, llm_name=model_name)
                        
                        # proc(True, False, False)  # DISABLED: use human column (independent)
                        # proc(True, True, False)  # DISABLED: use LLM column (independent)
                        # run_python("opinion_plot_human_llm.py", "False")  # DISABLED: plot opinion trajectories (independent)
                        proc(True, False, True)  # use human column (with round memory)
                        proc(True, True, True)  # use LLM column (with round memory)
                        try:
                            opinion_plot_human_llm.main(data_prefix, model_name, eval_model_name, "worker_id", version, True)  # plot opinion trajectories (with round memory)
                        except Exception as e:
                            with open('error.log', 'a+') as f:
                                f.write(f"[opinion_plot_human_llm] Error processing {data_prefix} {model_name} {version}: {e}\n")
                            continue
                        # run_python("llm_report.py")  # human-LLM: averaged scores
                        
                        # run_python("subtopic_list.py")  # DISABLED: subtopic discovery


def filter_valid_data_prefixes(data_prefixes, model_names, version):
    valid_prefixes = []
    for dp in data_prefixes:
        is_valid = True
        for model_name in model_names:
            score_file = f"../../result/eval/human_llm/{dp}/{model_name}/human_llm_score_{version}.csv"
            if not os.path.exists(score_file):
                print(f"Skipping {dp} - missing score file for model {model_name}")
                is_valid = False
                break
        if is_valid:
            valid_prefixes.append(dp)
    return valid_prefixes


if __name__ == "__main__":
    for version in versions:
        # Filter data prefixes to only include those with all required files
        valid_prefixes = filter_valid_data_prefixes(data_prefixes, 
            [m[1].split("/")[1] if "/" in m[1] else m[1] for m in models],
            version)
        
        if not valid_prefixes:
            print(f"No valid data prefixes found for version {version}, skipping llm_report_aggr")
            continue
        
        for t in ["breadth", "depth"]:
            typed_valid_prefixes = get_typed_data_prefixes(valid_prefixes, t)
            generate_llm_report_aggr(typed_valid_prefixes,
                [m[1].split("/")[1] if "/" in m[1] else m[1] for m in models],
                "gpt-4o-mini-2024-07-18", version, f"../../result/eval/human_llm/llm_report_{version}_{t}.csv")
        
        for model_name in model_names:
            for eval_model_type, eval_model_name in eval_models:  # human-LLM: opinion
                # Group data prefixes by topic for cross-experiment analysis
                all_topics = {}
                for dp in valid_prefixes:
                    dp_topic = util.dp_to_topic(dp)
                    if dp_topic not in all_topics:
                        all_topics[dp_topic] = []
                    all_topics[dp_topic].append(dp)
                
                # Call opinion_plot_topics for each topic with all related data prefixes
                for topic_name, related_prefixes in all_topics.items():
                    print(f"Generating topic plot for: {topic_name} with {len(related_prefixes)} data prefixes")
                    import opinion_plot_topics
                    opinion_plot_topics.main(topic_name, related_prefixes, model_name, 
                                            eval_model_name.split("/")[1] if "/" in eval_model_name else eval_model_name, 
                                            version, "worker_id")
            opinion_plot_topics.main("All", valid_prefixes, model_name, 
                                            eval_model_name.split("/")[1] if "/" in eval_model_name else eval_model_name, 
                                            version, "worker_id")
