
import json
import os

import fire
import pandas as pd
import torch
import torch.nn.functional as F
from binoculars import Binoculars
from sentence_transformers import SentenceTransformer
from termcolor import colored
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from tqdm import tqdm

from mtd.fast_detect_gpt import get_sampling_discrepancy_analytic
from nicks_dpo.create_preference_data import (
    get_luar_embeddings,
)

@torch.no_grad()
def get_rank_forward(text, base_model, base_tokenizer, log=False):
    """From: https://github.com/eric-mitchell/detect-gpt/blob/main/run.py#L298C1-L320C43
    """
    with torch.no_grad():
        tokenized = base_tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(base_model.device)
        logits = base_model(**tokenized).logits[:,:-1]
        labels = tokenized["input_ids"][:,1:]

        # get rank of each label token in the model's likelihood ordering
        matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()

        assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}"

        ranks, timesteps = matches[:,-1], matches[:,-2]

        # make sure we got exactly one match for each timestep in the sequence
        assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep"

        ranks = ranks.float() + 1 # convert to 1-indexed rank
        if log:
            ranks = torch.log(ranks)

        return ranks.float().mean().item()

@torch.no_grad()
def supervised_forward(
    detector: AutoModelForSequenceClassification,
    tokenizer: AutoTokenizer,
    text: list[str],
    batch_size: int,
    pbar: tqdm,
    is_remo: bool = False,
) -> list[float]:
    probabilities = []
    for i in range(0, len(text), batch_size):
        batch = text[i:i+128]
        batch = tokenizer(
            batch,
            max_length=512, 
            padding=True, 
            truncation=True, 
            return_tensors="pt",
        )
        batch = {k:v.to(detector.device) for k,v in batch.items()}
        if is_remo:
            output_probs = detector(**batch).logits[:, 0].tolist()
        else:
            output_probs = F.log_softmax(detector(**batch).logits, dim=-1)[:, 0].exp().tolist()
        probabilities.extend(output_probs)
        pbar.update(1)
    return [prob for prob in probabilities]

@torch.no_grad()
def compute_supervised_MTD_scores(
    text: list[str],
    model_name: str,
) -> list[float]:

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    detector = AutoModelForSequenceClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    detector.to(device)
    detector.eval()
    
    batch_size = 128
    pbar = tqdm(total=len(text)//batch_size)
    is_remo = "remo" in model_name.lower()
    scores = supervised_forward(detector, tokenizer, text, batch_size, pbar, is_remo=is_remo)

    return scores

@torch.no_grad()
def bino_forward(
    bino: Binoculars,
    text: list[str],
    batch_size: int,
    pbar: tqdm,
) -> list[float]:
    scores = []
    for i in range(0, len(text), batch_size):
        batch = text[i:i+batch_size]
        scores += bino.compute_score(batch)
        pbar.update(1)
    return [score for score in scores]

@torch.no_grad()
def compute_binoculars_score(
    text: list[str],
) -> list[float]:
    bino = Binoculars(
        observer_name_or_path="tiiuae/falcon-7b",
        performer_name_or_path="tiiuae/falcon-7b-instruct",
    )
    
    batch_size = 16
    pbar = tqdm(total=len(text))
    scores = bino_forward(bino, text, batch_size, pbar)

    return scores

@torch.no_grad()
def fastdetectgpt_forward(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    text: list[str],
    pbar: tqdm,
) -> list[float]:
    scores = []
    for sample in text:
        tok = tokenizer(
            sample, 
            padding=True,
            truncation=True, 
            return_tensors="pt", 
        ).to(model.device)
        base_logits = model(**tok).logits[:, :-1]
        reference_logits = base_logits
        labels = tok["input_ids"][:, 1:]
        discrepancy = get_sampling_discrepancy_analytic(reference_logits, base_logits, labels)
        scores.append(discrepancy)
        pbar.update(1)
    return scores

@torch.no_grad()
def compute_fastdetectgpt_score(
    text: list[str],
) -> list[float]:
    # model used by the original FastDetectGPT paper:
    model_name = "EleutherAI/gpt-neo-2.7B"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.to(device)

    total = len(text)
    pbar = tqdm(total=total)
    scores = fastdetectgpt_forward(model, tokenizer, text, pbar)

    return scores

@torch.no_grad()
def compute_rank_score(
    text: list[str],
    log: bool = False,
) -> list[float]:
    model_name = "gpt2-xl"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.to(device)
    scores = [get_rank_forward(t, model, tokenizer, log) for t in tqdm(text)]

    return scores

def compute_styledetect_scores(
    text: list[str],
    background: list[str],
    model_name: str = "rrivera1849/LUAR-MUD",
):
    # load background:
    if "LUAR" in model_name:
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True).eval().cuda()
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        background_emb = get_luar_embeddings(background, model, tokenizer, single=True)
        emb = get_luar_embeddings(text, model, tokenizer)
    else:
        model = SentenceTransformer(model_name).eval().cuda()
        background_emb = model.encode(background, show_progress_bar=False, normalize_embeddings=False, convert_to_tensor=True)
        background_emb = background_emb.mean(dim=0, keepdim=True)
        background_emb = F.normalize(background_emb, p=2, dim=-1)
        emb = model.encode(text, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    
    scores = F.cosine_similarity(background_emb.repeat(emb.size(0), 1), emb).cpu().tolist()
    return scores

def main(
    mtd_data_path: str = "./neurips/shards/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonlshard-4-4.iter=3",
    eval_key: str = "content_text",
    K: int = 100, # number of few-shot examples for StyleDetect
    debug: bool = False,
):
    df = pd.read_json(mtd_data_path, lines=True)

    if "respond_reddit" in df.columns:
        guess_mkey = "respond_reddit"
    elif "generation" in df.columns:
        guess_mkey = "generation"
    else:
        assert False, "Unable to find column name with machine text"

    machine = df[guess_mkey].tolist()
    fewshot_examples = machine[:K]
    machine = machine[K:]

    if eval_key == guess_mkey:
        text = machine
    else:
        text = df[eval_key].tolist()
    if isinstance(text[0], list):
        text = [t[0] for t in text]
    if isinstance(fewshot_examples[0], list):
        fewshot_examples = [t[0] for t in fewshot_examples]
    text = [x for x in text if len(x) > 0]

    if debug:
        text = text[:100]

    base = os.path.basename(mtd_data_path)
    suffix = ".debug" if debug else ""
    savename = f"./mtd_scores/{base}_eval-{eval_key}{suffix}.json"
    if os.path.isfile(savename):
        mtd_scores = json.loads(open(savename, "r").read())
        print("Found a file with these keys:", mtd_scores.keys())
    else:
        mtd_scores = {}

    if "Binoculars" not in mtd_scores:
        print("Running Binoculars")
        mtd_scores["Binoculars"] = compute_binoculars_score(text)
    if "FastDetectGPT" not in mtd_scores:
        print("Running FastDetectGPT")
        mtd_scores["FastDetectGPT"] = compute_fastdetectgpt_score(text)
    # if "OpenAI" not in mtd_scores:
    #     print("Running OpenAI")
    #     mtd_scores["OpenAI"] = compute_supervised_MTD_scores(
    #         text, 
    #         "openai-community/roberta-base-openai-detector",
    #     )
    if "RADAR" not in mtd_scores:
        print("Running RADAR")
        mtd_scores["RADAR"] = compute_supervised_MTD_scores(
            text, 
            "TrustSafeAI/RADAR-Vicuna-7B",
        )
    if "Rank" not in mtd_scores:
        print("Running Rank")
        mtd_scores["Rank"] = compute_rank_score(text, log=False)
    if "LogRank" not in mtd_scores:
        print("Running LogRank")
        mtd_scores["LogRank"] = compute_rank_score(text, log=True)
    if "StyleDetect" not in mtd_scores:
        print("Running StyleDetect")
        mtd_scores["StyleDetect"] = compute_styledetect_scores(text, fewshot_examples)
    if "StyleDetect-CISR" not in mtd_scores:
        print("Running StyleDetect-CISR")
        mtd_scores["StyleDetect-CISR"] = compute_styledetect_scores(text, fewshot_examples, model_name="AnnaWegmann/Style-Embedding")
    if "StyleDetect-SD" not in mtd_scores:
        print("Running StyleDetect-SD")
        mtd_scores["StyleDetect-SD"] = compute_styledetect_scores(text, fewshot_examples, model_name="StyleDistance/styledistance")
    if "SemDetect" not in mtd_scores:
        print("Running SemDetect")
        mtd_scores["SemDetect"] = compute_styledetect_scores(text, fewshot_examples, model_name="sentence-transformers/all-mpnet-base-v2")
    if "ReMoDetect" not in mtd_scores:
        print("Running ReMoDetect")
        mtd_scores["ReMoDetect"] = compute_supervised_MTD_scores(
            text,
            "hyunseoki/ReMoDetect-deberta",
        )

    print(colored("Saving to:", "green"), savename)
    with open(savename, "w+") as fout:
        fout.write(json.dumps(mtd_scores))

    return 0

if __name__ == "__main__":
    fire.Fire(main)
