import os, sys
import torch
import logging
from utils import gather_log_probs, mask_hf_labels, masked_mean

LOG = logging.getLogger(__name__)
def es_sentiment(pre_logits, post_logits, raw_targets, same_sent_mask, q_mask=None, NULL_TOKEN=0):
    with torch.no_grad():
        mask, targ = mask_hf_labels(raw_targets)
        # print("mask:", mask.shape)
        # print(mask)
        # print("q_mask:", q_mask.shape)
        # print(q_mask)
        # print("targ:", targ.shape)
        # print(targ)
        # print("same_sent_mask:", same_sent_mask.shape)
        # print("same_sent_mask.unsqueeze(-1):", same_sent_mask.unsqueeze(-1).shape)
        pos_mask = same_sent_mask.unsqueeze(-1) * mask if q_mask is None else same_sent_mask.unsqueeze(-1) * q_mask
        neg_mask = (~same_sent_mask).unsqueeze(-1) * mask if q_mask is None else (~same_sent_mask).unsqueeze(-1) * q_mask
        
        # print("pos_mask:", pos_mask.shape)

        # Compute log likelihoods of pos/neg samples
        pre_edit_token_log_probs = gather_log_probs(pre_logits, targ)
        post_edit_token_log_probs = gather_log_probs(post_logits, targ)

        mean_pos_pre = masked_mean(pre_edit_token_log_probs, pos_mask)
        mean_pos_post = masked_mean(post_edit_token_log_probs, pos_mask)
        mean_neg_post = masked_mean(post_edit_token_log_probs, neg_mask)
        # print("mean_pos_pre:", mean_pos_pre.shape)
        # print("mean_pos_post:", mean_pos_post.shape)
        # print("mean_neg_post:", mean_neg_post.shape)
        # sys.exit(0)

        z_sent = (mean_pos_post - mean_neg_post).sigmoid()
        z_topic_raw = (mean_pos_post - mean_pos_pre).exp()
        z_topic = min(1, z_topic_raw)

        es_sent = z_sent * z_topic
        return {
            "acc_sent": es_sent,
            "z_sent": z_sent,
            "z_topic": z_topic,
            "z_topic_raw": z_topic_raw,
            "correct_probs": mean_pos_post,
            "wrong_probs": mean_neg_post,
        }


# DEPRECATED
def sent_success(pre_edit_probs, post_edit_probs, pos_mask, eps=torch.finfo(torch.float32).eps, batch_size=20):
    assert False, "No longer used"
    # content_score = post_edit_probs[pos_mask].prod() ** (1/pos_mask.sum()) / (pre_edit_probs[pos_mask]. + eps)
    post_pos_avg = post_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
    pre_pos_avg = pre_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
    content_score = post_pos_avg / (pre_pos_avg + eps)
    z_content = min(1., content_score)

    # compute z_sent through a weighting objective
    # normalized_probs = post_edit_probs / (post_edit_probs.sum() + eps)
    # balancing_factor = 0.5 * ((~pos_mask).float().sum() / pos_mask.float().sum() + 1)
    # z_sent_weight = balancing_factor * normalized_probs.dot(pos_mask.float())
    post_neg_avg = post_edit_probs[~pos_mask].prod() ** (1 / (~pos_mask).sum())
    neg_over_pos = post_neg_avg / (eps + post_pos_avg)
    z_sent_weight = 1 / (1 + neg_over_pos)

    # compute z_sent through a ranking objective
    batch_mask = pos_mask.view(-1, batch_size).long()
    sort_idxs = post_edit_probs.view(-1, batch_size).sort(-1, descending=True).indices
    ranked_mask = batch_mask.gather(1, sort_idxs)
    true_mask = batch_mask.sort(-1, descending=True).values
    z_sent_rank = (ranked_mask == true_mask).float().mean()

    # compute the final success scores
    weight_success = (z_content * z_sent_weight) ** 0.5
    rank_success = (z_content * z_sent_rank) ** 0.5

    correct_probs = post_edit_probs[pos_mask].mean()
    wrong_probs = post_edit_probs[~pos_mask].mean()

    return {
        "acc_weight": weight_success,
        "acc_rank": rank_success,
        "rank_score": z_sent_rank,
        "weight_score": z_sent_weight,
        "content_score": content_score,
        "post_edit_probs": post_edit_probs.sum(),
        "pre_edit_probs": pre_edit_probs.sum(),
        "correct_probs": correct_probs,
        "wrong_probs": wrong_probs
    }


# def sent_retain(pre_logits, post_logits, sent_mask, batch_size=20, eps=torch.finfo(torch.float32).eps):
#     pre_log_probs = pre_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1)
#     post_log_probs = post_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1)

#     pre_batch = pre_probs.view(-1, batch_size)
#     post_batch = post_probs.view(-1, batch_size)
#     mask_batch = sent_mask.view(-1, batch_size)

#     stats = []
#     for pre, post, mask in zip(pre_batch, post_batch, mask_batch):
#         avg_pre = pre.prod() ** (1 / pre.numel())
#         avg_post = post.prod() ** (1 / post.numel())
#         z_avg = min(avg_pre / avg_post, avg_post / avg_pre)

#         post_neg_avg = post[~mask].prod() ** (1 / (~mask).sum())
#         post_pos_avg = post[mask].prod() ** (1 / mask.sum())

#         pre_neg_avg = pre[~mask].prod() ** (1 / (~mask).sum())
#         pre_pos_avg = pre[mask].prod() ** (1 / mask.sum())

#         post_neg_over_pos = post_neg_avg / (eps + post_pos_avg)
#         pre_neg_over_pos = pre_neg_avg / (eps + pre_pos_avg)
#         z_post = 1 / (1 + post_neg_over_pos)
#         z_pre = 1 / (1 + pre_neg_over_pos)

#         z_sent = min(z_post / z_pre, z_pre / z_post)

#         stats.append((z_avg * z_sent) ** 0.5)

#     return sum(stats) / len(stats)


# For zsRE and F-NLI
def retain_rate(pre_logits, post_logits, mask=None):
    if pre_logits.shape[-1] == 1:
        pre_logits = pre_logits.squeeze(-1)
    if post_logits.shape[-1] == 1:
        post_logits = post_logits.squeeze(-1)

    assert pre_logits.shape == post_logits.shape
    assert pre_logits.shape[0] == mask.shape[0]

    if pre_logits.dim() == 1:
        # binary classification
        pre_preds = pre_logits > 0
        post_preds = post_logits > 0
        retain = (pre_preds == post_preds).float().mean()
    elif pre_logits.dim() == 3:
        # sequence modeling
        pre_preds = pre_logits.argmax(-1)
        post_preds = post_logits.argmax(-1)
        match = (pre_preds == post_preds) * mask
        retain = (match.sum(-1) == mask.sum(-1)).float().mean()
    else:
        raise NotImplementedError

    return retain.item()
