import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_score, recall_score, accuracy_score, f1_score
import tqdm as tqdm
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

ce_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
softmax_fn = torch.nn.Softmax(dim=-1)

def load_model(args):
    if args.method == 'RM':
        ModelType = AutoModelForSequenceClassification
    else:
        ModelType = AutoModelForCausalLM
    if args.device == 'cpu':
        dm1, dm2 = 'cpu', 'cpu'
    else:
        dm1, dm2 = args.device.split(',')
        dm1 = f'cuda:{dm1}'
        dm2 = f'cuda:{dm2}'
    if '9b' in args.base_model:
        torch_dtype=torch.bfloat16
    else:
        torch_dtype=torch.float32
    base_model = ModelType.from_pretrained(args.base_model, device_map=dm1, torch_dtype=torch_dtype)
    base_model.eval()
    base_tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    if base_tokenizer.pad_token_id is None:
        base_tokenizer.pad_token_id = base_tokenizer.eos_token_id

    if args.ref_model != None and args.ref_model != args.base_model:
        ref_model = ModelType.from_pretrained(args.ref_model, device_map=dm2, torch_dtype=torch_dtype)
        ref_model.eval()
        ref_tokenizer = AutoTokenizer.from_pretrained(args.ref_model)
        if ref_tokenizer.pad_token_id is None:
            ref_tokenizer.pad_token_id = ref_tokenizer.eos_token_id
    else:
        ref_model, ref_tokenizer = None, None
    
    return base_model, ref_model, base_tokenizer, ref_tokenizer

def evaluate_detectrl(data, get_score, skip_fail=False):
    predictions = {'human': [], 'llm': []}
    eval_results = []
    num_fail = 0
    for item in tqdm.tqdm(data):
        text = item["text"]
        label = item["label"]
        if skip_fail:
            try:
                item['crit'] = get_score(text)
            except:
                num_fail = num_fail + 1
                continue
        else:
            item['crit'] = get_score(text)
        # result
        if label == "human":
            predictions['human'].append(item["crit"])
        elif label == "llm":
            predictions['llm'].append(item["crit"])
        else:
            raise ValueError(f"Unknown label {label}")
        eval_results.append(item)
    predictions['human'] = [i for i in predictions['human'] if np.isfinite(i)]
    predictions['llm'] = [i for i in predictions['llm'] if np.isfinite(i)]
    if skip_fail: # lastde fails when text length is small
        print(f"number fail: {num_fail}")
    return predictions, eval_results

def get_roc_metrics(real_preds, sample_preds):
    real_labels = [0] * len(real_preds) + [1] * len(sample_preds)
    predicted_probs = real_preds + sample_preds

    fpr, tpr, thresholds = roc_curve(real_labels, predicted_probs)
    roc_auc = auc(fpr, tpr)

    # Youden's J statistic
    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]

    predictions = [1 if prob >= optimal_threshold else 0 for prob in predicted_probs]
    conf_matrix = confusion_matrix(real_labels, predictions)
    precision = precision_score(real_labels, predictions)
    recall = recall_score(real_labels, predictions)
    f1 = f1_score(real_labels, predictions)
    accuracy = accuracy_score(real_labels, predictions)
    tpr_at_fpr_0_01 = np.interp(0.01 / 100, fpr, tpr)

    return float(roc_auc), float(optimal_threshold), conf_matrix.tolist(), float(precision), float(recall), float(f1), float(accuracy), float(tpr_at_fpr_0_01)

def get_metrics(real_preds, sample_preds, optimal_threshold):
    real_labels = [0] * len(real_preds) + [1] * len(sample_preds)
    predicted_probs = real_preds + sample_preds

    predictions = [1 if prob >= optimal_threshold else 0 for prob in predicted_probs]
    conf_matrix = confusion_matrix(real_labels, predictions)
    precision = precision_score(real_labels, predictions)
    recall = recall_score(real_labels, predictions)
    f1 = f1_score(real_labels, predictions)
    accuracy = accuracy_score(real_labels, predictions)

    return float(optimal_threshold), conf_matrix.tolist(), float(
        precision), float(recall), float(f1), float(accuracy)

def get_likelihood(logits, labels, return_sum=False):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1

    logits = logits.view(-1, logits.shape[-1])
    labels = labels.view(-1).to(logits.device)
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    log_likelihood = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    if return_sum:
        return log_likelihood.sum().item()
    return log_likelihood.mean().item()


def get_rank(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1

    labels = labels.to(logits.device)
    # get rank of each label token in the model's likelihood ordering
    matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()
    assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}"

    ranks, timesteps = matches[:, -1], matches[:, -2]

    # make sure we got exactly one match for each timestep in the sequence
    assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep"

    ranks = ranks.float() + 1 # convert to 1-indexed rank
    return -ranks.mean().item()

def get_logrank(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1

    labels = labels.to(logits.device)
    # get rank of each label token in the model's likelihood ordering
    matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()
    assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}"

    ranks, timesteps = matches[:, -1], matches[:, -2]

    # make sure we got exactly one match for each timestep in the sequence
    assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep"

    ranks = ranks.float() + 1  # convert to 1-indexed rank
    ranks = torch.log(ranks)
    return -ranks.mean().item()

def get_entropy(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1

    entropy = F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)
    entropy = -entropy.sum(-1)
    return entropy.mean().item()

def get_perplexity(encoding, logits):
    shifted_logits = logits[..., :-1, :].contiguous()
    shifted_labels = encoding.input_ids[..., 1:].contiguous()
    shifted_attention_mask = encoding.attention_mask[..., 1:].contiguous()

    ppl = (ce_loss_fn(shifted_logits.transpose(1, 2), shifted_labels) *
            shifted_attention_mask).sum(1) / shifted_attention_mask.sum(1)
    ppl = ppl.to("cpu").float().numpy()

    return ppl

def get_entropy_binoculars(ref_logits, base_logits, encoding, pad_token_id):
    
    vocab_size = ref_logits.shape[-1]
    total_tokens_available = base_logits.shape[-2]
    p_scores, q_scores = ref_logits , base_logits

    p_proba = softmax_fn(p_scores).view(-1, vocab_size)

    q_scores = q_scores.view(-1, vocab_size)

    ce = ce_loss_fn(input=q_scores, target=p_proba).view(-1, total_tokens_available)
    padding_mask = (encoding.input_ids != pad_token_id).type(torch.uint8)

    agg_ce = (((ce * padding_mask).sum(1) / padding_mask.sum(1)).to("cpu").float().numpy())

    return agg_ce

def get_samples(logits, labels, nsamples=10000):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1
    lprobs = torch.log_softmax(logits, dim=-1)
    distrib = torch.distributions.categorical.Categorical(logits=lprobs)
    samples = distrib.sample([nsamples]).permute([1, 2, 0])
    return samples

def get_sampling_discrepancy_fast_detect_gpt(logits_ref, logits_score, labels):
    assert logits_ref.shape[0] == 1
    assert logits_score.shape[0] == 1
    assert labels.shape[0] == 1
    if logits_ref.size(-1) != logits_score.size(-1):
        # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
        vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
        logits_ref = logits_ref[:, :, :vocab_size]
        logits_score = logits_score[:, :, :vocab_size]

    samples = get_samples(logits_ref, labels)
    log_likelihood_x = get_likelihood(logits_score, labels)
    log_likelihood_x_tilde = get_likelihood(logits_score, samples)
    miu_tilde = log_likelihood_x_tilde.mean(dim=-1)
    sigma_tilde = log_likelihood_x_tilde.std(dim=-1)
    discrepancy = (log_likelihood_x.squeeze(-1) - miu_tilde) / sigma_tilde
    return discrepancy.item()

def get_sampling_discrepancy_analytic_fast_detect_gpt(logits_ref, logits_score, labels):
    assert logits_ref.shape[0] == 1
    assert logits_score.shape[0] == 1
    assert labels.shape[0] == 1
    if logits_ref.size(-1) != logits_score.size(-1):
        # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
        vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
        logits_ref = logits_ref[:, :, :vocab_size]
        logits_score = logits_score[:, :, :vocab_size]

    labels = labels.unsqueeze(-1) if labels.ndim == logits_score.ndim - 1 else labels
    lprobs_score = torch.log_softmax(logits_score, dim=-1)
    probs_ref = torch.softmax(logits_ref, dim=-1)
    log_likelihood = lprobs_score.gather(dim=-1, index=labels).squeeze(-1)
    mean_ref = (probs_ref * lprobs_score).sum(dim=-1)
    var_ref = (probs_ref * torch.square(lprobs_score)).sum(dim=-1) - torch.square(mean_ref)
    discrepancy = (log_likelihood.sum(dim=-1) - mean_ref.sum(dim=-1)) / var_ref.sum(dim=-1).sqrt()
    discrepancy = discrepancy.mean()
    return discrepancy.item()

def get_lrr(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1

    likelihood = get_likelihood(logits, labels)
    logrank = get_logrank(logits, labels)
    return likelihood / logrank

# ============== lastde begin ==============
# https://github.com/TrustMedia-zju/Lastde_Detector/blob/main/py_scripts/baselines/lastde_doubleplus.py
def get_likelihood_lastde_doubleplus(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1
    labels = labels.unsqueeze(-1) if labels.ndim == logits.ndim - 1 else labels
    lprobs = torch.log_softmax(logits, dim=-1)
    log_likelihood = lprobs.gather(dim=-1, index=labels)
    return log_likelihood

def get_sampling_discrepancy_lastde_doubleplus(logits_ref, logits_score, labels):
    assert logits_ref.shape[0] == 1
    assert logits_score.shape[0] == 1
    assert labels.shape[0] == 1
    if logits_ref.size(-1) != logits_score.size(-1):
        # print(f"WARNING: vocabulary size mismatch {logits_ref.size(-1)} vs {logits_score.size(-1)}.")
        vocab_size = min(logits_ref.size(-1), logits_score.size(-1))
        logits_ref = logits_ref[:, :, :vocab_size]
        logits_score = logits_score[:, :, :vocab_size]

    samples = get_samples(logits_ref, labels, nsamples=100)
    log_likelihood_x = get_likelihood_lastde_doubleplus(logits_score, labels)
    log_likelihood_x_tilde = get_likelihood_lastde_doubleplus(logits_score, samples)

    # lastde
    lastde_x = get_lastde_doubleplus(log_likelihood_x)
    sampled_lastde = get_lastde_doubleplus(log_likelihood_x_tilde)

    miu_tilde = sampled_lastde.mean()
    sigma_tilde = sampled_lastde.std()
    discrepancy = (lastde_x - miu_tilde) / sigma_tilde

    return discrepancy.cpu().item()

def get_lastde_doubleplus(log_likelihood):
    embed_size = 4
    epsilon = int(8 * log_likelihood.shape[1])
    tau_prime = 15
    
    templl = log_likelihood.mean(dim=1)

    aggmde = get_tau_multiscale_DE(ori_data = log_likelihood, embed_size=embed_size, epsilon=epsilon, tau_prime=tau_prime)
    lastde = templl / aggmde 
    return lastde

def get_lastde(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1
    labels = labels.unsqueeze(-1) if labels.ndim == logits.ndim - 1 else labels
    lprobs = torch.log_softmax(logits, dim=-1)
    log_likelihood = lprobs.gather(dim=-1, index=labels.to(lprobs.device))
    templl = log_likelihood.mean(dim=1)

    # open-source
    embed_size = 3
    epsilon = 10 * log_likelihood.shape[1] 
    tau_prime = 5

    # closed-source
    # embed_size = 3
    # epsilon = 1 * log_likelihood.shape[1] 
    # tau_prime = 15


    aggmde = get_tau_multiscale_DE(ori_data = log_likelihood, embed_size=embed_size, epsilon=epsilon, tau_prime=tau_prime)
    lastde = templl / aggmde 
    return lastde.cpu().item()

def histcounts(data, epsilon, min_=-1, max_=1):
    """
    example: data = [0.6054744899487247, 0.6986512231376916, 0.9243823257809534, 0.9308167830778726], epsilon = 10, range(-1, 1)
             [0.6054744899487247, 0.6986512231376916, 0.9243823257809534, 0.9308167830778726] ===> [0. 0. 0. 0. 0. 0. 0. 0. 2. 2.] ===> [0. 0. 0. 0. 0. 0. 0. 0. 0.5. 0.5.]
    params:
        range(min,max)   ===> state_interval
        epsilon ===> epsilon-level split
        data    ===> orbits_cosine_similarity_sequence
    return:
        hist : a list about each interval frequence
        statistical_probabilities_sequence: statistical probabilities of epsilon intervals
    """
    data = data.float()
    hist = torch.histc(data, bins=epsilon, min=min_, max=max_)
    statistical_probabilities_sequence = hist / torch.sum(hist)
    return hist, statistical_probabilities_sequence

def DE(statistical_probabilities_sequence, epsilon):
    """
    example: statistical_probabilities_sequence = [0. 0. 0. 0. 0. 0. 0. 0. 2. 2.], epsilon = 10
             [0. 0. 0. 0. 0. 0. 0. 0. 2. 2.] ===> 0.301
    params:
        statistical_probabilities_sequence ===> statistical probabilities of epsilon intervals
        epsilon                            ===> epsilon-level split
    return: DE_value
    """
    # caculate de value
    DE_value = -1 / torch.log(torch.tensor(epsilon)) * torch.nansum(statistical_probabilities_sequence * torch.log(statistical_probabilities_sequence), dim=0)
    
    return DE_value

def calculate_DE(ori_data, embed_size, epsilon):
    """
    example：ori_data = [1, 2, 13, 7, 9, 5, 4], embed_size = 3, epsilon = 10
             [1, 2, 13, 7, 9, 5, 4] ===> 0.9896002614175352
    params：
        ori_data       ===> sequence data
        embedding_size ===> dimension of new space
        epsilon        ===> Divide the interval [-1, 1] into epsilon equal segments.
    return： DE_value
    """
    # build orbits along second dimension, operate token_length-dimension(1,embed_*,*)

    orbits = ori_data.unfold(1, embed_size, 1)  # [1, token_length, samples_size]---> [1, token_length-embed_size+1, samples_size, embed_size]
    # calculate cosine similarity of orbits
    orbits_cosine_similarity_sequence = torch.nn.functional.cosine_similarity(orbits[:, :-1], orbits[:, 1:], dim=-1) # [1, token_length-embed_size+1, samples_size, embed_size]---> [1, token_length-embed_size, samples_size]
    # Placing the cosine similarity into intervals, operate sample_size-dimension(in_dims=-1)
    batched_1 = torch.vmap(histcounts, in_dims=-1, out_dims=1) 
    hist, statistical_probabilities_sequence = batched_1(orbits_cosine_similarity_sequence, epsilon=epsilon)  
    # calculate de
    DE_value = DE(statistical_probabilities_sequence, epsilon)
    # print(DE_value)
    return DE_value

def get_tau_scale_DE(ori_data, embed_size, epsilon, tau):
    """
    example: ori_data = [1, 2, 13, 7, 9, 5, 4], embed_size=3, epsilon = 10,  tau = 2
             [1, 2, 13, 7, 9, 5, 4] ===> [1.5, 7.5, 10.0, 8.0, 7.0, 4.5] ==> de_value([1.5, 7.5, 10.0, 8.0, 7.0, 4.5])
    params:
        ori_data ===> sequence data
        embedding_size ===> dimension of new space
        epsilon        ===> Divide the interval [-1, 1] into epsilon equal segments.
        tau      ===> tau-level sequence of ori_data
    return: tau_scale_de
    """
    # get sub-series
    windows = ori_data.unfold(1, tau, 1) 
    tau_scale_sequence = torch.mean(windows, dim=3) # Pay attention, in this case dim=3
    # caculate tau_scale de value
    de = calculate_DE(tau_scale_sequence, embed_size, epsilon)
    # return de.unsqueeze(0)
    return de

def get_tau_multiscale_DE(ori_data, embed_size, epsilon,  tau_prime):
    """
    example: ori_data = [1, 2, 13, 7, 9, 5, 4], embed_size=3, epsilon = 10,  tau = 3
             [1, 2, 13, 7, 9, 5, 4] ===> {'tau = 1': [1.0, 2.0, 13.0, 7.0, 9.0, 5.0, 4.0], 'tau = 2': [1.5, 7.5, 10.0, 8.0, 7.0, 4.5], 'tau = 3': [5.333333333333333, 7.333333333333333, 9.666666666666666, 7.0, 6.0]}
                                    ===> {'tau = 1': de([1.0, 2.0, 13.0, 7.0, 9.0, 5.0, 4.0]), 'tau = 2': de([1.5, 7.5, 10.0, 8.0, 7.0, 4.5]), 'tau = 3': de([5.333333333333333, 7.333333333333333, 9.666666666666666, 7.0, 6.0])}
                                    ===> [0.30102999566398114, -0.0, -0.0]
                                    ===> std[0.30102999566398114, -0.0, -0.0])
    params:
        ori_data ===> sequence data
        embedding_size ===> dimension of new space
        epsilon        ===> Divide the interval [-1, 1] into epsilon equal segments.
        tau_prime ===> multiscale_sequence from tau_1 to tau_prime
    return: mde
    """
    # mde = torch.zeros(tau_prime)
    mde = []
    for temp_tau in range(1,tau_prime + 1):
        value = get_tau_scale_DE(ori_data, embed_size, epsilon, temp_tau)
        mde.append(value) 
    mde = torch.stack(mde, dim=0)
    std_mde = torch.std(mde, dim=0) 
    return std_mde
    
    # can also try 
    # expstd_mde = torch.exp(std_mde) 
    # return expstd_mde

# ============== lastde end ==============