from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
import os
import sys
import random
import torch
import numpy as np
import scipy.stats as st

torch.cuda.empty_cache()
random.seed(5)


def create_batch_file(prompts, joint_size):
    # model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"
    models = ["llama-3.3-70b-versatile", "gpt-4o-mini", "gpt-4.1-mini"]
    for model in models:
        with open(f"batch_input/Rebuttal/batch_file_joint_mle_{joint_size}_mushroom_{model}.jsonl", "w") as f:
            for idx, content in enumerate(prompts, start=1):
                record = {
                    "custom_id": f"request-{idx}",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": model,
                        "messages": [
                            {"role": "user", "content": content}
                        ]
                    }
                }
                f.write(json.dumps(record, separators=(",", ":")) + "\n")




def evaluate_api_batch(prompts, groundtruth, joint_size):
    # model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    model = "gpt-4.1-mini"
    # res_file = f"../batch_output/Rebuttal/batch_output_joint_mle_{joint_size}_{model}_seed5.jsonl"
    res_file = f"../batch_output/Rebuttal/batch_output_joint_mle_mushroom_{joint_size}_{model}_seed5.jsonl"
    responses = [None] * len(prompts)
    with open(res_file, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            custom_id = record.get("custom_id", "")         
            i = int(custom_id.split("-")[1])               
            content = record["response"]["body"]["choices"][0]["message"]["content"]
            responses[i-1] = content

    num_prompts = len(prompts)
    incorrect_responses = []
    num_mistake = 0
    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    tvds = []
    for i in range(len(responses)):
        res = responses[i].replace(prompts[i], '')
        # print(res)
        # print('-'*100)
        # print(f'groundtruth: {groundtruth[i]}')
        gt, pred = get_MLE_probs(res, groundtruth[i])
        if pred is None:
            num_mistake += 1
            incorrect_responses.append(res)
            # print(f"Error in response")
            # print('-'*100)
            continue
        # print(f'pred: {pred}')
        # print('-'*100)
        tvd = TVD(gt, pred)
        tvds.append(tvd)

    # try:
    #     tvd_array = np.array(tvds)
    #     mean = np.mean(tvd_array)
    #     standard_error = st.sem(tvd_array)  # Standard error of the mean
    #     # 95% confidence interval using Student's t-distribution
    #     confidence_level = 0.95
    #     degrees_freedom = len(tvd_array) - 1
    #     confidence_interval = st.t.interval(confidence_level, degrees_freedom, loc=mean, scale=standard_error)
    #     print(f"Mean TVD: {mean:.4f}")
    #     print(f"95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f})")
    #     print(f"Error bar: ±{(confidence_interval[1] - mean):.4f}")
    # except Exception as e:
    #     print(f"Error occurred: {e}")
    #     print(tvds)
        

    print(f'num mistakes: {num_mistake}')
    print('\n\n\n\n\n\n')
    out[model] = {"num_mistakes": num_mistake, "Average TVD": (sum(tvds)/len(tvds)), "incorrect_responses": incorrect_responses}
    print(f'Average TVD: {sum(tvds)/len(tvds)}')

    # with open(f"../eval_results/MLE/joint_mle_jointsize_{joint_size}_{model}.json", "w") as file:
    #     json.dump(out, file, indent=4)
    
    with open(f"../eval_results/Rebuttal/joint_mle_{joint_size}_mushroom_{model}.json", "w") as file:
        json.dump(out, file, indent=4)




def extract_params(freqs, num_permut=10):
    rv_names = ["X", "Y", "Z", "T", "U", "V", "W"]
    keys, values = list(freqs.keys()), list(freqs.values())
    sample_size = sum(values)
    rvs = {rv_names[i]: set() for i in range(len(keys[0]))}

    for i in range(len(keys)):
        for j in range(len(rvs.keys())):
            rvs[rv_names[j]].add(keys[i][j])

    num_vars = len(rvs.keys())
    s = ""
    rv, labels = list(rvs.keys()), list(rvs.values())
    for i in range(num_vars):
        l = list(labels[i])
        l.sort()
        s += f"{rv[i]} can take outcomes from " + "{" + f"{', '.join(l)}" + "}, "

    freqs_text, groundtruth = [], []
    for i in range(num_permut):
        random.shuffle(values)
        f = {keys[j]: values[j] for j in range(len(keys))}
        groundtruth.append(get_expected_output(f))
        f = {str(key).replace("'",""): value for key, value in f.items()}
        f_text = "\n".join(f"- {key}: {value}" for key, value in f.items())
        freqs_text.append(f_text)
    return num_vars, sample_size, rvs, s[:-2], freqs_text, groundtruth



def get_expected_output(freqs):
    sample_size = sum(freqs.values())
    keys, values = list(freqs.keys()), list(freqs.values())
    l = []
    keys = [str(key).replace("'","") for key in keys]
    for i in range(len(keys)):
        l.append(f"P{(keys[i])} = {round(values[i]/sample_size, 4)}")
    expected = '\n'.join(l)
    return expected


def extract_probability(text: str, outcome: str) -> float:
    pattern = re.compile(rf"{re.escape(outcome)}.*?(0\.[0-9]+)")
    for line in text.splitlines():
        match = pattern.search(line)
        if match:
            return float(match.group(1))
    return None


def get_MLE_probs(response, expected):
    # print(f"Response:\n{response}\n")
    # print(f"Expected:\n{expected}\n")
    gt, preds = [], []
    for x in expected.split('\n'):
        outcome, p = x.split(" = ")
        pred = extract_probability(response, outcome)
        # print(f'pred is: {pred} for outcome {outcome} and gt is {p}')
        if pred is None:
            return None, None
        gt.append(float(p))
        preds.append(pred)
    return gt, preds


def TVD(P, Q):
    P = [round(x, 3) for x in P]
    Q = [round(x, 3) for x in Q]
    tvd = 0.5 * sum(abs(p - q) for p, q in zip(P, Q))
    return round(tvd, 4)


def batch_prompts(prompts, batch_size):
    for i in range(0, len(prompts), batch_size):
        yield prompts[i:i + batch_size]



def generate_prompts(n, freqs, permute=False):
    prompts, groundtruth = [], []
    num_vars, sample_size, rvs, s, freqs_text, gt = extract_params(freqs)
    q = ['P'+s[s.index('('):s.index(')')+1]+' = [value]' for s in freqs_text[0].split('\n')]
    query = '\n'.join(q[:4]) + '\n...' + '\n' + '\n'.join(q[-2:])
    for i in range(n):
        groundtruth.append(gt[i%len(gt)])
        prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

{freqs_text[i%len(freqs_text)]}

Task: Predict the maximum likelihood estimation (MLE) of the joint probability distribution P({', '.join(rvs.keys())}) based on these {sample_size} samples.

Instructions:
1. Think step by step and solve this mathematically using probability theory - do not write any code or pseudocodes as you get negative penalty for that.
2. Clearly state the final estimated probabilities for each ({', '.join(rvs.keys())}) outcome. Output probabilities should be expressed as float numbers with up to four decimal points.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. only your final answer will be graded.

Output Format:
- Final probabilities as float numbers should be listed as:

{query}
"""
        prompts.append(prompt.strip())
    return prompts, groundtruth





def generate_oneshot_prompts2(n, freqs, permute=False):  
    prompts, groundtruth = [], []
    num_vars, sample_size, rvs, s, freqs_text, gt = extract_params(freqs, 1)
    groundtruth = [gt[0]]*n
    prompt = """Task Definition:
Consider the joint probability distribution of a set of discrete random variables. You will be provided with the frequencies of each outcome of the joint distribution in a set of independently drawn samples.  
Your task is to identify the predict the maximum likelihood estimation (MLE) of the joint distribution.  

Instructions:
1. Think step by step and solve this mathematically using probability theory - do not write any code or pseudocodes as you get negative penalty for that.
2. Clearly state the final estimated probabilities as float numbers with up to four decimal points.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. 
4. Final probabilities as float numbers should be listed at the end as shown in the example.

--- EXAMPLE with Solution---
Consider discrete random variables X, Y, Z with joint distribution P(X, Y, Z) where X can take outcomes from {A, B}, Y can take outcomes from {C, D}, Z can take outcomes from {E, F, G}. 
In 60 independent samples drawn from the joint distribution P(X, Y, Z), the observed frequencies are:
- (A, C, E): 5  
- (A, C, F): 7  
- (A, C, G): 3  
- (A, D, E): 10  
- (A, D, F): 2  
- (A, D, G): 3  
- (B, C, E): 6  
- (B, C, F): 8  
- (B, C, G): 2  
- (B, D, E): 3  
- (B, D, F): 3  
- (B, D, G): 8 

Example Task (solved):  Predict the maximum likelihood estimation (MLE) of the joint distribution P(X, Y, Z).

Step-by-step solution:
1. Total number of samples = 60.  
2. Compute probabilities for each joint outcome by dividing frequency by 60.  
   - P(A, C, E) = 5 / 60 = 0.0833  
   - P(A, C, F) = 7 / 60 = 0.1167  
   - P(A, C, G) = 3 / 60 = 0.0500  
   - P(A, D, E) = 10 / 60 = 0.1667  
   - P(A, D, F) = 2 / 60 = 0.0333  
   - P(A, D, G) = 3 / 60 = 0.0500  
   - P(B, C, E) = 6 / 60 = 0.1000  
   - P(B, C, F) = 8 / 60 = 0.1333  
   - P(B, C, G) = 2 / 60 = 0.0333  
   - P(B, D, E) = 3 / 60 = 0.0500  
   - P(B, D, F) = 3 / 60 = 0.0500  
   - P(B, D, G) = 8 / 60 = 0.1333

Final Answer:
P(A, C, E) = 0.0833  
P(A, C, F) = 0.1167  
P(A, C, G) = 0.0500  
P(A, D, E) = 0.1667  
P(A, D, F) = 0.0333  
P(A, D, G) = 0.0500  
P(B, C, E) = 0.1000  
P(B, C, F) = 0.1333  
P(B, C, G) = 0.0333  
P(B, D, E) = 0.0500  
P(B, D, F) = 0.0500  
P(B, D, G) = 0.1333 
--- END EXAMPLE ---

Main Task:
Consider 4 discrete random variables X, Y, Z, T, where X can take outcomes from {A, B, C}, Y can take outcomes from {D, E}, Z can take outcomes from {F, G, H}, T can take outcomes from {I, J, K}. 
In 270 independent samples drawn from their joint distribution P(X, Y, Z, T), the observed frequencies are:
- (A, D, F, I): 6
- (A, D, F, J): 3
- (A, D, F, K): 1
- (A, D, G, I): 7
- (A, D, G, J): 10
- (A, D, G, K): 3
- (A, D, H, I): 1
- (A, D, H, J): 7
- (A, D, H, K): 2
- (A, E, F, I): 8
- (A, E, F, J): 5
- (A, E, F, K): 12
- (A, E, G, I): 15
- (A, E, G, J): 10
- (A, E, G, K): 9
- (A, E, H, I): 4
- (A, E, H, J): 5
- (A, E, H, K): 2
- (B, D, F, I): 1
- (B, D, F, J): 5
- (B, D, F, K): 7
- (B, D, G, I): 1
- (B, D, G, J): 1
- (B, D, G, K): 5
- (B, D, H, I): 1
- (B, D, H, J): 2
- (B, D, H, K): 5
- (B, E, F, I): 1
- (B, E, F, J): 11
- (B, E, F, K): 5
- (B, E, G, I): 2
- (B, E, G, J): 2
- (B, E, G, K): 12
- (B, E, H, I): 2
- (B, E, H, J): 4
- (B, E, H, K): 3
- (C, D, F, I): 13
- (C, D, F, J): 2
- (C, D, F, K): 2
- (C, D, G, I): 13
- (C, D, G, J): 1
- (C, D, G, K): 9
- (C, D, H, I): 4
- (C, D, H, J): 12
- (C, D, H, K): 3
- (C, E, F, I): 4
- (C, E, F, J): 2
- (C, E, F, K): 11
- (C, E, G, I): 3
- (C, E, G, J): 1
- (C, E, G, K): 3
- (C, E, H, I): 4
- (C, E, H, J): 2
- (C, E, H, K): 1

Main Task (to solve): Predict the maximum likelihood estimation (MLE) of the joint distribution P(X, Y, Z, T).

Provide step-by-step solution similar to the example and then present the final probabilities.

Output Format:
- Final probabilities as float numbers should be listed as:
P(A, D, F, I) = [value]  
P(A, D, F, J) = [value] 
...
P(C, E, H, K) = [value]
"""
    prompts = [prompt.strip()]*n

    return prompts, groundtruth








def eval_mle(models, freqs, num_prompts, permute, mode="local"):

    # prompts, groundtruth = generate_prompts(num_prompts, freqs)
    prompts, groundtruth = generate_oneshot_prompts2(num_prompts, freqs)

    # print(f'prompt: \n{prompts[0]}\n\n\n\n')
    # print(f'groundtruth: \n{groundtruth[0]}\n\n')
    # return

    joint_size = len(freqs.keys())

    if mode == "api_create_batch":
        create_batch_file(prompts, joint_size)
        return

    elif mode == "api_eval_batch":
        evaluate_api_batch(prompts, groundtruth, joint_size)
        return

    batch_size = num_prompts if num_prompts < 35 else 35

    # context_length = 1024 if len(freqs.keys()) < 30 else 2048
    context_length = 3000
    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    out["permute"] = permute
    # out["groundtruth_freqs"] = {str(key): value for key, value in freqs.items()}
    # out["groundtruth_probs"] = expected
    print(f'prompt sample: \n{prompts[0]}\n\n\n')
    print(f'groundtruth sample: \n{groundtruth[0]}\n\n')

    for model_id in models:

        model, tokenizer = None, None
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token

        print(f'Model: {model_id}')

        max_token = context_length if model_id!="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" else (1.5*context_length)

        responses = []
        num_mistake = 0
        incorrect_responses = []
        tvds = []
        for batch in batch_prompts(prompts, batch_size):
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            attention_mask = inputs["attention_mask"]

            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_token,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id
            )
            
            respon = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            responses.extend(respon)
        
        
        
        for i in range(len(responses)):
            res = responses[i].replace(prompts[i], '')
            # print(res)
            # print('-'*100)
            # print(f'groundtruth: {groundtruth[i]}')
            gt, pred = get_MLE_probs(res, groundtruth[i])

            if pred is None:
                num_mistake += 1
                incorrect_responses.append(res)
                print(f"Error in response")
                print('-'*100)
                continue
            # print(f'pred: {pred}')
            # print('-'*100)
            tvd = TVD(gt, pred)
            tvds.append(tvd)
        
        # try:
        #     tvd_array = np.array(tvds)
        #     mean = np.mean(tvd_array)
        #     standard_error = st.sem(tvd_array)  # Standard error of the mean
        #     # 95% confidence interval using Student's t-distribution
        #     confidence_level = 0.95
        #     degrees_freedom = len(tvd_array) - 1
        #     confidence_interval = st.t.interval(confidence_level, degrees_freedom, loc=mean, scale=standard_error)
        #     print(f"Mean TVD: {mean:.4f}")
        #     print(f"95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f})")
        #     print(f"Error bar: ±{(confidence_interval[1] - mean):.4f}")
        # except Exception as e:
        #     print(f"Error occurred")
        #     print(tvds)

        print(f'num mistakes: {num_mistake}')
        avg_tvds = (sum(tvds)/len(tvds))
        out[model_id] = {"num_mistakes": num_mistake, "Average TVD": avg_tvds, "TVDs": tvds, "incorrect_responses": incorrect_responses}
        print(f'Average TVD: {avg_tvds}')


    # with open(f"../eval_results/Rebuttal/joint_mle_16_mushroom_local.json", "w") as file:
    #     json.dump(out, file, indent=4)   
    # with open(f"../eval_results/Rebuttal/joint_mle_{joint_size}_local.json", "w") as file:
    #     json.dump(out, file, indent=4)
    # with open(f"../eval_results/MLE/joint_mle_jointsize_27_deepseek.json", "w") as file:
    #     json.dump(out, file, indent=4)
 
    return out


