from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
import random
import torch
torch.cuda.empty_cache()
import os
import sys
sys.path.append(os.path.abspath(".."))
from groq_req import judge

seed = 2
random.seed(seed)



def create_batch_file(prompts, expected, joint_size):
    # model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"
    models = ["llama-3.3-70b-versatile", "gpt-4o-mini", "gpt-4.1-mini"]
    for model in models:
        expected_outputs = {}
        with open(f"batch_input/Rebuttal/batch_file_cond_mode_{joint_size}_oneshot_{model}_seed{seed}.jsonl", "w") as f:
            for idx, content in enumerate(prompts, start=1):
                record = {
                    "custom_id": f"request-{idx}",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": model,
                        "messages": [
                            {"role": "user", "content": content}
                        ]
                    }
                }
                f.write(json.dumps(record, separators=(",", ":")) + "\n")
                expected_outputs[f"request-{idx}"] = {"prompt":content , "expected":expected[idx-1]}

        with open(f"batch_input/Rebuttal/cond_mode_{joint_size}_oneshot_expected_{model}_seed{seed}.json", "w") as f:
            json.dump(expected_outputs, f, indent=4)



def evaluate_api_batch(prompts, queries, expected, joint_size):
    model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"
    with open(f"batch_input/Rebuttal/cond_mode_{joint_size}_oneshot_expected_{model}_seed{seed}.json", "r") as f:
        data = json.load(f)
    prompts = []
    expected = []
    for i in range(1, len(data) + 1):
        key = f"request-{i}"
        entry = data.get(key)
        prompts.append(entry["prompt"])
        expected.append(entry["expected"])
  
    res_file = f"../batch_output/Rebuttal/batch_output_cond_mode_oneshot_{joint_size}_{model}_seed{seed}.jsonl"
    responses = [None] * len(prompts)
    with open(res_file, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            custom_id = record.get("custom_id", "")         
            i = int(custom_id.split("-")[1])               
            content = record["response"]["body"]["choices"][0]["message"]["content"]
            responses[i-1] = content

    num_prompts = len(prompts)
    incorrect_responses = []
    num_mistake = 0
    num_calls = 0
    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts

    for i in range(len(responses)):
        res = responses[i].replace(prompts[i], '')
        split_expec = expected[i].split(' ')
        check = check_pattern(res, split_expec[4], split_expec[6].split('=')[0], split_expec[6].split('=')[1], split_expec[8]) ##### Herre's the problemmmmmm
        if not check: 
            # # print(res)
            # # print('-'*100)
            # # print(f'expected: {expected[i]}')
            # # print('='*100)
            judgement = judge(prompts[i].split('Instruction')[0], expected[i], res)
            num_calls += 1
            print(f'num_calls: {num_calls}')
            if judgement == 0:
                incorrect_responses.append(res)
                num_mistake += 1

    print(f'num calls: {num_calls}')
    print(f'num mistakes: {num_mistake}')
    print({"num_mistakes": num_mistake, "num_judge_calls": num_calls, "num_correct": (num_prompts - num_mistake), "accuracy": (num_prompts - num_mistake) / num_prompts})
    print(f'Accuracy: {(num_prompts - num_mistake) / num_prompts}')
    print('\n\n\n\n\n\n')
    out[model] = {"num_mistakes": num_mistake, "num_judge_calls": num_calls, "num_correct": (num_prompts - num_mistake), "accuracy": (num_prompts - num_mistake) / num_prompts, "incorrect_responses": incorrect_responses}

    # with open(f"../eval_results/Rebuttal/cond_mode_jointsize_{joint_size}_mushroom_{model}_seed{seed}.json", "w") as file:
    #     json.dump(out, file, indent=4)




def get_cond_queries(rv, labels, freqs):
    queries = []
    for i, q in enumerate(rv):
        for j, cond in enumerate(rv):
            if i == j:
                continue
            for cond_label in labels[j]:
                t = {label_q: 0 for label_q in labels[i]}
                for k, v in freqs.items():
                    if cond_label in k:
                        t[k[i]] += v
                cond_freqs = sorted(t.values(), reverse=True)
                if len(cond_freqs) < 2 or cond_freqs[0] != cond_freqs[1]:
                    queries.append({"Q_rv": q, "cond_rv": cond, "cond_label": cond_label, "answer": max(t, key=t.get)})
    random.shuffle(queries)
    return queries


def extract_params(freqs, num_permut=10):
    rv_names = ["X", "Y", "Z", "T", "U", "V", "W"]
    keys, values = list(freqs.keys()), list(freqs.values())
    sample_size = sum(values)
    rvs = {rv_names[i]: set() for i in range(len(keys[0]))}

    for i in range(len(keys)):
        for j in range(len(rvs.keys())):
            rvs[rv_names[j]].add(keys[i][j])

    num_vars = len(rvs.keys())
    
    rv, labels = list(rvs.keys()), list(rvs.values())
 
    s = ""
    for i in range(num_vars):
        l = list(labels[i])
        l.sort()
        s += f"{rv[i]} can take outcomes from " + "{" + f"{', '.join(l)}" + "}, "


    freqs_text = []
    cond_queries = []
    for i in range(num_permut):
        random.shuffle(values)
        f = {keys[j]: values[j] for j in range(len(keys))}
        cond_queries.append(get_cond_queries(rv, labels, f)[0])
        f = {str(key).replace("'",""): value for key, value in f.items()}
        f_text = "\n".join(f"- {key}: {value}" for key, value in f.items())
        freqs_text.append(f_text)
    return num_vars, sample_size, rvs, s[:-2], freqs_text, cond_queries
    
  

def check_pattern(input_string, q_rv, cond_rv, cond_label, gt):
    pattern = fr"(Most|most)(\s*probable\s*)(\s*value of\s*)?(\\\(\s*)?(\s*{q_rv}\s*)(\\\))?(\s*given\s*)(\s*that\s*)?(\\\(\s*)?(\s*{cond_rv}\s*=\s*{cond_label}\s*)(\\\))?(\s*is\s*)(\s*:\s*)?(\[)?(\\\()?(\s*{q_rv}\s*=\s*)?\s*{gt}"
    return ((re.search(pattern, input_string) is not None) or (("boxed{" + gt + "}") in input_string) )



def batch_prompts(prompts, batch_size):
    for i in range(0, len(prompts), batch_size):
        yield prompts[i:i + batch_size]


def generate_prompts(n, freqs, permute=False):
    prompts, out_queries = [], []
    num_vars, sample_size, rvs, s, freqs_text, cond_queries = extract_params(freqs, 10)
    
    for i in range(n):
        query = cond_queries[i%len(cond_queries)]
        out_queries.append(query)
        prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

{freqs_text[i%len(freqs_text)]}

Task: Identify the mode (most probable outcome) of the conditional distribution P({query['Q_rv']} | {query['cond_rv']}={query['cond_label']}).

Instructions:
1. Do not write any code or pseudocodes.
2. Strictly follow the output format below.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. You get negative penalty for not following the output format.

Output Format that you should strictly follow:

Most probable value of {query['Q_rv']} given {query['cond_rv']}={query['cond_label']} is 
"""
        prompts.append(prompt.strip())
    
    return prompts, out_queries


def generate_oneshot_prompts(n, freqs, permute=False): 
    prompts, out_queries = [], []
    num_vars, sample_size, rvs, s, freqs_text, cond_queries = extract_params(freqs, 1)
    out_queries = [cond_queries[0]]*n
    prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

{freqs_text[0]}

--- EXAMPLE ---
Example Task (solved): Identify the mode (most probable outcome) of the conditional distribution P(Z | X = A).

Short step-by-step solution:
1. Extract all observed samples where X = A (all rows above starting with A).
2. Aggregate counts by Z among those samples:
   - Count(Z = F | X = A) = (A,D,F,I)+(A,D,F,J)+(A,D,F,K)+(A,E,F,I)+(A,E,F,J)+(A,E,F,K)
     = 6 + 3 + 1 + 8 + 5 + 12 = 35.
   - Count(Z = G | X = A) = (A,D,G,I)+(A,D,G,J)+(A,D,G,K)+(A,E,G,I)+(A,E,G,J)+(A,E,G,K)
     = 7 + 10 + 3 + 15 + 10 + 9 = 54.
   - Count(Z = H | X = A) = (A,D,H,I)+(A,D,H,J)+(A,D,H,K)+(A,E,H,I)+(A,E,H,J)+(A,E,H,K)
     = 1 + 7 + 2 + 4 + 5 + 2 = 21.
3. Total samples with X = A = 35 + 54 + 21 = 110.
4. Compute empirical conditional probabilities and round to 4 decimals:
   - P(Z = F | X = A) = 35 / 110 = 0.3181818... → 0.3182
   - P(Z = G | X = A) = 54 / 110 = 0.4909090... → 0.4909
   - P(Z = H | X = A) = 21 / 110 = 0.1909090... → 0.1909

   - Final conditional probabilities:
    P(Z=F | X=A) = 0.3182
    P(Z=G | X=A) = 0.4909
    P(Z=H | X=A) = 0.1909

5. Identify mode: Z = G (highest probability 0.4909).

Most probable value of Z given X=A is G.
--- END EXAMPLE ---

Main Task:
Task: Identify the mode (most probable outcome) of the conditional distribution P(X | Y = E).

Instructions:
1. Do not write any code or pseudocodes.
2. Strictly follow the output format below.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. You get negative penalty for not following the output format.

Output Format that you should strictly follow:

Most probable value of X given Y=E is
"""
    prompts = [prompt.strip()]*n

    return prompts, out_queries





def generate_oneshot_prompts2(n, freqs, permute=False): 
    prompts, out_queries = [], []
    num_vars, sample_size, rvs, s, freqs_text, cond_queries = extract_params(freqs, 1)
    out_queries = [cond_queries[0]]*n
    prompt = """Task Definition:
Consider the joint probability distribution of a set of discrete random variables. You will be provided with the frequencies of each outcome of the joint distribution in a set of independently drawn samples.  
Your task is to identify the mode (the most probable outcome) of a specific conditional distribution.  


--- EXAMPLE with Solution---
Consider discrete random variables X, Y, Z with joint distribution P(X, Y, Z) where X can take outcomes from {A, B}, Y can take outcomes from {C, D}, Z can take outcomes from {E, F, G}. 
In 60 independent samples drawn from the joint distribution P(X, Y, Z), the observed frequencies are:
- (A, C, E): 5  
- (A, C, F): 7  
- (A, C, G): 3  
- (A, D, E): 10  
- (A, D, F): 2  
- (A, D, G): 3  
- (B, C, E): 6  
- (B, C, F): 8  
- (B, C, G): 2  
- (B, D, E): 3  
- (B, D, F): 3  
- (B, D, G): 8 

Example Task (solved): Identify the mode (most probable outcome) of the conditional distribution P(Z | X = A).

Step-by-step solution:
1. Extract rows where X = A:
   (A,C,E):5, (A,C,F):7, (A,C,G):3, (A,D,E):10, (A,D,F):2, (A,D,G):3.
2. Aggregate counts by Z among those rows:
   - Count(Z = E | X = A) = 5 + 10 = 15.
   - Count(Z = F | X = A) = 7 + 2 = 9.
   - Count(Z = G | X = A) = 3 + 3 = 6.
3. Total samples with X = A = 15 + 9 + 6 = 30.
4. Compute empirical conditional probabilities.
   - P(Z = E | X = A) = 15 / 30 = 0.5
   - P(Z = F | X = A) = 9 / 30 = 0.3
   - P(Z = G | X = A) = 6 / 30 = 0.2
5. Identify the mode: Z = E (highest probability 0.5).

Final Answer:
Most probable value of Z given X=A is E
--- END EXAMPLE ---

Main Task:
Consider 4 discrete random variables X, Y, Z, T, where X can take outcomes from {A, B, C}, Y can take outcomes from {D, E}, Z can take outcomes from {F, G, H}, T can take outcomes from {I, J, K}. 
In 270 independent samples drawn from their joint distribution P(X, Y, Z, T), the observed frequencies are:
- (A, D, F, I): 6
- (A, D, F, J): 3
- (A, D, F, K): 1
- (A, D, G, I): 7
- (A, D, G, J): 10
- (A, D, G, K): 3
- (A, D, H, I): 1
- (A, D, H, J): 7
- (A, D, H, K): 2
- (A, E, F, I): 8
- (A, E, F, J): 5
- (A, E, F, K): 12
- (A, E, G, I): 15
- (A, E, G, J): 10
- (A, E, G, K): 9
- (A, E, H, I): 4
- (A, E, H, J): 5
- (A, E, H, K): 2
- (B, D, F, I): 1
- (B, D, F, J): 5
- (B, D, F, K): 7
- (B, D, G, I): 1
- (B, D, G, J): 1
- (B, D, G, K): 5
- (B, D, H, I): 1
- (B, D, H, J): 2
- (B, D, H, K): 5
- (B, E, F, I): 1
- (B, E, F, J): 11
- (B, E, F, K): 5
- (B, E, G, I): 2
- (B, E, G, J): 2
- (B, E, G, K): 12
- (B, E, H, I): 2
- (B, E, H, J): 4
- (B, E, H, K): 3
- (C, D, F, I): 13
- (C, D, F, J): 2
- (C, D, F, K): 2
- (C, D, G, I): 13
- (C, D, G, J): 1
- (C, D, G, K): 9
- (C, D, H, I): 4
- (C, D, H, J): 12
- (C, D, H, K): 3
- (C, E, F, I): 4
- (C, E, F, J): 2
- (C, E, F, K): 11
- (C, E, G, I): 3
- (C, E, G, J): 1
- (C, E, G, K): 3
- (C, E, H, I): 4
- (C, E, H, J): 2
- (C, E, H, K): 1

Main Task (to solve): Identify the mode (most probable outcome) of the conditional distribution P(X | Y = E).

Instructions:
1. Do not write any code or pseudocodes.
2. Strictly follow the output format below.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. You get negative penalty for not following the output format.

Output Format that you should strictly follow:
Most probable value of X given Y=E is
"""
    prompts = [prompt.strip()]*n

    return prompts, out_queries


# def generate_code_prompts(n, freqs, permute=False):
#     prompts, out_queries = [], []
#     num_vars, sample_size, rvs, s, freqs_text, cond_queries = extract_params(freqs)
    
#     for i in range(n):
#         query = cond_queries[i%len(cond_queries)]
#         out_queries.append(query)
#         prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

# {freqs_text[i%len(freqs_text)]}

# Task: Write a Python code to identify and print the mode (most probable outcome) of the conditional distribution P({query['Q_rv']} | {query['cond_rv']}={query['cond_label']}). The outpput of your code should be "Mode: <most_probable_value>".
# """
#         prompts.append(prompt.strip())
    
#     return prompts, out_queries



def extract_code_answer(response): 
    pattern = re.compile(r"```python(.*?)```", re.DOTALL)
    
    snippets = pattern.findall(response)
    snippets = [snippet.strip() for snippet in snippets]
    
    return snippets if snippets else None




def eval_cond_mode(models, freqs, num_prompts, permute, mode="local"):

    # prompts, queries = generate_prompts(num_prompts, freqs)
    prompts, queries = generate_oneshot_prompts2(num_prompts, freqs, permute)
    queries = [{'Q_rv': 'X', 'cond_rv': 'Y', 'cond_label': 'E', 'answer': 'A'}]*num_prompts

    # for p in prompts:
        # print(p) 
        # print(f'Query: {queries[0]}')
    #     print(prompts[0].split('\n')[-1]+" "+queries[0]['answer'])
    #     print('-'*100)

    # return
    # prompts, queries = generate_code_prompts(num_prompts, freqs)

    # expected = [(prompts[i].split('\n')[-1]+" "+queries[i]['answer'])for i in range(len(prompts))]
    expected = ["Most probable value of X given Y=E is A"]*num_prompts

    joint_size = len(freqs.keys())

    if mode == "api_create_batch":
        create_batch_file(prompts, expected, joint_size)
        return

    elif mode == "api_eval_batch":
        evaluate_api_batch(prompts, queries, expected, joint_size)
        return

    batch_size = num_prompts if num_prompts < 40 else 40

    # groundtruth = max(freqs, key=freqs.get)

    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    out["permute"] = permute
    # out["groundtruth_freqs"] = {str(key): value for key, value in freqs.items()}

    print(f'prompt sample: \n{prompts[0]}: {expected[0]}\n\n\n\n')

    num_calls = 0
    max_tokens = 1024 if len(freqs.keys()) < 30 else 2048
    max_tokens = 3000
    for model_id in models:

        model, tokenizer = None, None
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token

        print(f'Model: {model_id}')

        responses = []
        num_mistake = 0
        num_calls = 0
        incorrect_responses = []
        for batch in batch_prompts(prompts, batch_size):
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            attention_mask = inputs["attention_mask"]

            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_tokens,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id
            )
            
            respon = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            responses.extend(respon)
        
        # out_codes = {}
        # for i in range(len(responses)):
        #     print(responses[i].replace(prompts[i], ''))
        #     print('-'*100)
        #     # codes = extract_code_answer(responses[i])
        #     # out_codes[i] = {"expected": expected[i], "codes": codes}

        # return
    

        print("Responses are generated. Now evaluating...")
        for i in range(len(responses)):
            res = responses[i].replace(prompts[i], '')
            # print(res)
            # print('-'*50)
            # print(f'expected: {expected[i]}')
            # print('-'*50)
            check = check_pattern(res, queries[i]['Q_rv'], queries[i]['cond_rv'], queries[i]['cond_label'], queries[i]['answer'])
            print(f'res {i}, pattern-check: {check}')
            # print('-'*100)
            if not check: 
                # print(res)
                # print('-'*100)
                judgement = judge(prompts[i].split('Instruction')[0], expected[i], res)
                num_calls += 1
                print(f'num_calls: {num_calls}')
                if judgement == 0:
                    # print(f'judgement: {judgement}')
                    incorrect_responses.append(res)
                    num_mistake += 1

                    

        print(f'num calls: {num_calls}')
        print(f'num mistakes: {num_mistake}')
        
        out[model_id] = {"num_mistakes": num_mistake, "num_judge_calls": num_calls, "num_correct": (num_prompts - num_mistake), "accuracy": (num_prompts - num_mistake) / num_prompts, "incorrect_responses": incorrect_responses}
        print(f'Accuracy: {out[model_id]["accuracy"]}')
        print('\n\n\n\n\n\n')

    # with open(f"../eval_results/Rebuttal/cond_mode_16_mushroom_local.json", "w") as file:
    #     json.dump(out, file, indent=4) 

    # with open(f"../eval_results/Rebuttal/cond_mode_12_skewed_local.json", "w") as file:
    #     json.dump(out, file, indent=4)  

    # with open(f"../eval_results/Rebuttal/cond_mode_permu_jointsize_{len(freqs.keys())}_seed{seed}.json", "w") as file:
    #     json.dump(out, file, indent=4)
 
    # return out








