import json
import os
from typing import Union

import fire
import numpy as np
import random
import torch
import torch.nn.functional as F
from datasets import Dataset
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from tqdm import tqdm

from nicks_dpo.prompts import *

def get_sampling_discrepancy_analytic(logits_ref, logits_score, labels):
    assert logits_ref.shape[0] == 1
    assert logits_score.shape[0] == 1
    assert labels.shape[0] == 1
    if logits_ref.size(-1) != logits_score.size(-1):
        # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
        vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
        logits_ref = logits_ref[:, :, :vocab_size]
        logits_score = logits_score[:, :, :vocab_size]

    labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
    lprobs_score = torch.log_softmax(logits_score, dim=-1)
    probs_ref = torch.softmax(logits_ref, dim=-1)
    log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
    mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
    var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
    discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).sqrt()
    discrepancy = discrepancy.mean()
    return discrepancy.item()

@torch.no_grad()
def get_fast_detect_gpt_scores(
    text: list[str],
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_model = AutoModelForCausalLM.from_pretrained("gpt2-xl")
    base_model.eval()
    base_model.to(device)
    base_tok = AutoTokenizer.from_pretrained("gpt2-xl")
    base_tok.pad_token = base_tok.eos_token

    scores = []
    for sample in tqdm(text):
        if len(sample) <= 0:
            continue
        tok = base_tok(
            sample, 
            padding=True,
            truncation=True, 
            return_tensors="pt", 
        ).to(device)
        base_logits = base_model(**tok).logits[:, :-1]
        reference_logits = base_logits
        labels = tok["input_ids"][:, 1:]
        discrepancy = get_sampling_discrepancy_analytic(reference_logits, base_logits, labels)
        scores.append(discrepancy)
    return scores

@torch.inference_mode()
def get_luar_embeddings(
    text: Union[list[str], list[list[str]]],
    model: AutoModel,
    tokenizer: AutoTokenizer,
    batch_size: int = 32,
    single: bool = False,
    normalize: bool = True,
):
    if isinstance(text[0], list):
        outputs = torch.cat([get_luar_embeddings(t, model, tokenizer, single=True) for t in text], dim=0)
        return outputs

    device = model.device

    inputs = tokenizer(
        text,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    
    if single:
        inputs["input_ids"] = inputs["input_ids"].unsqueeze(0)
        inputs["attention_mask"] = inputs["attention_mask"].unsqueeze(0)
        inputs.to(device)
        outputs = model(**inputs)
    else:
        outputs = []
        for batch_idx in tqdm(range(0, len(text), batch_size)):
            batch_inputs = {
                k: v[batch_idx:batch_idx + batch_size].unsqueeze(1).to(device)
                for k, v in inputs.items()
            }
            outputs.append(model(**batch_inputs))
        outputs = torch.cat(outputs, dim=0)
            
    if normalize:
        outputs = F.normalize(outputs, dim=-1, p=2)
    return outputs

def main(
    data_path: str,
    suffix: str,
    debug: bool = False,
    method: str = "FastDetectGPT",
    K: int = 100, # number of few-shot examples for LUAR
    print_scores: bool = False,
):
    assert method in ["FastDetectGPT", "LUAR_against", "LUAR_for"]

    data = []
    with open(data_path, "r") as fin:
        for line in fin:
            data.append(json.loads(line))
    random.shuffle(data)
    if debug:
        num = 10 + K if "LUAR" in method else 10
        data = data[:num]

    if "reddit" in data_path:
        dataset_name = "reddit"
    elif "amazon" in data_path:
        dataset_name = "amazon"
    elif "blogs" in data_path:
        dataset_name = "blogs"
    else:
        assert False

    pdict = {
        "reddit": RESPOND_REDDIT_PROMPT,
        "amazon": RESPOND_AMAZON_PROMPT,
        "blogs": RESPOND_BLOG_PROMPT,
    }
    if "qwen" in os.path.basename(data_path).lower():
        tokenizer_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
        
    if "LUAR" in method:
        model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True).eval().cuda()
        tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD")

        fs_key = "content_text" if method == "LUAR_for" else "respond_reddit"
        fs = [[d[fs_key] for d in data[:K]]]
        if isinstance(fs[0], list):
            fs = [d[0][0] for d in fs]

        prompts = []
        for d in data[K:]:
            ct = d["content_text"]
            lenwords = len(ct.split(" "))
            prompts.append(pdict[dataset_name].format(ct, lenwords))
            if "qwen" in os.path.basename(data_path).lower():
                prompts[-1] = tokenizer_qwen.apply_chat_template(
                    [{"role":"user", "content":prompts[-1]}], 
                    tokenize=False, 
                    add_generation_prompt=False, 
                    enable_thinking=False
                )
        responses = [d["respond_reddit"] for d in data[K:]]
        responses = [j for i in responses for j in i]

        fs_emb = get_luar_embeddings(fs, model, tokenizer, single=True)
        responses_emb = get_luar_embeddings(responses, model, tokenizer)
        scores = F.cosine_similarity(fs_emb.repeat(responses_emb.shape[0], 1), responses_emb, dim=-1)
        if method == "LUAR_for":
            scores = -scores
        scores = scores.cpu().tolist()
    else:
        prompts = []
        for d in data:
            ct = d["content_text"]
            lenwords = len(ct.split(" "))
            prompts.append(pdict[dataset_name].format(ct, lenwords))
            if "qwen" in os.path.basename(data_path).lower():
                prompts[-1] = tokenizer_qwen.apply_chat_template(
                    [{"role":"user", "content":prompts[-1]}], 
                    tokenize=False, 
                    add_generation_prompt=False, 
                )

        responses = [d["respond_reddit"] for d in data]
        responses = [j for i in responses for j in i]
        scores = get_fast_detect_gpt_scores(responses)

    records = []
    chosen_scores_list = []
    rejected_scores_list = []
    for i in range(0, len(scores), 2):
        try:
            prompt = prompts[i//2]
            response_1 = responses[i]
            response_2 = responses[i+1]

            if scores[i] < scores[i+1]:
                chosen = response_1
                rejected = response_2
                chosen_score = scores[i]
                rejected_score = scores[i+1]
            else:
                chosen = response_2
                rejected = response_1
                chosen_score = scores[i+1]
                rejected_score = scores[i]
            
            chosen_scores_list.append(chosen_score)
            rejected_scores_list.append(rejected_score)
            
            records.append({
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
            })
        except:
            continue
    
    if print_scores:
        chosen_mean = np.mean(chosen_scores_list)
        chosen_std = np.std(chosen_scores_list)
        rejected_mean = np.mean(rejected_scores_list)
        rejected_std = np.std(rejected_scores_list)
        
        print(f"Chosen: {chosen_mean:.4f} ({chosen_std:.4f})")
        print(f"Rejected: {rejected_mean:.4f} ({rejected_std:.4f})")
        
        with open(f"scores_{method}_{suffix}.md", "w") as f:
            f.write(f"| Type | Mean | Std |\n")
            f.write(f"| --- | --- | --- |\n")
            f.write(f"| Chosen | {chosen_mean:.4f} | {chosen_std:.4f} |\n")
            f.write(f"| Rejected | {rejected_mean:.4f} | {rejected_std:.4f} |\n")
            
    elif not debug:
        dataset = Dataset.from_list(records)
        dataset.save_to_disk("./nicks_dpo/preference-{}-{}".format(method, suffix))

    return 0


if __name__ == "__main__":
    random.seed(43)
    fire.Fire(main)