from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from fractions import Fraction
from collections import Counter
import random
import torch
import numpy as np
import scipy.stats as st
torch.cuda.empty_cache()

seed = 4
random.seed(seed)

import re


def create_batch_file(prompts, joint_size, out_queries):
    # # model = "llama-3.3-70b-versatile"
    # # model = "gpt-4o-mini"
    # model = "gpt-4.1-mini"
    queries = {}
    models = ["llama-3.3-70b-versatile", "gpt-4o-mini", "gpt-4.1-mini"]

    for model in models:
    # with open(f"batch_input/batch_file_cond_samp_{joint_size}_{model}_seed{seed}_20samp.jsonl", "w") as f:
        with open(f"batch_input/Rebuttal/batch_file_cond_samp_{joint_size}_mushroom_{model}_seed{seed}.jsonl", "w") as f:

            for idx, content in enumerate(prompts, start=1):
                record = {
                    "custom_id": f"request-{idx}",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": model,
                        "messages": [
                            {"role": "user", "content": content}
                        ]
                    }
                }
                f.write(json.dumps(record, separators=(",", ":")) + "\n")
                queries[f"request-{idx}"] = {"prompt":content , "query":out_queries[idx-1]}

        # with open(f"batch_input/cond_samp_{joint_size}_queries_{model}_seed{seed}_20samp.json", "w") as f:
        with open(f"batch_input/Rebuttal/cond_samp_{joint_size}_mushroom_queries_{model}_seed{seed}.json", "w") as f:
            json.dump(queries, f, indent=4)



def evaluate_api_batch(prompts, rvs, freqs, joint_size, num_samples):
    # model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    model = "gpt-4.1-mini"
    # with open(f"batch_input/cond_samp_{joint_size}_queries_{model}_seed{seed}_20samp.json", "r") as f:
    with open(f"batch_input/Rebuttal/cond_samp_{joint_size}_mushroom_queries_{model}_seed{seed}.json", "r") as f:
        data = json.load(f)
    prompts = []
    queries = []
    for i in range(1, len(data) + 1):
        key = f"request-{i}"
        entry = data.get(key)
        prompts.append(entry["prompt"])
        queries.append(entry["query"])

    # res_file = f"../batch_output/batch_output_cond_samp_{joint_size}_{model}_seed{seed}.jsonl"
    res_file = f"../batch_output/Rebuttal/batch_output_cond_samp_mushroom_{joint_size}_{model}_seed{seed}.jsonl"

    responses = [None] * len(prompts)
    with open(res_file, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            custom_id = record.get("custom_id", "")         
            i = int(custom_id.split("-")[1])               
            content = record["response"]["body"]["choices"][0]["message"]["content"]
            responses[i-1] = content

    num_prompts = len(prompts)
    incorrect_responses = []
    num_mistake = 0
    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    tvds = []
    gt_tvds = []
    for i in range(len(responses)):
        res = responses[i].replace(prompts[i], '')
        pairs = extract_pairs(res, rvs, queries[i]['Q_rv'], num_samples)
        if pairs is None:
            # print(f' prompt: \n{prompts[i]}')
            # print(f"Incorrect response for prompt {i+1} with {queries[i]['Q_rv']} and {rvs} and num samp {num_samples}: \n{res}")
            # print('-'*100)
            incorrect_responses.append(res)
            num_mistake += 1
        else:
            assert len(pairs) == num_samples, f"Number of samples {len(pairs)} does not match the expected number {num_samples}"
            P_freq, P_gt_samp, Q = get_PQ(freqs[i], queries[i], rvs, pairs)
            tvds.append(TVD(P_freq, Q))
            gt_tvds.append(TVD(P_freq, P_gt_samp))
        

    try:
        tvd_array = np.array(tvds)
        mean = np.mean(tvd_array)
        standard_error = st.sem(tvd_array)  # Standard error of the mean
        # 95% confidence interval using Student's t-distribution
        confidence_level = 0.95
        degrees_freedom = len(tvd_array) - 1
        confidence_interval = st.t.interval(confidence_level, degrees_freedom, loc=mean, scale=standard_error)
        print(f"Mean TVD: {mean:.4f}")
        print(f"95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f})")
        print(f"Error bar: ±{(confidence_interval[1] - mean):.4f}")
    except Exception as e:
        print(f"Error occurred")
        print(tvds)


    tvd, gt_tvd = None, None
    if len(tvds) != 0:
        tvd = sum(tvds) / (len(tvds))
        gt_tvd = sum(gt_tvds) / (len(gt_tvds))

    print(f'num mistakes: {num_mistake}')
    print({'joint-size': joint_size, 'seed': seed, "GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake})
    out[model] = {"GT_Avg_TVD": gt_tvd, "Pred_Avg_TVD": tvd, "num_mistakes": num_mistake, "TVDs": tvds, "incorrect_responses": incorrect_responses}  

    # with open(f"../eval_results/sampling/cond_sample_js{joint_size}_{model}_seed_{seed}_20samp.json", "w") as file:
    with open(f"../eval_results/Rebuttal/cond_sample_js{joint_size}_mushroom_{model}_seed_{seed}.json", "w") as file:
        json.dump(out, file, indent=4)



def TVD(P, Q):
    P = [round(x, 5) for x in P]
    Q = [round(x, 5) for x in Q]
    tvd = 0.5 * sum(abs(p - q) for p, q in zip(P, Q))
    return round(tvd, 5)



def get_cond_queries(rv, labels, freqs):
    limit = 2
    if len(freqs) >= 9:
        limit = 3
    queries = []
    for i, q in enumerate(rv):
        if len(labels[i]) < limit:
            continue
        for j, cond in enumerate(rv):
            if i == j:
                continue
            for cond_label in labels[j]:
                t = {label_q: 0 for label_q in labels[i]}
                for k, v in freqs.items():
                    if cond_label in k:
                        t[k[i]] += v
                cond_freqs = sorted(t.values(), reverse=True)
                if len(cond_freqs) < 2 or cond_freqs[0] != cond_freqs[1]:
                    queries.append({"Q_rv": q, "cond_rv": cond, "cond_label": cond_label})
    random.shuffle(queries)
    return queries



def extract_params(freqs, num_permut=10):
    rv_names = ["X", "Y", "Z", "T", "U", "V", "W"]
    keys, values = list(freqs.keys()), list(freqs.values())
    sample_size = sum(values)
    rvs = {rv_names[i]: set() for i in range(len(keys[0]))}

    for i in range(len(keys)):
        for j in range(len(rvs.keys())):
            rvs[rv_names[j]].add(keys[i][j])

    num_vars = len(rvs.keys())
    
    rv, labels = list(rvs.keys()), list(rvs.values())
    
    s = ""
    for i in range(num_vars):
        l = list(labels[i])
        l.sort()
        s += f"{rv[i]} can take outcomes from " + "{" + f"{', '.join(l)}" + "}, "


    freqs_text = []
    cond_queries = []
    new_freqs = []
    for i in range(num_permut):
        random.shuffle(values)
        f = {keys[j]: values[j] for j in range(len(keys))}
        new_freqs.append(f)
        cond_queries.append(get_cond_queries(rv, labels, f)[0])
        f = {str(key).replace("'",""): value for key, value in f.items()}
        f_text = "\n".join(f"- {key}: {value}" for key, value in f.items())
        freqs_text.append(f_text)
    return num_vars, sample_size, rvs, s[:-2], freqs_text, cond_queries, new_freqs
    


def extract_pairs(res, rvs, query_rv, num_samples):
    patt = r"(### Output\n)?(\d+\.\s*[" + ''.join(rvs[query_rv]) + r"]\s*){" + f'{num_samples}'+r"}"
    match = re.search(patt, res)
    if match is None:
        return None
    matched = (match.group(0).split('\n')[1:] if "utput" in match.group(0) else match.group(0).split('\n'))
    pairs = [x.split('.')[-1].strip() for x in matched if x]
    return pairs  


def get_PQ(freqs, query, rvs, pairs):
    P_freq, Q = [], []
    P_gt_samp = []
    joint_keys = list(freqs.keys())
    cond_keys = {key: 0 for key in rvs[query['Q_rv']]}
    for key in joint_keys:
        if query['cond_label'] in key:
            for target_labels in rvs[query['Q_rv']]:
                if target_labels in key:
                    cond_keys[target_labels] += freqs[key]

    total_cond = sum(cond_keys.values())
    gt_cond_probs = {k: v / total_cond for k, v in cond_keys.items()}
    sampled_freq = random.choices(list(cond_keys.keys()), weights=list(cond_keys.values()), k=len(pairs))
    gt_freq = dict(Counter(sampled_freq))
    pred_freqs = dict(Counter(pairs))
    for key, value in gt_cond_probs.items():
        P_freq.append(value)
        if key in pred_freqs:
            Q.append(pred_freqs[key] / len(pairs))
        else:
            Q.append(0)
        if key in gt_freq:
            P_gt_samp.append(gt_freq[key] / len(pairs))
        else:
            P_gt_samp.append(0)
    
    return P_freq, P_gt_samp, Q



def batch_prompts(prompts, batch_size):
    for i in range(0, len(prompts), batch_size):
        yield prompts[i:i + batch_size]



def generate_prompts(n, freqs, num_samples, permute=True):
    prompts = []
    num_vars, sample_size, rvs, s, freqs_text, cond_queries, new_freqs = extract_params(freqs)
    out_freqs = []
    out_queries = []
    for i in range(n):
        query = cond_queries[i%len(cond_queries)]
        out_queries.append(query)
        # rv_pairs = '(' + ', '.join([x if x!=query['cond_rv'] else query["cond_label"] for x in list(rvs.keys())]) + ')'
        out_freqs.append(new_freqs[i%len(new_freqs)])
            
        prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

{freqs_text[i%len(freqs_text)]}

Task: Generate EXACTLY {num_samples} random samples from the conditional distribution P({query['Q_rv']} | {query['cond_rv']}={query['cond_label']}).

Instructions:
1. OUTPUT MUST BEGIN with "### Output" on a new line.
2. List EXACTLY {num_samples} samples numbered 1 to {num_samples}.
3. Do not write any code or pseudocodes.
4. Strictly follow the output format below
5. You will be penalized for writing codes or not following the output format.

Output Format that you should strictly follow:

### Output
1. {query['Q_rv']}
2. {query['Q_rv']}
...
{num_samples}. {query['Q_rv']}

Where {query['Q_rv']} can take outcomes from """ + "{" + ', '.join(rvs[query['Q_rv']]) + "}."
        prompts.append(prompt.strip())
    
    return prompts, rvs, out_freqs, out_queries






def eval_cond_sampling(models, freqs, num_prompts, num_samples, permute=True, mode="local"):

    prompts, rvs, out_freqs, out_queries = generate_prompts(num_prompts, freqs, num_samples)

    joint_size = len(freqs.keys())

    if mode == "api_create_batch":
        create_batch_file(prompts, joint_size, out_queries)
        return

    elif mode == "api_eval_batch":
        evaluate_api_batch(prompts, rvs, out_freqs, joint_size, num_samples)
        return


    batch_size = num_prompts if num_prompts < 35 else 35

    

    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    out["permute"] = permute

    print(f'prompt sample: \n{prompts[0]}\n\n\n\n')
    print(f'joint size: {joint_size}')

    for model_id in models:

        model, tokenizer = None, None
        tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token

        print(f'Model: {model_id}')

        responses = []
        num_mistake = 0
        incorrect_responses = []
        kls, gt_kls = [], []
        tvds, gt_tvds = [], []
        
        if num_samples > 54:
            max_token = 2048
        elif num_samples > 40:
            max_token = 1024
        else:
            max_token = 512

        if model_id == "meta-llama/Llama-3.1-8B-Instruct":
            max_token = max_token//1
        elif model_id == "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B":
            max_token = max_token * 4
            prompts = [ p + f"\n**Do not explain, do not write code, and do not deviate from the format.**" for p in prompts]
            # prompts = [p + f"\nDo not reason and make sure to generate {num_samples} samples numbered from 1 to {num_samples} starting with ### Output" for p in prompts]

        for batch in batch_prompts(prompts, batch_size):
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            attention_mask = inputs["attention_mask"]

            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_token,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id
            )
            
            respon = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            responses.extend(respon)
        
        print('Responses are generated!')
        
        for i in range(len(responses)):
            res = responses[i].replace(prompts[i], '')
            # print(f'Prompt: \n{prompts[i]}')
            # print('-'*100)
            # print(res)
            pairs = extract_pairs(res, rvs, out_queries[i]['Q_rv'], num_samples)
            # print(f'Pairs: {pairs}')
            # print('='*100)
            if pairs is None or len(pairs) != num_samples:
                incorrect_responses.append(res)
                num_mistake += 1
                # print(f'\nnum_mistake: {num_mistake}') 
            else:
                # assert len(pairs) == num_samples, f"Number of samples {len(pairs)} does not match the expected number {num_samples}\n\nres: {res}\n\npairs: {pairs}"
                P_freq, P_gt_samp, Q = get_PQ(out_freqs[i], out_queries[i], rvs, pairs)
                tvds.append(TVD(P_freq, Q))
                gt_tvds.append(TVD(P_freq, P_gt_samp))
            

        try:
            tvd_array = np.array(tvds)
            mean = np.mean(tvd_array)
            standard_error = st.sem(tvd_array)  # Standard error of the mean
            # 95% confidence interval using Student's t-distribution
            confidence_level = 0.95
            degrees_freedom = len(tvd_array) - 1
            confidence_interval = st.t.interval(confidence_level, degrees_freedom, loc=mean, scale=standard_error)
            print(f"Mean TVD: {mean:.4f}")
            print(f"95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f})")
            print(f"Error bar: ±{(confidence_interval[1] - mean):.4f}")
        except Exception as e:
            print(f"Error occurred")
            print(tvds)



        tvd, gt_tvd = None, None
        if len(tvds) != 0:
            tvd = sum(tvds) / (len(tvds))
            gt_tvd = sum(gt_tvds) / (len(gt_tvds))

        print(f'num mistakes: {num_mistake}')
        print({"GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake})
        out[model_id] = {"GT_Avg_TVD": gt_tvd, "Pred_Avg_TVD": tvd, "num_mistakes": num_mistake, "TVDs": tvds, 
                        "confidence": f"Mean TVD: {mean:.4f} \n95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f}) \nError bar: ±{(confidence_interval[1] - mean):.4f}",
                        "incorrect_responses": incorrect_responses}  
        print('\n\n\n\n')

        with open(f"../eval_results/Rebuttal/cond_sample_{joint_size}_mushroom_local.json", "w") as file:
            json.dump(out, file, indent=4)             
    # with open(f"../eval_results/sampling/cond_sample_js{len(freqs.keys())}_seed_{seed}_20samp.json", "w") as file:
    #     json.dump(out, file, indent=4)
 
    # return out



