from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
import random
import torch
import sys
import os

sys.path.append(os.path.abspath(".."))
from groq_req import judge



torch.cuda.empty_cache()

seed = 2
alpha = 25
random.seed(seed)



def create_batch_file(prompts, groundtruth, joint_size):
    model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"
    gts = {}
    with open(f"batch_input/batch_file_joint_mode_samples_{joint_size}_{model}_seed{seed}_alpha{alpha}.jsonl", "w") as f:
        for idx, content in enumerate(prompts, start=1):
            record = {
                "custom_id": f"request-{idx}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "messages": [
                        {"role": "user", "content": content}
                    ]
                }
            }
            f.write(json.dumps(record, separators=(",", ":")) + "\n")
            gts[f"request-{idx}"] = {"prompt":content , "groundtruth":tuple(groundtruth[idx-1])}

    with open(f"batch_input/joint_mode_samples_{joint_size}_{model}_seed{seed}_groundtruth_alpha{alpha}.json", "w") as f:
        json.dump(gts, f, indent=4)



def evaluate_api_batch(prompts, groundtruth, expected, joint_size):
    model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"

    with open(f"batch_input/joint_mode_samples_{joint_size}_{model}_seed{seed}_groundtruth_alpha{alpha}.json", "r") as f:
        data = json.load(f)
    prompts = []
    groundtruth = []
    for i in range(1, len(data) + 1):
        key = f"request-{i}"
        entry = data.get(key)
        prompts.append(entry["prompt"])
        groundtruth.append(tuple(entry["groundtruth"]))
    
    expected = [f"Mode = " + str(groundtruth[i]).replace("'","") for i in range(len(prompts))]
    
    res_file = f"../batch_output/batch_output_joint_mode_samples_{joint_size}_{model}_seed{seed}_alpha{alpha}.jsonl"
    responses = [None] * len(prompts)
    with open(res_file, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            custom_id = record.get("custom_id", "")         
            i = int(custom_id.split("-")[1])               
            content = record["response"]["body"]["choices"][0]["message"]["content"]
            responses[i-1] = content

    num_prompts = len(prompts)
    incorrect_responses = []
    num_mistake = 0
    num_calls = 0
    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts

    for i in range(len(responses)):
        res = responses[i].replace(prompts[i], '')
        # t = 'Output Format that you should strictly follow:\n\nMode = ('
        # if t in res:
        #     res = res.split(t)[1]
        # print(res)
        # print('-'*100)
        check = check_pattern(res, groundtruth[i])
        # print(f'expected: {expected[i]}')
        # print(f'groundtruth: {groundtruth[i]}')
        # print(f'pattern-check: {check}')
        # print('-'*100)
        if not check: 
            # print(prompts[i])
            # print('-'*100)
            # print(f'groundtruth: {groundtruth[i]}')
            # print('-'*100)
            # print(f'response: \n{res}')
            # print('='*100)
            judgement = judge(prompts[i].split('Instruction')[0], expected[i], res)
            num_calls += 1
            print(f'num_calls: {num_calls}')
            if judgement == 0:
                incorrect_responses.append(f"Expected: {expected[i]}\t" + res)
                num_mistake += 1

                    
    print(f'num calls: {num_calls}')
    print(f'num mistakes: {num_mistake}')
    out[model] = {"num_mistakes": num_mistake, "num_judge_calls": num_calls, "num_correct": (num_prompts - num_mistake), "accuracy": (num_prompts - num_mistake) / num_prompts, "incorrect_responses": incorrect_responses}
    print({"num_mistakes": num_mistake, "num_judge_calls": num_calls, "num_correct": (num_prompts - num_mistake), "accuracy": (num_prompts - num_mistake) / num_prompts})
    
    with open(f"../eval_results/mode/joint_mode_samples_jointsize_{joint_size}_{model}_seed{seed}_alpha{alpha}.json", "w") as file:
        json.dump(out, file, indent=4)
    


def extract_params(freqs, num_permut=10):
    rv_names = ["X", "Y", "Z", "T", "U", "V", "W"]
    keys, values = list(freqs.keys()), list(freqs.values())
    
    rvs = {rv_names[i]: set() for i in range(len(keys[0]))}

    for i in range(len(keys)):
        for j in range(len(rvs.keys())):
            rvs[rv_names[j]].add(keys[i][j])

    num_vars = len(rvs.keys())
    s = ""
    rv, labels = list(rvs.keys()), list(rvs.values())
    for i in range(num_vars):
        l = list(labels[i])
        l.sort()
        s += f"{rv[i]} can take outcomes from " + "{" + f"{', '.join(l)}" + "}, "

    freqs_text, groundtruth = [], []
    alpha = 9
    sample_size = sum(values) + (len(keys) * (alpha-5))

    for i in range(num_permut):
        random.shuffle(values)
        f = {keys[j]: values[j] for j in range(len(keys))}
        groundtruth.append(max(f, key=f.get))
        samples = []
        for k in keys:
            samples += [str(k).replace("'", "")] * (f[k] + alpha - 5)
        random.shuffle(samples)
        f_text = '\n'.join(f"- {s}" for s in samples)
        freqs_text.append(f_text)

    return num_vars, sample_size, rvs, s[:-2], freqs_text, groundtruth

    


def check_pattern(input_string, gt):
    patt = fr"(\*\*)?(Mode|mode)(\*\*)?\s*(=|is|:)\s*(\\\()*\s*\(" 
    for i in range(len(gt)):
        patt += f"\s*{gt[i]}\s*,\s*"
    patt = patt[:-4] + "\)"
    return (re.search(patt, input_string) is not None) or (("boxed{(" + ", ".join(list(gt))+ ")}") in input_string) 


def batch_prompts(prompts, batch_size):
    for i in range(0, len(prompts), batch_size):
        yield prompts[i:i + batch_size]




def generate_prompts(n, freq, permute=False):
    prompts, groundtruth = [], []
    num_vars, sample_size, rvs, s, freqs_text, gt = extract_params(freq)

    for i in range(n):
        groundtruth.append(gt[i%len(gt)])
        prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. Below are {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}). The samples are:

{freqs_text[i%len(freqs_text)]}

Task: Count the frequency of each outcome and identify the mode (most probable outcome) of the joint distribution P({', '.join(rvs.keys())}).

Instructions:
1. Do not write any code or pseudocodes.
2. Strictly follow the output format below.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. You get negative penalty for not following the output format.

Output Format that you should strictly follow:

Mode = ({', '.join(['']*num_vars)})
"""
        prompts.append(prompt.strip())
    
    return prompts, groundtruth




def eval_joint_mode_samples(models, freqs, num_prompts, permute=True, mode="local"):

    prompts, groundtruth = generate_prompts(num_prompts, freqs)
    print(f'prompt sample: \n{prompts[0]}\n\n\n\n')
    print(f'groundtruth: {groundtruth[0]}\n\n\n\n')
    expected = [f"Mode = " + str(groundtruth[i]).replace("'","") for i in range(num_prompts)]
    
    joint_size = len(freqs.keys())

    if mode == "api_create_batch":
        create_batch_file(prompts, groundtruth, joint_size)
        return

    elif mode == "api_eval_batch":
        evaluate_api_batch(prompts, groundtruth, expected, joint_size)
        return

    
    batch_size = num_prompts if num_prompts < 55 else 30


    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    out["permute"] = permute

    # print(f'prompt sample: \n{prompts[0]}\n\nExpected: {expected[0]}\n\n')

    num_calls = 0
    max_tokens = 1024 if len(freqs.keys()) < 30 else 2048
    max_tokens = 4000
    for model_id in models:

        model, tokenizer = None, None
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token

        print(f'Model: {model_id}')

        responses = []
        num_mistake = 0
        num_calls = 0
        incorrect_responses = []
        for batch in batch_prompts(prompts, batch_size):
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            attention_mask = inputs["attention_mask"]

            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_tokens,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id
            )
            
            respon = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            responses.extend(respon)
        
        
        
        for i in range(len(responses)):
            res = responses[i].replace(prompts[i], '')
            t = 'Output Format that you should strictly follow:\n\nMode = ('
            if t in res:
                res = res.split(t)[1]
            print(res)
            print('-'*100)
            check = check_pattern(res, groundtruth[i])
            # print(f'expected: {expected[i]}')
            # print(f'groundtruth: {groundtruth[i]}')
            # print(f'pattern-check: {check}')
            # print('-'*100)
            # if not check: 
            #     # print(res)
            #     # print('-'*100)
            #     judgement = judge(prompts[i].split('Instruction')[0], expected[i], res)
            #     num_calls += 1
            #     print(f'num_calls: {num_calls}')
            #     if judgement == 0:
            #         incorrect_responses.append(res)
            #         num_mistake += 1

                    

    #     print(f'num calls: {num_calls}')
    #     print(f'num mistakes: {num_mistake}')
    #     print('\n\n\n\n\n\n')
    #     out[model_id] = {"num_mistakes": num_mistake, "num_judge_calls": num_calls, "num_correct": (num_prompts - num_mistake), "accuracy": (num_prompts - num_mistake) / num_prompts, "incorrect_responses": incorrect_responses}
    #     print(f'Accuracy: {(num_prompts - num_mistake) / num_prompts}')
    # # with open(f"../eval_results/mode/joint_mode_{'_'.join(list(map(str, freqs.values())))}.json", "w") as file:
    # #     json.dump(out, file, indent=4)

    # with open(f"../eval_results/Rebuttal/joint_mode_samples_{len(freqs.keys())}_deepseek_seed{seed}.json", "w") as file:
    #     json.dump(out, file, indent=4)
 
    # return out


