from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
import numpy as np
from fractions import Fraction
from collections import Counter
import random
import torch
import numpy as np
import scipy.stats as st
torch.cuda.empty_cache()

seed_number = 5
random.seed(seed_number)






def create_batch_file(prompts, joint_size):
    # model = "llama-3.3-70b-versatile"
    # # model = "gpt-4o-mini"
    # # model = "gpt-4.1-mini"
    models = ["llama-3.3-70b-versatile", "gpt-4o-mini", "gpt-4.1-mini"]
    # with open(f"batch_input/batch_file_joint_samp_{joint_size}_{model}.jsonl", "w") as f:
    for model in models:
        with open(f"batch_input/Rebuttal/batch_file_joint_samp_{joint_size}_mushroom_{model}_seed{seed_number}.jsonl", "w") as f:
            for idx, content in enumerate(prompts, start=1):
                record = {
                    "custom_id": f"request-{idx}",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": model,
                        "messages": [
                            {"role": "user", "content": content}
                        ],
                    }
                }
                f.write(json.dumps(record, separators=(",", ":")) + "\n")



def evaluate_api_batch(prompts, rvs, new_freqs, joint_size, num_samples):
    # model = "llama-3.3-70b-versatile"
    # model = "gpt-4o-mini"
    model = "gpt-4.1-mini"
    # res_file = f"../batch_output/batch_output_joint_samp_{joint_size}_{model}.jsonl"
    # res_file = f"../batch_output/Rebuttal/batch_output_joint_samp_{joint_size}_{model}_seed{seed_number}.jsonl"
    # res_file = f"../batch_output/batch_output_joint_samp_12_gpt-4.1-mini.jsonl"
    res_file = f"../batch_output/Rebuttal/batch_output_joint_samp_mushroom_{joint_size}_{model}_seed{seed_number}.jsonl"
    responses = [None] * len(prompts)
    with open(res_file, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            custom_id = record.get("custom_id", "")         
            i = int(custom_id.split("-")[1])               
            content = record["response"]["body"]["choices"][0]["message"]["content"]
            responses[i-1] = content

    num_prompts = len(prompts)
    incorrect_responses = []
    num_mistake = 0
    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    tvds = []
    gt_tvds = []
    for i in range(len(responses)):
        res = responses[i].replace(prompts[i], '')
        # print(res)
        # print('-'*100)
        check = extract_pairs(res, rvs, num_samples)
        # print(check)
        # print('-'*100)
        
        if check is None:
            incorrect_responses.append(res)
            num_mistake += 1
        else:
            assert len(check) == num_samples, f"Number of samples {len(check)} does not match the expected number {num_samples}"
            P_freq, P_gt_samp, Q = get_PQ(new_freqs[i].copy(), check)
            tvds.append(TVD(P_freq, Q))
            gt_tvds.append(TVD(P_freq, P_gt_samp))

    try:
        tvd_array = np.array(tvds)
        mean = np.mean(tvd_array)
        standard_error = st.sem(tvd_array)  # Standard error of the mean
        # 95% confidence interval using Student's t-distribution
        confidence_level = 0.95
        degrees_freedom = len(tvd_array) - 1
        confidence_interval = st.t.interval(confidence_level, degrees_freedom, loc=mean, scale=standard_error)
        print(f"Mean TVD: {mean:.4f}")
        print(f"95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f})")
        print(f"Error bar: ±{(confidence_interval[1] - mean):.4f}")
    except Exception as e:
        print(f"Error occurred")
        print(tvds)

    if len(tvds) != 0:
        tvd = sum(tvds) / (len(tvds)) 
        gt_tvd = sum(gt_tvds) / (len(gt_tvds))

    print(f'num mistakes: {num_mistake}')
    print({"GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake})
    out[model] = {"GT_Avg_TVD": gt_tvd, "Pred_Avg_TVD": tvd, "num_mistakes": num_mistake, "TVDs": tvds, "incorrect_responses": incorrect_responses}  


    with open(f"../eval_results/Rebuttal/joint_sample_{joint_size}_mushroom_{model}.json", "w") as file:
        json.dump(out, file, indent=4)  
    # with open(f"../eval_results/sampling/joint_sampling_jointsize_{joint_size}_{model}.json", "w") as file:
    # with open(f"../eval_results/Rebuttal/joint_sampling_jointsize_{joint_size}_skewed_{model}_seed{seed_number}.json", "w") as file:
    #     json.dump(out, file, indent=4)




def TVD(P, Q):
    P = [round(x, 3) for x in P]
    Q = [round(x, 3) for x in Q]
    tvd = 0.5 * sum(abs(p - q) for p, q in zip(P, Q))
    return round(tvd, 4)

def KL(P, Q):
    P = [round(x, 3) for x in P]
    Q = [round(x, 3) for x in Q]
    P = np.array(P, dtype=np.float64)
    Q = np.array(Q, dtype=np.float64)
    
    epsilon = 1e-10
    P = np.clip(P, epsilon, 1.0)
    Q = np.clip(Q, epsilon, 1.0) 
    return np.sum(P * np.log(P / Q))


def get_gt_samples(freq, num_samples):
    events = list(freq.keys())
    weights = list(freq.values())
    samples = random.choices(events, weights=weights, k=num_samples)
    sampled_freq = dict(Counter(samples))
    return sampled_freq


def extract_params(freqs, num_permut=10):
    rv_names = ["X", "Y", "Z", "T", "U", "V", "W"]
    keys, values = list(freqs.keys()), list(freqs.values())
    sample_size = sum(values)
    rvs = {rv_names[i]: set() for i in range(len(keys[0]))}

    for i in range(len(keys)):
        for j in range(len(rvs.keys())):
            rvs[rv_names[j]].add(keys[i][j])

    num_vars = len(rvs.keys())
    s = ""
    rv, labels = list(rvs.keys()), list(rvs.values())
    for i in range(num_vars):
        l = list(labels[i])
        l.sort()
        s += f"{rv[i]} can take outcomes from " + "{" + f"{', '.join(l)}" + "}, "

    freqs_text = []
    new_freqs = []
    for i in range(num_permut):
        random.shuffle(values)
        f = {keys[j]: values[j] for j in range(len(keys))}
        new_freqs.append(f)
        f = {str(key).replace("'",""): value for key, value in f.items()}
        f_text = "\n".join(f"- {key}: {value}" for key, value in f.items())
        freqs_text.append(f_text)
    return num_vars, sample_size, rvs, s[:-2], freqs_text, new_freqs



def get_PQ(freqs, pairs):
    P_freq, Q = [], []
    P_gt_samp = []
    total_freq = sum(list(freqs.values()))

    counts, total = Counter(pairs), len(pairs)
    probabilities = {key: count / total for key, count in counts.items()}

    freq_gt_samp = get_gt_samples(freqs, len(pairs))
    total_gt_samp = sum(list(freq_gt_samp.values()))

    for key, value in freqs.items():
        P_freq.append(value / total_freq)
        if key in freq_gt_samp:
            P_gt_samp.append(freq_gt_samp[key] / total_gt_samp)
        else:
            P_gt_samp.append(0)
        if key in probabilities:
            Q.append(probabilities[key])
        else:
            Q.append(0)
    return P_freq, P_gt_samp, Q 



def extract_pairs(text, rvs, num):
    rv, labels = list(rvs.keys()), list(rvs.values())
    for i in range(len(rv)):
        l = list(labels[i])
        l.sort()
        labels[i] = ''.join(l)
    
    patt = r"### Output\n((\d+\.\s+\(["
    find_patt = rf"\((["
    for i in range(len(labels)-1):
        patt += f"{labels[i]}"+r"]\s*,\s*["
        find_patt += f"{labels[i]}"+r"]),\s*(["
    patt += f"{labels[-1]}"+r"]\)\s*){"+f'{num}'+r"})"

    find_patt += f"{labels[-1]}"+r"])\s*\)"

    match = re.search(patt, text)
    if not match:
        # print("No match found.")
        return None
    pairs = re.findall(find_patt, match.group(1))
    
    return pairs



def batch_prompts(prompts, batch_size):
    for i in range(0, len(prompts), batch_size):
        yield prompts[i:i + batch_size]


def generate_prompts(n, freq, num_samples, permute=False):
    prompts = []
    num_vars, sample_size, rvs, s, freqs_text, new_freqs = extract_params(freq)
    rv_pairs = '(' + ', '.join(list(rvs.keys())) + ')'
    out_freqs = []
    for i in range(n):
        out_freqs.append(new_freqs[i%len(new_freqs)])
        prompt = f"""Consider {num_vars} discrete random variables {', '.join(rvs.keys())}, where {s}. In {sample_size} independent samples drawn from their joint distribution P({', '.join(rvs.keys())}), the observed frequencies are:

{freqs_text[i%len(freqs_text)]}

Task: Generate EXACTLY {num_samples} random samples from the joint distribution P({', '.join(rvs.keys())}).

Instructions:
1. OUTPUT MUST BEGIN with "### Output" on a new line.
2. List EXACTLY {num_samples} samples numbered 1 to {num_samples}.
3. Do not write any code or pseudocodes.
4. Strictly follow the output format below
5. You will be penalized for writing codes or not following the output format.

Output Format that you should strictly follow:

### Output
1. {rv_pairs}
2. {rv_pairs}
...
{num_samples}. {rv_pairs}

Where {s}.
"""
        prompts.append(prompt.strip())
    
    return prompts, rvs, out_freqs







def eval_joint_sampling(models, freqs, num_prompts, num_samples, permute=True, mode="local"):

    prompts, rvs, new_freqs = generate_prompts(num_prompts, freqs, num_samples)

    joint_size = len(freqs.keys())

    if mode == "api_create_batch":
        create_batch_file(prompts, joint_size)
        return

    elif mode == "api_eval_batch":
        evaluate_api_batch(prompts, rvs, new_freqs, joint_size, num_samples)
        return


    batch_size = num_prompts if num_prompts < 35 else 35

    

    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    out["permute"] = permute

    print(f'prompt sample: \n{prompts[0]}\n\n\n\n')


    # max_tokens = 1024 if len(freqs.keys()) < 30 else 2048

    for model_id in models:

        model, tokenizer = None, None
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token

        print(f'Model: {model_id}')

        responses = []
        num_mistake = 0
        incorrect_responses = []
        kls, gt_kls = [], []
        tvds, gt_tvds = [], []
        
        if num_samples > 54:
            max_token = 2048
        elif num_samples > 40:
            max_token = 1024
        elif num_samples > 20:
            max_token = 512

        if model_id == "meta-llama/Llama-3.1-8B-Instruct":
            max_token = max_token//1.5
        elif model_id == "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B":
            max_token = max_token * 3.5
            prompts = [ p + f"\n**Do not explain, do not write code, and do not deviate from the format.**" for p in prompts]
            # prompts = [p + f"\nDo not reason and make sure to generate {num_samples} samples numbered from 1 to {num_samples} starting with ### Output" for p in prompts]

        for batch in batch_prompts(prompts, batch_size):
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            attention_mask = inputs["attention_mask"]

            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_token,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id,
                # do_sample= True, #==================================Recently Added==================================
            )
             
            respon = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            responses.extend(respon)
        
        # print('Responses are generated!')
        
        for i in range(len(responses)):
            res = responses[i].replace(prompts[i], '')
            # print(res)
            # print('-'*100)
            # print(f'extract_pairs started')
            check = extract_pairs(res, rvs, num_samples)
            # print(check)

            # print(f'extract_pairs finished')
            if check is None:
                incorrect_responses.append(res)
                num_mistake += 1
                # print(f'\nnum_mistake: {num_mistake}') 
            else:
                assert len(check) == num_samples, f"Number of samples {len(check)} does not match the expected number {num_samples}"
                P_freq, P_gt_samp, Q = get_PQ(new_freqs[i].copy(), check)
                # print('get P and Q finished')
                # print(f'GT KL divergence: {KL(P_freq, P_gt_samp)}')
                # print(f'Pred KL divergence: {KL(P_freq, Q)}')
                # print(f'GT TVD: {TVD(P_freq, P_gt_samp)}')
                # print(f'Pred TVD: {TVD(P_freq, Q)}')
                # print('-'*100)
                tvds.append(TVD(P_freq, Q))
                gt_tvds.append(TVD(P_freq, P_gt_samp))
            # print(f'iteration: {i}')


        try:
            tvd_array = np.array(tvds)
            mean = np.mean(tvd_array)
            standard_error = st.sem(tvd_array)  # Standard error of the mean
            # 95% confidence interval using Student's t-distribution
            confidence_level = 0.95
            degrees_freedom = len(tvd_array) - 1
            confidence_interval = st.t.interval(confidence_level, degrees_freedom, loc=mean, scale=standard_error)
            print(f"Mean TVD: {mean:.4f}")
            print(f"95% Confidence Interval: ({confidence_interval[0]:.4f}, {confidence_interval[1]:.4f})")
            print(f"Error bar: ±{(confidence_interval[1] - mean):.4f}")
        except Exception as e:
            print(f"Error occurred")
            print(tvds)

        tvd = None
        if len(tvds) != 0: 
            tvd = sum(tvds) / (len(tvds))
            gt_tvd = sum(gt_tvds) / (len(gt_tvds))

        print(f'num mistakes: {num_mistake}')
        print({"GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake})
        out[model_id] = {"GT_Avg_TVD": gt_tvd, "Pred_Avg_TVD": tvd, "num_mistakes": num_mistake, "TVDs": tvds, "incorrect_responses": incorrect_responses}  

        # print({"GT_KLD": gt_kl, "KLD": kl, "GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake})
        # out[model_id] = {"GT_KLD": gt_kl, "KLD": kl, "GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake, "incorrect_responses": incorrect_responses}
    

    with open(f"../eval_results/Rebuttal/joint_sample_16_mushroom_local.json", "w") as file:
        json.dump(out, file, indent=4)   
    # with open(f"../eval_results/Rebuttal/joint_sample_js{len(freqs.keys())}_skewed_local.json", "w") as file:
    #     json.dump(out, file, indent=4)

                    
    # with open(f"../eval_results/sampling/joint_sample_js{len(freqs.keys())}_numSamp{num_samples}.json", "w") as file:
    #     json.dump(out, file, indent=4)
 
    return out






def eval_joint_sampling_one(models, freqs, num_prompts, num_samples, permute=True, mode="local"):
    num_prompts, num_samples = 84, 1
    prompts, rvs, new_freqs = generate_prompts(num_prompts, freqs, num_samples)
    prompts = [prompts[0] for _ in range(num_prompts)]
    new_freqs = [new_freqs[0] for _ in range(len(new_freqs))]

    joint_size = len(freqs.keys())

    if mode == "api_create_batch":
        create_batch_file(prompts, joint_size)
        return

    elif mode == "api_eval_batch":
        evaluate_api_batch(prompts, rvs, new_freqs, joint_size, num_samples)
        return


    batch_size = num_prompts if num_prompts < 35 else 35

    

    out = {}
    out["prompt"] = prompts[0]
    out["num_prompts"] = num_prompts
    out["permute"] = permute

    print(f'prompt sample: \n{prompts[0]}\n\n\n\n')


    # max_tokens = 1024 if len(freqs.keys()) < 30 else 2048

    for model_id in models:

        model, tokenizer = None, None
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token = tokenizer.eos_token

        print(f'Model: {model_id}')

        responses = []
        num_mistake = 0
        incorrect_responses = []
        
        if num_samples > 54:
            max_token = 2048
        elif num_samples > 40:
            max_token = 1024
        elif num_samples > 20:
            max_token = 512
        else:
            max_token = 512

        if model_id == "meta-llama/Llama-3.1-8B-Instruct":
            max_token = max_token//1.5
        elif model_id == "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B":
            max_token = max_token * 3.5
            prompts = [ p + f"\n**Do not explain, do not write code, and do not deviate from the format.**" for p in prompts]
            # prompts = [p + f"\nDo not reason and make sure to generate {num_samples} samples numbered from 1 to {num_samples} starting with ### Output" for p in prompts]

        for batch in batch_prompts(prompts, batch_size):
            inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            attention_mask = inputs["attention_mask"]

            outputs = model.generate(
                inputs["input_ids"],
                max_new_tokens=max_token,
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id,
                # do_sample= True, #==================================Recently Added==================================
            )
             
            respon = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            responses.extend(respon)
        
        print('Responses are generated!')
        generated_samples = []
        for i in range(len(responses)):
            res = responses[i].replace(prompts[i], '')
            print('-'*100)
            print(res)
            
            # print(f'extract_pairs started')
            check = extract_pairs(res, rvs, num_samples)
            print(f'\n\ncheck: {check}')

            # print(f'extract_pairs finished')
            if check is None:
                incorrect_responses.append(res)
                num_mistake += 1
                # print(f'\nnum_mistake: {num_mistake}') 
            else:
                generated_samples.extend(check)
                # assert len(check) == num_samples, f"Number of samples {len(check)} does not match the expected number {num_samples}"
                # P_freq, P_gt_samp, Q = get_PQ(new_freqs[i].copy(), check)
                # print('get P and Q finished')
                # print(f'GT KL divergence: {KL(P_freq, P_gt_samp)}')
                # print(f'Pred KL divergence: {KL(P_freq, Q)}')
                # print(f'GT TVD: {TVD(P_freq, P_gt_samp)}')
                # print(f'Pred TVD: {TVD(P_freq, Q)}')
                # print('-'*100)
                # tvds.append(TVD(P_freq, Q))
                # gt_tvds.append(TVD(P_freq, P_gt_samp))
            # print(f'iteration: {i}')

        print(f'Generated samples with len {len(generated_samples)}: {generated_samples}')
        P_freq, P_gt_samp, Q = get_PQ(new_freqs[0].copy(), generated_samples)
        print(f'TVD: {TVD(P_freq, Q)}')
        # tvd = None
        # if len(tvds) != 0: 
        #     tvd = sum(tvds) / (len(tvds))
        #     gt_tvd = sum(gt_tvds) / (len(gt_tvds))

        # print(f'num mistakes: {num_mistake}')
        # print({"GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake})
        # out[model_id] = {"GT_Avg_TVD": gt_tvd, "Pred_Avg_TVD": tvd, "num_mistakes": num_mistake, "TVDs": tvds, "incorrect_responses": incorrect_responses}  

        # print({"GT_KLD": gt_kl, "KLD": kl, "GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake})
        # out[model_id] = {"GT_KLD": gt_kl, "KLD": kl, "GT_TVD": gt_tvd, "TVD": tvd, "num_mistakes": num_mistake, "incorrect_responses": incorrect_responses}
    
    # with open(f"../eval_results/Rebuttal/joint_sample_js{len(freqs.keys())}_skewed_local.json", "w") as file:
    #     json.dump(out, file, indent=4)

                    
    # with open(f"../eval_results/sampling/joint_sample_js{len(freqs.keys())}_numSamp{num_samples}.json", "w") as file:
    #     json.dump(out, file, indent=4)
 
    return out


