import torch
from collections import defaultdict


def compute_ktae_outcome_advantage_and_keytokens(token_level_rewards: torch.Tensor,
                                                    responses: torch.Tensor,
                                                    eos_mask: torch.Tensor,
                                                    index: torch.Tensor,
                                                    epsilon: float = 1e-6,
                                                    config=None):
    response_length = token_level_rewards.shape[-1]
    from compute_key_tokens import ComputeKeyTokens

    def average_weight_by_mask_tensor(weight, mask, n):
        assert weight.shape == mask.shape, "Weight and mask must have same shape"
        result = weight.clone()
        for i in range(weight.size(0)):
            current_mask = mask[i]
            masked_indices = current_mask.nonzero(as_tuple=True)[0]
            num_elements = masked_indices.size(0)
            if num_elements == 0:
                continue
            masked_values = weight[i].index_select(0, masked_indices)
            group_indices = torch.arange(num_elements, device=weight.device) // n
            group_sums = torch.zeros(group_indices.max() + 1, dtype=weight.dtype, device=weight.device)
            group_sums.index_add_(0, group_indices, masked_values)
            group_counts = torch.bincount(group_indices)
            group_means = group_sums / group_counts
            expanded_means = group_means.repeat_interleave(group_counts)
            result[i].index_put_((masked_indices,), expanded_means)
        return result

    id2score = defaultdict(list)
    id2reponses = {}
    id2mask = {}
    id2mean = {}
    id2std = {}
    scores = token_level_rewards.sum(dim=-1)
    with torch.no_grad():
        bsz = token_level_rewards.shape[0]
        for i in range(bsz):
            id2score[index[i]].append(scores[i])
            if index[i] in id2reponses:
                id2reponses[index[i]] = torch.cat((id2reponses[index[i]], responses[i].unsqueeze(0)), dim=0)
            else:
                id2reponses[index[i]] = responses[i].unsqueeze(0)
            if index[i] in id2mask:
                id2mask[index[i]] = torch.cat((id2mask[index[i]], eos_mask[i].unsqueeze(0)), dim=0)
            else:
                id2mask[index[i]] = eos_mask[i].unsqueeze(0)
        id2key_token = {}
        for idx in id2score:
            reponses_per_q = id2reponses[idx]
            mask_per_q = id2mask[idx]
            score_per_q = id2score[idx]
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
                id2key_token[idx] = torch.zeros([responses.max().item()], device=responses.device)
            elif len(id2score[idx]) > 1:
                format_score_per_q = torch.tensor(score_per_q)
                id2mean[idx] = torch.mean(format_score_per_q)
                id2std[idx] = torch.std(format_score_per_q)
                computer = ComputeKeyTokens(alpha=1.0,
                                            beta_ig=2.0,
                                            gamma_tf=1.0,
                                            top=1.0,
                                            bottom=-1.0,
                                            responses_ids=reponses_per_q,
                                            mask=mask_per_q,
                                            rewards=format_score_per_q,
                                            max_token_num=responses.max().item())
                
                key_tokens = computer.get_key_tokens().to(reponses_per_q.device)
                id2key_token[idx] = key_tokens
            else:
                raise ValueError(f"no score in prompt index: {idx}")
        means = torch.tensor([id2mean[index[i]] for i in range(bsz)], device=scores.device)
        stds = torch.tensor([id2std[index[i]] for i in range(bsz)], device=scores.device)
        scores = (scores - means) / (stds + epsilon)
        format_weights = [id2key_token[index[i]][responses[i]].unsqueeze(0) for i in range(bsz)]
        all_weight = torch.cat(format_weights, dim=0)
        scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + all_weight * eos_mask
    return scores, scores