import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
import os
import argparse
import json
from model import load_tokenizer, load_model
from fast_detect_gpt import get_sampling_discrepancy_analytic
from scipy.stats import norm
from tqdm import tqdm

context = False
humanised = True
test_only = False
new_data = False
human_only = False
conf = "b"

def compute_prob_norm(x, mu0, sigma0, mu1, sigma1):
    pdf_value0 = norm.pdf(x, loc=mu0, scale=sigma0)
    pdf_value1 = norm.pdf(x, loc=mu1, scale=sigma1)
    prob = pdf_value1 / (pdf_value0 + pdf_value1)
    return prob

class FastDetectGPT:
    def __init__(self, args):
        self.args = args
        self.criterion_fn = get_sampling_discrepancy_analytic
        self.scoring_tokenizer = load_tokenizer(args.scoring_model_name, args.cache_dir)
        self.scoring_model = load_model(args.scoring_model_name, args.device, args.cache_dir)
        self.max_len = self.scoring_model.config.max_position_embeddings
        self.scoring_model.eval()
        if args.sampling_model_name != args.scoring_model_name:
            self.sampling_tokenizer = load_tokenizer(args.sampling_model_name, args.cache_dir)
            self.sampling_model = load_model(args.sampling_model_name, args.device, args.cache_dir)
            self.max_len = min(self.max_len, self.sampling_model.config.max_position_embeddings)
            self.sampling_model.eval()
        # To obtain probability values that are easy for users to understand, we assume normal distributions
        # of the criteria and statistic the parameters on a group of dev samples. The normal distributions are defined
        # by mu0 and sigma0 for human texts and by mu1 and sigma1 for AI texts. We set sigma1 = 2 * sigma0 to
        # make sure of a wider coverage of potential AI texts.
        distrib_params = {
            'gpt-j-6B_gpt-j-6B': {'mu0': 0.2713, 'sigma0': 0.9366, 'mu1': 2.2334, 'sigma1': 1.8731}, # PROXY
            'gpt-j-6B_gpt-neo-2.7B': {'mu0': 0.2713, 'sigma0': 0.9366, 'mu1': 2.2334, 'sigma1': 1.8731},
            'gpt-neo-2.7B_gpt-neo-2.7B': {'mu0': -0.2489, 'sigma0': 0.9968, 'mu1': 1.8983, 'sigma1': 1.9935},
            'falcon-7b_falcon-7b-instruct': {'mu0': -0.0707, 'sigma0': 0.9520, 'mu1': 2.9306, 'sigma1': 1.9039},
            'llama3-8b_llama3-8b-instruct': {'mu0': -0.1500, 'sigma0': 0.9800, 'mu1': 2.5000, 'sigma1': 1.9600},
        }
        key = f'{args.sampling_model_name}_{args.scoring_model_name}'
        self.classifier = distrib_params[key]


    # compute conditional probability curvature
    def compute_crit(self, text):
        # Tokenise with SCORING MODEL
        tokenized = self.scoring_tokenizer(text, truncation=True, return_tensors="pt", padding=False, return_token_type_ids=False).to(self.args.device)
        # extract all but first token ids
        labels = tokenized.input_ids[:, 1:]
        with torch.no_grad():
            # forward pass, get logits corresponding to each token from SCORING MODEL
            logits_score = self.scoring_model(**tokenized).logits[:, :-1]
            if self.args.sampling_model_name == self.args.scoring_model_name:
                logits_ref = logits_score
            else:
                # Tokenise with SAMPLING MODEL. Tokens should be the same as scoring model
                tokenized = self.sampling_tokenizer(text, truncation=True, return_tensors="pt", padding=False, return_token_type_ids=False).to(self.args.device)
                assert torch.all(tokenized.input_ids[:, 1:] == labels), "Tokenizer is mismatch."
                # forward pass, get logits corresponding to each token from SAMPLING MODEL
                logits_ref = self.sampling_model(**tokenized).logits[:, :-1]
            # SCORE COMPUTATION
            res = self.criterion_fn(logits_ref, logits_score, labels)
            crit = res['discrepancy']
            log_likelihood = res['log_likelihood']
        return crit, labels.size(1), log_likelihood
    
    
    # compute conditional probability curvature with added context
    def compute_crit_cxt(self, review, abstract):
        # =========== Tokenise with SCORING MODEL ============
        # Tokenise review 
        rev_tok = self.scoring_tokenizer(review, truncation=True, return_tensors="pt", add_special_tokens=False).to(self.args.device)
        review_len = rev_tok.input_ids.size(1)
        if review_len >= self.max_len:
            rev_tok = self.scoring_tokenizer(review, truncation=True, max_length=self.max_len, return_tensors="pt", add_special_tokens=False).to(self.args.device)
            return self.compute_crit(review)
        remaining_len = self.max_len - review_len
        if remaining_len <= 0:
            return self.compute_crit(review)
        # Tokenise context with available space
        abs_tok = self.scoring_tokenizer(abstract, truncation=True, max_length=remaining_len, return_tensors="pt", add_special_tokens=False).to(self.args.device)
        abstract_len = abs_tok.input_ids.size(1)
        input_ids = torch.cat([abs_tok.input_ids, rev_tok.input_ids], dim=1)
        attention_mask = torch.ones_like(input_ids)
        with torch.no_grad():
            # forward pass, get logits corresponding to each token from SCORING MODEL
            outputs = self.scoring_model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            # Separate out logits of review text for further expectation and variance computations
            if abstract_len == 0:
                logits_score = logits[:, :-1, :]
            else:
                logits_score = logits[:, abstract_len:-1, :]
            labels = rev_tok.input_ids[:, 1:]
            if logits_score.size(1) != labels.size(1):
                print(f"Warning: logits_score shape {logits_score.shape} doesn't match labels shape {labels.shape}")
                return self.compute_crit(review)
            if self.args.sampling_model_name == self.args.scoring_model_name:
                logits_ref = logits_score
            else:
                # =========== Tokenise with SAMPLING MODEL ============
                # Tokenise review
                rev_tok_ref = self.sampling_tokenizer(review, truncation=True, return_tensors="pt", add_special_tokens=False).to(self.args.device)
                review_len_ref = rev_tok_ref.input_ids.size(1)
                remaining_len_ref = self.max_len - review_len_ref
                if remaining_len_ref <= 0:
                    return self.compute_crit(review)
                # Tokenise context with available space
                abs_tok_ref = self.sampling_tokenizer(abstract, truncation=True, max_length=remaining_len_ref, return_tensors="pt", add_special_tokens=False).to(self.args.device)
                abstract_len_ref = abs_tok_ref.input_ids.size(1)
                input_ids_ref = torch.cat([abs_tok_ref.input_ids, rev_tok_ref.input_ids], dim=1)
                attention_mask_ref = torch.ones_like(input_ids_ref)
                outputs_ref = self.sampling_model(input_ids=input_ids_ref, attention_mask=attention_mask_ref)
                # Separate out logits of review text for further expectation and variance computations
                if abstract_len_ref == 0:
                    logits_ref = outputs_ref.logits[:, :-1, :]
                else:
                    logits_ref = outputs_ref.logits[:, abstract_len_ref:-1, :]
                labels = rev_tok_ref.input_ids[:, 1:]
                if logits_ref.size(1) != labels.size(1):
                    print(f"Warning: logits_ref shape {logits_ref.shape} doesn't match labels shape {labels.shape}")
                    return self.compute_crit(review)
            # ============ SCORE COMPUTATION ============
            res = self.criterion_fn(logits_ref, logits_score, labels)
            crit = res['discrepancy']
            log_likelihood = res['log_likelihood']
        return crit, labels.size(1), log_likelihood


    # compute probability
    def compute_prob(self, review, abstract=None):
        if context and abstract is not None and abstract.strip() and len(abstract.strip()) > 0:
            crit, ntoken, log_likelihood = self.compute_crit_cxt(review, abstract)
        else:
            crit, ntoken, log_likelihood = self.compute_crit(review)
        mu0 = self.classifier['mu0']
        sigma0 = self.classifier['sigma0']
        mu1 = self.classifier['mu1']
        sigma1 = self.classifier['sigma1']
        prob = compute_prob_norm(crit, mu0, sigma0, mu1, sigma1)
        return prob, crit, ntoken, log_likelihood
    
    
def run(args):
    print("Loading FastDetectGPT detector...", flush=True)
    if context:
        print("Using context-aware FastDetectGPT detector", flush=True)
    if humanised:
        print("Using humanised reviews", flush=True)
    if test_only:
        print("Scoring only test set reviews", flush=True)
    if human_only:
        print("Scoring only human reviews", flush=True)
    detector = FastDetectGPT(args)
    if conf == "o":
        conferences = ["acl_2017", "conll_2016", "iclr_2017"]
    elif conf == "n":
        conferences = ["neurips_2013", "neurips_2014", "neurips_2015", "neurips_2016", "neurips_2017"]
    elif conf == "b":
        conferences = ["acl_2017", "conll_2016", "iclr_2017", "neurips_2013", "neurips_2014", "neurips_2015", "neurips_2016", "neurips_2017"]
    else:
        raise ValueError("conf argument must be 'n', 'o' or 'b'")
    
    with open("data/all_paper_texts_intro_conclusion.json", "r") as f:
        all_paper_text = json.load(f)
        
    # Load all data
    if humanised:
        with open("data/humanised_reviews_subset.json", "r") as f:
            all_data = json.load(f)
    elif new_data:
        with open("data/all_conferences_new_data.json", "r") as f:
            all_data = json.load(f)
    else:
        with open("data/all_conferences_final_data.json", "r") as f:
            all_data = json.load(f)
    
    # Determine output path
    suffix = ""
    suffix += "_cxt" if context else ""
    suffix += "_humanised_subset" if humanised else ""
    suffix += "_testonly" if test_only else ""
    if human_only:
        suffix += "_humanonly"
    prefix = "new_data" if new_data else "old_data"
    output_path = f"data/detector_scores/{prefix}_with_fast_detect_gpt{suffix}_{conf}_{args.sampling_model_name}_{args.scoring_model_name}.json"
        
    all_result_data = {}

    for conference in conferences:
        data = all_data.get(conference, [])
        if len(data) == 0:
            print(f"No data found for {conference}, skipping...", flush=True)
            continue
        else:
            print(f"Processing {conference}, {len(data)} items", flush=True)
        result_data = []
        for item in tqdm(data, desc="Scoring items"):
            text = item["text"]
            if test_only and item["set"] == "train":
                continue
            if human_only and item["category"] != "human":
                continue
            if context:
                context_text = all_paper_text.get(conference, {}).get(item["set"], {}).get(item["paper_number"], {}).get("full_text", "")
                if context_text == "":
                    print(f"Context not found for {item['paper_number']} in {conference}, using only text", flush=True)
                prob, crit, ntokens, log_likelihood = detector.compute_prob(text, context_text)
            else:
                prob, crit, ntokens, log_likelihood = detector.compute_prob(text)
            result_data.append({
                "id": item["id"],
                "paper_number": item["paper_number"],
                "model": item["model"],
                "set": item["set"],
                "key": item["key"],
                "category": item["category"],
                "fast_detect_gpt": {
                    "prob": prob,
                    "crit": crit,
                    "ntokens": ntokens,
                    "model": f"{args.sampling_model_name}_{args.scoring_model_name}",
                },
                "log_likelihood": {
                    "value": log_likelihood,
                    "ntokens": ntokens,
                    "normalised": log_likelihood / ntokens,
                    "model": args.scoring_model_name,
                }
            })
        all_result_data[conference] = result_data
        
        with open(output_path, "w") as f:
            json.dump(all_result_data, f, indent=2)
        print(f"Saved results upto {conference} to {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sampling_model_name", type=str, default="falcon-7b")
    parser.add_argument("--scoring_model_name", type=str, default="falcon-7b-instruct")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--cache_dir", type=str, default="../cache")
    args = parser.parse_args()

    run(args)
