from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
import random
import torch
import sys
import os

sys.path.append(os.path.abspath(".."))
from groq_req import judge

torch.cuda.empty_cache()

seed = 2
random.seed(seed)

def create_batch_file(prompts, groundtruth, joint_size):
    # model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"
    models = ["llama-3.3-70b-versatile", "gpt-4o-mini", "gpt-4.1-mini"]

    for model in models:
        gts = {}
        with open(f"batch_input/Rebuttal/batch_file_joint_mode_{joint_size}_oneshot_{model}_{groundtruth[0][0]}.jsonl", "w") as f:
            for idx, content in enumerate(prompts, start=1):
                record = {
                    "custom_id": f"request-{idx}",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": model,
                        "messages": [
                            {"role": "user", "content": content}
                        ]
                    }
                }
                f.write(json.dumps(record, separators=(",", ":")) + "\n")
                gts[f"request-{idx}"] = {"prompt":content , "groundtruth":tuple(groundtruth[idx-1])}

        with open(f"batch_input/Rebuttal/joint_mode_{joint_size}_oneshot_{model}_{groundtruth[0][0]}_groundtruth.json", "w") as f:
            json.dump(gts, f, indent=4)



def evaluate_api_batch(prompts, groundtruth, expected, joint_size):
    model = "llama-3.3-70b-versatile"
    # label = 'toyota'
    # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"
    with open(f"batch_input/Rebuttal/joint_mode_{joint_size}_oneshot_{model}_{groundtruth[0][0]}_groundtruth.json", "r") as f:
        data = json.load(f)

    prompts = []
    groundtruth = []
    for i in range(1, len(data) + 1):
        key = f"request-{i}"
        entry = data.get(key)
        prompts.append(entry["prompt"])
        groundtruth.append(tuple(entry["groundtruth"]))

    # res_file = f"../batch_output/batch_output_label_{label}_{joint_size}_{model}.jsonl" 
    # res_file = f"../batch_output/batch_output_label_sens_{joint_size}_{model}_{label}.jsonl" 
    res_file = f"../batch_output/Rebuttal/batch_output_joint_mode_oneshot_{joint_size}_{model}_seed{seed}.jsonl"
    responses = [None] * len(prompts)
    with open(res_file, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            custom_id = record.get("custom_id", "")         
            i = int(custom_id.split("-")[1])               
            content = record["response"]["body"]["choices"][0]["message"]["content"]
            responses[i-1] = content

    num_prompts = len(prompts)
    incorrect_responses = []
    num_mistake = 0
    num_calls = 0
    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts

    for i in range(len(responses)):
        res = responses[i].replace(prompts[i], '')
        # t = 'Output Format that you should strictly follow:\n\nMode = ('
        # if t in res:
        #     res = res.split(t)[1]
        # print(res)
        # print('-'*100)
        check = check_pattern(res, groundtruth[i])
 
        # print(f'expected: {expected[i]}')
        # print(f'groundtruth: {groundtruth[i]}')
        # print(f'pattern-check: {check}')
        # print('-'*100)
        if not check: 
            # num_mistake += 1
            # print(prompts[i])
            # print('-'*100)
            # print(f'groundtruth: {groundtruth[i]}')
            # print(f'i is {i}')
            # print('-'*100)
            # print(res)
            # print('='*100)
            judgement = judge(prompts[i].split('Instruction')[0], expected[i], res)
            num_calls += 1
            print(f'num_calls: {num_calls}')
            if judgement == 0:
                incorrect_responses.append(f"Expected: {expected[i]}\t" + res)
                num_mistake += 1

               
    # print(f'num calls: {num_calls}')
    print(f'num mistakes: {num_mistake}')
    # print('\n\n\n\n\n\n')
    print({'num_mistakes': num_mistake, 'num_judge_calls': num_calls, 'num_correct': (num_prompts - num_mistake), 'accuracy': (num_prompts - num_mistake) / num_prompts})
    out[model] = {"num_mistakes": num_mistake, "num_judge_calls": num_calls, "num_correct": (num_prompts - num_mistake), "accuracy": (num_prompts - num_mistake) / num_prompts, "incorrect_responses": incorrect_responses}
    print(f'Accuracy: {out[model]["accuracy"]}')
    # with open(f"../eval_results/Rebuttal/joint_mode_{joint_size}_mushroom_{model}.json", "w") as file:
    #     json.dump(out, file, indent=4)
    


def extract_params(freqs, num_permut=10):
    rv_names = ["X", "Y", "Z", "T", "U", "V", "W"]
    keys, values = list(freqs.keys()), list(freqs.values())
    sample_size = sum(values)
    rvs = {rv_names[i]: set() for i in range(len(keys[0]))}

    for i in range(len(keys)):
        for j in range(len(rvs.keys())):
            rvs[rv_names[j]].add(keys[i][j])

    num_vars = len(rvs.keys())
    s = ""
    rv, labels = list(rvs.keys()), list(rvs.values())
    for i in range(num_vars):
        l = list(labels[i])
        l.sort()
        s += f"{rv[i]} can take outcomes from " + "{" + f"{', '.join(l)}" + "}, "

    freqs_text, groundtruth = [], []
    for i in range(num_permut):
        random.shuffle(values)
        f = {keys[j]: values[j] for j in range(len(keys))}
        groundtruth.append(max(f, key=f.get))
        f = {str(key).replace("'",""): value for key, value in f.items()}
        f_text = "\n".join(f"- {key}: {value}" for key, value in f.items())
        freqs_text.append(f_text)
    return num_vars, sample_size, rvs, s[:-2], freqs_text, groundtruth

    


def check_pattern(input_string, gt):
    patt = fr"(\*\*)?(Mode|mode)(\*\*)?\s*(=|is|:)\s*(\\\()*\s*\(" 
    for i in range(len(gt)):
        patt += f"\s*{gt[i]}\s*,\s*"
    patt = patt[:-4] + "\)"
    return (re.search(patt, input_string) is not None) or (("boxed{(" + ", ".join(list(gt))+ ")}") in input_string) 


def batch_prompts(prompts, batch_size):
    for i in range(0, len(prompts), batch_size):
        yield prompts[i:i + batch_size]




def generate_prompts(n, freq, permute=False):
    prompts, groundtruth = [], []
    num_vars, sample_size, rvs, s, freqs_text, gt = extract_params(freq)
    for i in range(n):
        groundtruth.append(gt[i%len(gt)])
        prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

{freqs_text[i%len(freqs_text)]}

Task: Identify the mode (most probable outcome) of the joint distribution P({', '.join(rvs.keys())}).

Instructions:
1. Do not write any code or pseudocodes.
2. Strictly follow the output format below.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. You get negative penalty for not following the output format.

Output Format that you should strictly follow:

Mode = ({', '.join(['']*num_vars)})
"""
        prompts.append(prompt.strip())
    
    return prompts, groundtruth




def generate_oneshot_prompts2(n, freqs, permute=False): 
    prompts, groundtruth = [], []
    num_vars, sample_size, rvs, s, freqs_text, gt = extract_params(freqs,1)
    groundtruth = [gt[0]]*n
    prompt = """Task Definition:
Consider the joint probability distribution of a set of discrete random variables. You will be provided with the frequencies of each outcome of the joint distribution in a set of independently drawn samples.  
Your task is to identify the mode (the most probable outcome) of the joint distribution.  


--- EXAMPLE with Solution---
Consider discrete random variables X, Y, Z with joint distribution P(X, Y, Z) where X can take outcomes from {A, B}, Y can take outcomes from {C, D}, Z can take outcomes from {E, F, G}. 
In 60 independent samples drawn from the joint distribution P(X, Y, Z), the observed frequencies are:
- (A, C, E): 5  
- (A, C, F): 7  
- (A, C, G): 3  
- (A, D, E): 10  
- (A, D, F): 2  
- (A, D, G): 3  
- (B, C, E): 6  
- (B, C, F): 8  
- (B, C, G): 2  
- (B, D, E): 3  
- (B, D, F): 3  
- (B, D, G): 8 

Example Task (solved): Identify the mode (most probable outcome) of the joint distribution P(X, Y, Z).

Step-by-step solution:
1. The mode of a joint distribution is the single outcome that appears most frequently in the observed samples.
2. Look at the list of all outcomes and their frequencies.
(A, C, E): 5
(A, C, F): 7
(A, C, G): 3
(A, D, E): 10
(A, D, F): 2
(A, D, G): 3
(B, C, E): 6
(B, C, F): 8
(B, C, G): 2
(B, D, E): 3
(B, D, F): 3
(B, D, G): 8

3. Compare the frequencies to find the maximum.
   - The largest count is 10, which corresponds to the outcome (A, D, E).
4. Since no other outcome has a higher frequency than 10, this outcome is the most probable (i.e., the mode).
5. mode of the joint distribution is (A, D, E).

Final Answer:
Mode = (A, D, E)
--- END EXAMPLE ---

Main Task: 
Consider 4 discrete random variables X, Y, Z, T, where X can take outcomes from {A, B, C}, Y can take outcomes from {D, E}, Z can take outcomes from {F, G, H}, T can take outcomes from {I, J, K}. 
In 270 independent samples drawn from their joint distribution P(X, Y, Z, T), the observed frequencies are:
- (A, D, F, I): 6
- (A, D, F, J): 3
- (A, D, F, K): 1
- (A, D, G, I): 7
- (A, D, G, J): 10
- (A, D, G, K): 3
- (A, D, H, I): 1
- (A, D, H, J): 7
- (A, D, H, K): 2
- (A, E, F, I): 8
- (A, E, F, J): 5
- (A, E, F, K): 12
- (A, E, G, I): 15
- (A, E, G, J): 10
- (A, E, G, K): 9
- (A, E, H, I): 4
- (A, E, H, J): 5
- (A, E, H, K): 2
- (B, D, F, I): 1
- (B, D, F, J): 5
- (B, D, F, K): 7
- (B, D, G, I): 1
- (B, D, G, J): 1
- (B, D, G, K): 5
- (B, D, H, I): 1
- (B, D, H, J): 2
- (B, D, H, K): 5
- (B, E, F, I): 1
- (B, E, F, J): 11
- (B, E, F, K): 5
- (B, E, G, I): 2
- (B, E, G, J): 2
- (B, E, G, K): 12
- (B, E, H, I): 2
- (B, E, H, J): 4
- (B, E, H, K): 3
- (C, D, F, I): 13
- (C, D, F, J): 2
- (C, D, F, K): 2
- (C, D, G, I): 13
- (C, D, G, J): 1
- (C, D, G, K): 9
- (C, D, H, I): 4
- (C, D, H, J): 12
- (C, D, H, K): 3
- (C, E, F, I): 4
- (C, E, F, J): 2
- (C, E, F, K): 11
- (C, E, G, I): 3
- (C, E, G, J): 1
- (C, E, G, K): 3
- (C, E, H, I): 4
- (C, E, H, J): 2
- (C, E, H, K): 1

Main Task (to solve): Identify the mode (most probable outcome) of the joint distribution P(X, Y, Z, T).

Instructions:
- You may explain your reasoning, but the final answer should be explicitly summarized at the end. You get negative penalty for not following the output format.

Final Answer:
Mode = (, , ,)
"""
    prompts = [prompt.strip()]*n

    return prompts, groundtruth
 

 

# def generate_code_prompts(n, freq, permute=False):
#     prompts, groundtruth = [], []
#     num_vars, sample_size, rvs, s, freqs_text, gt = extract_params(freq)
#     for i in range(n):
#         groundtruth.append(gt[i%len(gt)])
#         prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

# {freqs_text[i%len(freqs_text)]}

# Task: write a Python code to identify the mode (most probable outcome) of the joint distribution P({', '.join(rvs.keys())}).
# """
#         prompts.append(prompt.strip())
    
#     return prompts, groundtruth



def eval_mode(models, freqs, num_prompts, permute=True, mode="local"):

    # prompts, groundtruth = generate_prompts(num_prompts, freqs)

    prompts, groundtruth = generate_oneshot_prompts2(num_prompts, freqs)
    # groundtruth = ["('C', 'E', 'G', 'I')"]*num_prompts
    expected = [f"Mode = " + str(groundtruth[i]).replace("'","") for i in range(num_prompts)]

    # print(f'prompt sample:\n{prompts[-1]}')
    # print(f'groundtruth sample: {groundtruth[-1]}')
    # print(f'expected sample: {expected[-1]}')
    # return

    joint_size = len(freqs.keys())

    if mode == "api_create_batch":
        create_batch_file(prompts, groundtruth, joint_size)
        return

    elif mode == "api_eval_batch":
        evaluate_api_batch(prompts, groundtruth, expected, joint_size)
        return


    batch_size = num_prompts if num_prompts < 40 else 40


    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    out["permute"] = permute

    print(f'prompt sample: \n{prompts[0]}\n\n\n\n')

    num_calls = 0
    max_tokens = 1024 if len(freqs.keys()) < 30 else 2048

    for model_id in models:
        if model_id == "meta-llama/Llama-3.1-8B-Instruct":
            max_tokens = 1000
        model, tokenizer = None, None
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token

        print(f'Model: {model_id}')

        responses = []
        num_mistake = 0
        num_calls = 0
        incorrect_responses = []
        for batch in batch_prompts(prompts, batch_size):
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            attention_mask = inputs["attention_mask"]

            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_tokens,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id,
                # do_sample=True,
                # top_k=1,
                # top_p=0.9,
            )
            
            respon = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            responses.extend(respon)
        
        
        # for i in range(len(responses)):
        #     print(responses[i])
        #     print(f'\n\n******groundtruth: {groundtruth[i]}')
        #     print('-'*200)

        # return



        for i in range(len(responses)):
            res = responses[i].replace(prompts[i], '')
            t = 'Output Format that you should strictly follow:\n\nMode = ('
            if t in res:
                res = res.split(t)[1]
            # print(res)
            # print('-'*100)
            check = check_pattern(res, groundtruth[i])
            # print(f'expected: {expected[i]}')
            # print(f'groundtruth: {groundtruth[i]}')
            print(f'res {i}: pattern-check: {check}')
            # print('-'*100)
            if not check: 
                print(res)
                print('-'*100)
                judgement = judge(prompts[i].split('Instruction')[0], expected[i], res)
                num_calls += 1
                print(f'num_calls: {num_calls}')
                if judgement == 0:
                    incorrect_responses.append(res)
                    num_mistake += 1

                    

        print(f'num calls: {num_calls}')
        print(f'num mistakes: {num_mistake}')
        print('\n\n\n\n\n\n')
        out[model_id] = {"num_mistakes": num_mistake, "num_judge_calls": num_calls, "num_correct": (num_prompts - num_mistake), "accuracy": (num_prompts - num_mistake) / num_prompts, "incorrect_responses": incorrect_responses}

        print(out[model_id]['accuracy'])

        # file = open(f"accuracy_{model_id}.txt", "w")
        # file.write(f'Accuracy: {out[model_id]["accuracy"]}\n')
        # file.close()
        # print('-'*100)

    # with open(f"../eval_results/Rebuttal/joint_mode_16_mushroom_local.json", "w") as file:
        # json.dump(out, file, indent=4)    

    # with open(f"../eval_results/mode/joint_mode_{'_'.join(list(map(str, freqs.values())))}.json", "w") as file:
    #     json.dump(out, file, indent=4)

    # with open(f"../eval_results/mode/joint_mode_permu_jointsize_{len(freqs.keys())}_seed{seed}.json", "w") as file:
    #     json.dump(out, file, indent=4)
 


