from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from groq_req import judge
import random
import torch
import numpy as np
import scipy.stats as st
torch.cuda.empty_cache()

seed = 2
random.seed(seed)



def create_batch_file(prompts, queries, joint_size):
    # model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"
    models = ["llama-3.3-70b-versatile", "gpt-4o-mini", "gpt-4.1-mini"]
    for model in models:
        expected_outputs = {}
        with open(f"batch_input/Rebuttal/batch_file_cond_mle_{joint_size}_oneshot_{model}_seed{seed}.jsonl", "w") as f:
            for idx, content in enumerate(prompts, start=1):
                record = {
                    "custom_id": f"request-{idx}",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": model,
                        "messages": [
                            {"role": "user", "content": content}
                        ]
                    }
                }
                f.write(json.dumps(record, separators=(",", ":")) + "\n")
                expected = ('\n'.join([f'{k} = {v}' for k,v in queries[idx-1]['answer'].items()]))
                expected_outputs[f"request-{idx}"] = {"prompt":content , "expected":expected}

        with open(f"batch_input/Rebuttal/cond_mle_{joint_size}_oneshot_expected_{model}_seed{seed}.json", "w") as f:
            json.dump(expected_outputs, f, indent=4)



def evaluate_api_batch(prompts, queries, joint_size):
    # model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    model = "gpt-4.1-mini" 
    with open(f"batch_input/Rebuttal/cond_mle_{joint_size}_oneshot_expected_{model}_seed{seed}.json", "r") as f:
        data = json.load(f)
    prompts = []
    expected = []
    for i in range(1, len(data) + 1):
        key = f"request-{i}"
        entry = data.get(key)
        prompts.append(entry["prompt"])
        expected.append(entry["expected"])

    res_file = f"../batch_output/Rebuttal/batch_output_cond_mle_oneshot_{joint_size}_{model}_seed{seed}.jsonl"
    responses = [None] * len(prompts)
    with open(res_file, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            custom_id = record.get("custom_id", "")         
            i = int(custom_id.split("-")[1])               
            content = record["response"]["body"]["choices"][0]["message"]["content"]
            responses[i-1] = content


    num_prompts = len(prompts)
    incorrect_responses = []
    num_mistake = 0
    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    tvds = []

    for i in range(len(responses)):
        res = responses[i].replace(prompts[i], '')
        # print(res)
        # print('-'*50)
        # print(f'expected: {expected[i]}')
        # print('-'*50)
        res = res.replace('\mid', '|')
        res = res.replace('\approx', '=')
        gt, pred = get_MLE_probs(res, expected[i])
        if pred is None:
            num_mistake += 1
            incorrect_responses.append(res)
            continue

        tvd = TVD(gt, pred)
        tvds.append(tvd)

    # try:
    #     tvd_array = np.array(tvds)
    #     mean = np.mean(tvd_array)
    #     standard_error = st.sem(tvd_array)  # Standard error of the mean
    #     # 95% confidence interval using Student's t-distribution
    #     confidence_level = 0.95
    #     degrees_freedom = len(tvd_array) - 1
    #     confidence_interval = st.t.interval(confidence_level, degrees_freedom, loc=mean, scale=standard_error)
    #     print(f"Mean TVD: {mean:.4f}")
    #     print(f"95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f})")
    #     print(f"Error bar: ±{(confidence_interval[1] - mean):.4f}")
    # except Exception as e:
    #     print(f"Error occurred: {e}")
    #     print(tvds)

    print(f'num mistakes: {num_mistake}')
    out[model] = {"num_mistakes": num_mistake, "Average TVD": (sum(tvds)/len(tvds)), "incorrect_responses": incorrect_responses}
    print({"num_mistakes": num_mistake, "Average TVD": (sum(tvds)/len(tvds))})

    # with open(f"../eval_results/Rebuttal/cond_mle_jointsize_{joint_size}_mushroom_{model}_seed{seed}.json", "w") as file:
    #     json.dump(out, file, indent=4)




def get_cond_queries(rv, labels, freqs):
    queries = []
    labels = [sorted(label) for label in labels]
    for i, q in enumerate(rv):
        for j, cond in enumerate(rv):
            if i == j:
                continue
            for cond_label in labels[j]:
                t = {label_q: 0 for label_q in labels[i]}
                total = 0
                for k, v in freqs.items():
                    if cond_label in k:
                        total += v
                        t[k[i]] += v
                if total == 0:
                    continue
                t = {f'P({q}={k} | {cond}={cond_label})': v / total for k, v in t.items()}
                queries.append({"Q_rv": q, "cond_rv": cond, "cond_label": cond_label, "answer": t})
    random.shuffle(queries)
    return queries


def extract_params(freqs, num_permut=10):
    rv_names = ["X", "Y", "Z", "T", "U", "V", "W"]
    keys, values = list(freqs.keys()), list(freqs.values())
    sample_size = sum(values)
    rvs = {rv_names[i]: set() for i in range(len(keys[0]))}

    for i in range(len(keys)):
        for j in range(len(rvs.keys())):
            rvs[rv_names[j]].add(keys[i][j])

    num_vars = len(rvs.keys())
    
    rv, labels = list(rvs.keys()), list(rvs.values())
    
    s = ""
    for i in range(num_vars):
        l = list(labels[i])
        l.sort()
        s += f"{rv[i]} can take outcomes from " + "{" + f"{', '.join(l)}" + "}, "

    freqs_text = []
    cond_queries = []
    for i in range(num_permut):
        random.shuffle(values)
        f = {keys[j]: values[j] for j in range(len(keys))}
        cond_queries.append(get_cond_queries(rv, labels, f)[0])
        f = {str(key).replace("'",""): value for key, value in f.items()}
        f_text = "\n".join(f"- {key}: {value}" for key, value in f.items())
        freqs_text.append(f_text)
    return num_vars, sample_size, rvs, s[:-2], freqs_text, cond_queries
    


def extract_probability(text: str, outcome: str) -> float:
    pattern = re.compile(rf"{re.escape(outcome)}.*?(0\.[0-9]+)")
    for line in text.splitlines():
        match = pattern.search(line)
        if match:
            return float(match.group(1))
    return None


def get_MLE_probs(response, expected):
    gt, preds = [], []
    for x in expected.split('\n'):
        outcome, p = x.split(" = ")
        pred = extract_probability(response, outcome)
        outcome = outcome.replace('=', ' = ')
        pred_2 = extract_probability(response, outcome)
        if (pred is None) and (pred_2 is None):
            return None, None
        pred = pred if pred is not None else pred_2
        gt.append(float(p))
        preds.append(pred)
    return gt, preds


def TVD(P, Q):
    P = [round(x, 3) for x in P]
    Q = [round(x, 3) for x in Q]
    tvd = 0.5 * sum(abs(p - q) for p, q in zip(P, Q))
    return round(tvd, 5)



def batch_prompts(prompts, batch_size):
    for i in range(0, len(prompts), batch_size):
        yield prompts[i:i + batch_size]


def generate_prompts(n, freqs, permute=False):
    prompts, groundtruth = [], []
    num_vars, sample_size, rvs, s, freqs_text, cond_queries = extract_params(freqs)
    
    for i in range(n):
        query = cond_queries[i%len(cond_queries)]
        groundtruth.append(query)
        q = list(query['answer'].keys())
        q_string = '\n'.join([f'{p} = [value]' for p in q])
        prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

{freqs_text[i%len(freqs_text)]}

Task: Predict the maximum likelihood estimation (MLE) of the conditional distribution P({query['Q_rv']} | {query['cond_rv']}={query['cond_label']}).

Instructions:
1. Think step by step and solve this mathematically using probability theory - do not write any code or pseudocodes as you get negative penalty for that.
2. Clearly state the final estimated probabilities as float numbers with up to four decimal points.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. 

Output Format:
- Final probabilities as float numbers should be listed as:

{q_string}

"""
        prompts.append(prompt.strip())
    
    return prompts, groundtruth





def generate_oneshot_prompts2(n, freqs, permute=False):  
    prompts, out_queries = [], []
    num_vars, sample_size, rvs, s, freqs_text, cond_queries = extract_params(freqs, 1)
    out_queries = [cond_queries[0]]*n
    prompt = """Task Definition:
Consider the joint probability distribution of a set of discrete random variables. You will be provided with the frequencies of each outcome of the joint distribution in a set of independently drawn samples.  
Your task is to identify the predict the maximum likelihood estimation (MLE) of a specified conditional distribution.  

Instructions:
1. Think step by step and solve this mathematically using probability theory - do not write any code or pseudocodes as you get negative penalty for that.
2. Clearly state the final estimated probabilities as float numbers with up to four decimal points.
3. You may explain your reasoning, but the final answer should be explicitly summarized at the end. 
4. Final probabilities as float numbers should be listed at the end as shown in the example.

--- EXAMPLE with Solution---
Consider discrete random variables X, Y, Z with joint distribution P(X, Y, Z) where X can take outcomes from {A, B}, Y can take outcomes from {C, D}, Z can take outcomes from {E, F, G}. 
In 60 independent samples drawn from the joint distribution P(X, Y, Z), the observed frequencies are:
- (A, C, E): 5  
- (A, C, F): 7  
- (A, C, G): 3  
- (A, D, E): 10  
- (A, D, F): 2  
- (A, D, G): 3  
- (B, C, E): 6  
- (B, C, F): 8  
- (B, C, G): 2  
- (B, D, E): 3  
- (B, D, F): 3  
- (B, D, G): 8 

Example Task (solved):  Predict the maximum likelihood estimation (MLE) of the conditional distribution P(Z | X = A).

Step-by-step solution:
1. Extract all rows where X = A:
   - (A, C, E): 5
   - (A, C, F): 7
   - (A, C, G): 3
   - (A, D, E): 10
   - (A, D, F): 2
   - (A, D, G): 3
2. Aggregate counts by Z among those rows:
   - Count(Z = E | X = A) = 5 + 10 = 15
   - Count(Z = F | X = A) = 7 + 2 = 9
   - Count(Z = G | X = A) = 3 + 3 = 6
3. Compute total conditioned samples:
   - TotalCount(X = A) = 15 + 9 + 6 = 30
4. Compute empirical conditional probabilities (MLE) and round to four decimals:
   - P(Z = E | X = A) = 15 / 30 = 0.5000
   - P(Z = F | X = A) = 9 / 30 = 0.3000
   - P(Z = G | X = A) = 6 / 30 = 0.2000

Final Answer:
P(Z=E | X=A) = 0.5000  
P(Z=F | X=A) = 0.3000  
P(Z=G | X=A) = 0.2000
--- END EXAMPLE ---

Main Task:
Consider 4 discrete random variables X, Y, Z, T, where X can take outcomes from {A, B, C}, Y can take outcomes from {D, E}, Z can take outcomes from {F, G, H}, T can take outcomes from {I, J, K}. 
In 270 independent samples drawn from their joint distribution P(X, Y, Z, T), the observed frequencies are:
- (A, D, F, I): 6
- (A, D, F, J): 3
- (A, D, F, K): 1
- (A, D, G, I): 7
- (A, D, G, J): 10
- (A, D, G, K): 3
- (A, D, H, I): 1
- (A, D, H, J): 7
- (A, D, H, K): 2
- (A, E, F, I): 8
- (A, E, F, J): 5
- (A, E, F, K): 12
- (A, E, G, I): 15
- (A, E, G, J): 10
- (A, E, G, K): 9
- (A, E, H, I): 4
- (A, E, H, J): 5
- (A, E, H, K): 2
- (B, D, F, I): 1
- (B, D, F, J): 5
- (B, D, F, K): 7
- (B, D, G, I): 1
- (B, D, G, J): 1
- (B, D, G, K): 5
- (B, D, H, I): 1
- (B, D, H, J): 2
- (B, D, H, K): 5
- (B, E, F, I): 1
- (B, E, F, J): 11
- (B, E, F, K): 5
- (B, E, G, I): 2
- (B, E, G, J): 2
- (B, E, G, K): 12
- (B, E, H, I): 2
- (B, E, H, J): 4
- (B, E, H, K): 3
- (C, D, F, I): 13
- (C, D, F, J): 2
- (C, D, F, K): 2
- (C, D, G, I): 13
- (C, D, G, J): 1
- (C, D, G, K): 9
- (C, D, H, I): 4
- (C, D, H, J): 12
- (C, D, H, K): 3
- (C, E, F, I): 4
- (C, E, F, J): 2
- (C, E, F, K): 11
- (C, E, G, I): 3
- (C, E, G, J): 1
- (C, E, G, K): 3
- (C, E, H, I): 4
- (C, E, H, J): 2
- (C, E, H, K): 1

Main Task (to solve): Predict the maximum likelihood estimation (MLE) of the conditional distribution P(X | Y = D).

Provide step-by-step solution similar to the example and then present the final probabilities.

Output Format:
- Final probabilities as float numbers should be listed as:
P(X=A | Y=D) = [value]
P(X=B | Y=D) = [value]
P(X=C | Y=D) = [value]
"""
    prompts = [prompt.strip()]*n

    return prompts, out_queries










def eval_cond_mle(models, freqs, num_prompts, permute, mode="local"):

    # prompts, queries = generate_prompts(num_prompts, freqs)
    prompts, queries = generate_oneshot_prompts2(num_prompts, freqs)

    # print(f'prompt sample: \n{prompts[1]}')
    # print(f'query sample: \n{queries[1]}')
    # return
    joint_size = len(freqs.keys())

    if mode == "api_create_batch":
        create_batch_file(prompts, queries, joint_size)
        return

    elif mode == "api_eval_batch":
        evaluate_api_batch(prompts, queries, joint_size)
        return
    
    batch_size = num_prompts if num_prompts < 40 else 40

    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    out["permute"] = permute
    ans = '\n'.join([f'{k} = {v}' for k,v in queries[0]['answer'].items()])
    # print(f'prompt sample: \n{prompts[0]}\n\nCorrect Answer:\n{ans}\n\n\n\n')

    num_calls = 0
    max_tokens = 1024 if len(freqs.keys()) < 30 else 2048
    max_tokens = 3000
    for model_id in models:

        model, tokenizer = None, None
        tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token

        print(f'Model: {model_id}')

        responses = []
        num_mistake = 0
        incorrect_responses = []
        tvds = []
        for batch in batch_prompts(prompts, batch_size):
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            attention_mask = inputs["attention_mask"]

            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_tokens,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id
            )
            
            respon = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            responses.extend(respon)
        
        
        
        for i in range(len(responses)):
            res = responses[i].replace(prompts[i], '')
            print(res)
            print('-'*50)
            answer = '\n'.join([f'{k} = {v}' for k,v in queries[i]['answer'].items()])
            print(f'expected: {answer}')
            print('-'*50)
            res = res.replace('\mid', '|')
            res = res.replace('\approx', '=')
            gt, pred = get_MLE_probs(res, answer)
            print(f'gt: {gt}')
            if pred is None:
                num_mistake += 1
                incorrect_responses.append(res)
                print(f"Error in response")
                # print('-'*50)
                continue
            # print(f'pred: {pred}')
            # print('-'*100)
            tvd = TVD(gt, pred)
            tvds.append(tvd)
        # return
        # try:
        #     tvd_array = np.array(tvds)
        #     mean = np.mean(tvd_array)
        #     standard_error = st.sem(tvd_array)  # Standard error of the mean
        #     # 95% confidence interval using Student's t-distribution
        #     confidence_level = 0.95
        #     degrees_freedom = len(tvd_array) - 1
        #     confidence_interval = st.t.interval(confidence_level, degrees_freedom, loc=mean, scale=standard_error)
        #     print(f"Mean TVD: {mean:.4f}")
        #     print(f"95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f})")
        #     print(f"Error bar: ±{(confidence_interval[1] - mean):.4f}")
        # except Exception as e:
        #     print(f"Error occurred: {e}")
        #     print(tvds)
                    

        print(f'num mistakes: {num_mistake}')
        
        out[model_id] = {"num_mistakes": num_mistake, "Average TVD": (sum(tvds)/len(tvds)), "incorrect_responses": incorrect_responses}
        print(f'Average TVD: {sum(tvds)/len(tvds)}')

    # with open(f"../eval_results/Rebuttal/cond_mle_16_mushroom_local.json", "w") as file:
    #     json.dump(out, file, indent=4)  
    # with open(f"../eval_results/MLE/cond_permu_jointsize_{len(freqs.keys())}_qwen.json", "w") as file:
    #     json.dump(out, file, indent=4)
 
    return out






