from transformers import Qwen2PreTrainedModel, Qwen2Model
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.cache_utils import Cache
from typing import List, Optional, Tuple, Union
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy


class Qwen2ForQSharp(Qwen2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        num_labels = config.num_labels
        self.model = Qwen2Model(config)
        self.score = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, num_labels),
        )
        self.p_dropout = config.attention_dropout
        print(f"GOT DROPOUT {self.p_dropout}")
        self.score_dropout = nn.Dropout(self.p_dropout)

        self.inference_impl = "naive"

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        loss_mask: Optional[torch.Tensor] = None,
        continuation_ids: Optional[torch.LongTensor] = None,
        continuation_attention_mask: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
        """
        During training:
        - labels should not be None and have shape: [bs, 1]
        - input_ids: [bs, seqlen]
        - loss_mask [bs, seqlen]

        During inference:
        labels, loss_mask should be None
        continuation_ids is [bs, N, c_len].
        If input_ids is [bs, seqlen], this is prefill stage.
        Otherwise, input_ids is also [bs, c_len] which contains the chosen continuation from last step. And we update the kv_cache.
        Here, attention_mask should be [bs, q_len] where q_len is seqlen + len of continuations so far.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        assert return_dict, "Only return_dict=True is supported."
        is_training = labels is not None
        is_update_kv_cache = input_ids is not None

        if is_training:
            transformer_outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            hidden_states = transformer_outputs[0]  # [bs, seqlen, hidden_dim]
            logits = self.score(self.score_dropout(hidden_states)).float()  # [bs, seqlen, 1]

            # use BCE loss for math
            logits = logits.squeeze(-1)  # [bs, seqlen]
            labels_expanded = labels.unsqueeze(-1).expand_as(logits)
            loss = F.binary_cross_entropy_with_logits(logits, labels_expanded, reduction="none")  # [bs, seqlen]
            # avg over seqlen and bs
            loss = (loss * loss_mask).sum(1) / loss_mask.sum(1)
            loss = loss.mean()
            return SequenceClassifierOutputWithPast(loss=loss, logits=logits)
        elif is_update_kv_cache:
            assert continuation_ids is None
            transformer_outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            return SequenceClassifierOutputWithPast(past_key_values=transformer_outputs.past_key_values)
        else:
            bs, N, clen = continuation_ids.shape
            if self.inference_impl == "naive":
                classifier_logits = []
                for i in range(N):
                    # print(i, past_key_values[0][0].shape)
                    next_input_ids = continuation_ids[:, i]
                    cur_attention_mask = torch.cat([attention_mask, continuation_attention_mask[:, i]], dim=1)
                    transformer_outputs = self.model(next_input_ids, attention_mask=cur_attention_mask,
                                                     past_key_values=past_key_values, use_cache=use_cache)
                    hidden_states = transformer_outputs[0]  # [bs, clen, hidden_size]
                    # take the last index where attn_mask is true
                    last_valid_index = torch.sum(continuation_attention_mask[:, i], dim=1) - 1  # [bs]
                    hidden_states = hidden_states[torch.arange(bs), last_valid_index]  # [bs, hidden_size]
                    logits = self.score(hidden_states)  # [bs, 1]
                    classifier_logits.append(logits)

                    # clear kv_cache
                    if past_key_values is not None:
                        past_key_values.crop(-clen)

                logits = torch.cat(classifier_logits, dim=1)  # [bs, N]
                return SequenceClassifierOutputWithPast(logits=logits)

            elif self.inference_impl == "efficient":
                device = continuation_ids.device
                if input_ids is None:
                    input_ids = torch.zeros((bs, 0), dtype=torch.long, device=device)
                    bs, q_len = input_ids.shape
                _, kv_len = attention_mask.shape  # includes query and any previous kv cache

                # [bs, 1, q_len + total_clen, kv_len + total_clen]
                total_clen = N * clen
                # t0 = time.time()
                attn_mask_4d = create_4d_attn_mask(bs, q_len, kv_len, attention_mask, clen, N, continuation_attention_mask, device)
                # torch.cuda.synchronize()
                # print("Time to create attn_mask_4d:", time.time() - t0)

                # t0 = time.time()
                # done with causal mask, now create other args
                input_ids_with_cont = torch.cat([input_ids, torch.flatten(continuation_ids, start_dim=1)], dim=1)  # [bs, q_len + total_clen]
                past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                assert past_seen_tokens == kv_len - q_len, f"past_seen_tokens: {past_seen_tokens}, kv_len: {kv_len}, q_len: {q_len}"
                cache_position = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=device)
                cont_position_ids = torch.arange(past_seen_tokens + q_len, past_seen_tokens + q_len + clen, device=device).repeat(N)  # [total_clen]
                position_ids = torch.cat([cache_position, cont_position_ids]).unsqueeze(0) # [1, q_len + total_clen]
                transformer_outputs = self.model(input_ids_with_cont, attention_mask=attn_mask_4d, position_ids=position_ids,
                                    past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position)
                # torch.cuda.synchronize()
                # print("Time to forward:", time.time() - t0)

                hidden_states = transformer_outputs[0]  # [bs, q_len + total_clen, hidden_size]
                hidden_states = hidden_states[:, -total_clen:].view(bs, N, clen, -1)  # [bs, N, clen, hidden_size]
                last_valid_index = torch.sum(continuation_attention_mask, dim=2) - 1  # [bs, N]
                hidden_states = hidden_states[torch.arange(bs).unsqueeze(1), torch.arange(N).unsqueeze(0), last_valid_index]  # [bs, N, hidden_size]
                logits = self.score(hidden_states).squeeze(-1)  # [bs, N]

                # slice kv_cache by ignoring the continuations
                if past_key_values is not None:
                    past_key_values.crop(-N*clen)
                return SequenceClassifierOutputWithPast(logits=logits)

            else:
                raise NotImplementedError(f"Unknown inference_impl: {self.inference_impl}")


def create_causal_mask(d, device, inverted=True):
    if inverted:
        mask = torch.full((d, d), fill_value=float("-inf"), dtype=torch.float, device=device)
        mask = torch.triu(mask, diagonal=1)
    else:
        mask = torch.tril(torch.ones((d, d), dtype=torch.bool, device=device), diagonal=0)
    return mask  # [d, d]


def invert_mask(mask):
    inverted_mask = torch.full_like(mask, fill_value=float("-inf"), dtype=torch.float)
    # set masked elements to 0 in a functional way
    # inverted_mask[mask.bool()] = 0
    inverted_mask = inverted_mask.masked_fill(mask, 0)
    return inverted_mask


def create_4d_attn_mask(bs, q_len, kv_len, attn_mask_2d, clen, N, attn_mask_cont, device):
    # start with causal mask for input_ids (masked by attention_mask) and zero mask for continuation_ids
    assert attn_mask_2d.shape == (bs, kv_len)
    assert attn_mask_cont.shape == (bs, N, clen)
    q_len_causal_mask = create_causal_mask(q_len, device)  # [q_len, q_len]
    if q_len == kv_len:
        causal_mask = q_len_causal_mask
    else:
        # can attend to all previous tokens
        assert kv_len > q_len, f"kv_len: {kv_len}, q_len: {q_len}"
        causal_mask = torch.zeros((q_len, kv_len), dtype=torch.float, device=device)  # [q_len, kv_len]
        if q_len > 0:
            causal_mask[:, -q_len:] = q_len_causal_mask

    causal_mask = causal_mask[None, None, :, :].expand(bs, 1, -1, -1)  # [bs, 1, q_len, kv_len]
    causal_mask = causal_mask.clone()   # copy to contiguous memory for in-place edit
    padding_mask = causal_mask + attn_mask_2d[:, None, None, :]  # [bs, 1, q_len, kv_len]
    padding_mask = padding_mask == 0  # causal mask is True and attn_mask is False
    causal_mask = causal_mask.masked_fill(padding_mask, float("-inf"))  # [bs, 1, q_len, kv_len]

    total_clen = N * clen
    # first, the queried tokens shouldn't attend to the continuation_ids
    causal_mask = torch.cat([
        causal_mask,
        torch.full((bs, 1, q_len, total_clen), fill_value=float("-inf"), dtype=torch.float, device=device)
    ], dim=3)  # [bs, 1, q_len, kv_len + total_clen]
    # second, the continuations should attend to all previous tokens
    all_prev_tokens_mask = torch.zeros((bs, 1, total_clen, kv_len+total_clen), dtype=torch.float, device=device)  # [bs, 1, total_clen, kv_len + total_clen]
    # third, continuations should attend to themselves causally
    clen_causal_mask = create_causal_mask(clen, device, inverted=False)  # [clen, clen]
    block_diag_causal_mask = torch.kron(torch.eye(N, device=device, dtype=torch.bool), clen_causal_mask)  # [total_clen, total_clen]
    # technically, we should apply the attn_mask_cont here as well.
    # but we are right padding continuations, so I think it's not needed. Just need to extract the right logit.
    all_prev_tokens_mask[:,:,:,-total_clen:] = invert_mask(block_diag_causal_mask)

    # combine masks
    causal_mask = torch.cat([causal_mask, all_prev_tokens_mask], dim=2)  # [bs, 1, q_len + total_clen, kv_len + total_clen]
    return causal_mask


def convert_binary_linear(linear_layer):
    """
    Convert a nn.Linear layer with 2 outputs to a single output layer
    for binary classification with sigmoid activation.
    Preserves device and dtype of the original layer.

    Args:
        linear_layer (nn.Linear): Linear layer with in_features -> 2

    Returns:
        nn.Linear: New linear layer with in_features -> 1
    """
    import torch.nn as nn

    # Extract weights and biases
    weights = linear_layer.weight.data  # Shape: [2, in_features]
    biases = linear_layer.bias.data     # Shape: [2]

    if weights.shape[0] != 2:
        raise ValueError(f"Expected 2 output features, got {weights.shape[0]}")

    # Get device and dtype
    device = weights.device
    dtype = weights.dtype

    # Create new linear layer with 1 output on the correct device/dtype
    new_layer = nn.Linear(linear_layer.in_features, 1)
    new_layer = new_layer.to(device=device, dtype=dtype)

    # Set weights to w₂ - w₁ (already on correct device)
    new_layer.weight.data = weights[1:2] - weights[0:1]

    # Set bias to b₂ - b₁ (already on correct device)
    new_layer.bias.data[0] = biases[1] - biases[0]

    return new_layer