import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
import logging
from typing import Optional

logger = logging.getLogger(__name__)  # Use getLogger (camelCase)


class MyGPT2(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # Get maximum sequence length N from config
        self.n_positions = config.n_positions
        # Initialize P as an identity matrix
        # self.P = nn.Parameter(torch.eye(self.n_positions))
        self.P = nn.Parameter(torch.randn(20, 20))  # Random initialization
        self.relu = nn.ReLU()
        # self.P = nn.Parameter(torch.eye(20, 20))  # Initialize with a small value
        # Enable gradient calculation for P
        self.P.requires_grad_(True)
        self.permutation = None

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        stage=None,
        gumbel_tau: float = 0.1,  # Get tau from args ideally
        gumbel_iters: int = 20,  # Get iters from args ideally
    ):
        output_attentions = True  # Always output attention for entropy calculation of layer 1
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # breakpoint()

        # Apply permutation P to the input embedding before entering the transformer stack
        # Get token embedding
        if inputs_embeds is None:
            inputs_embeds = self.transformer.wte(input_ids)
        hidden_states_unpermuted = inputs_embeds
        # Apply permutation P
        # The shape of P is (N, N), the shape of hidden_states_unpermuted is (batch_size, seq_len, hidden_dim)
        # It is necessary to match the dimensions. Assuming seq_len <= N. Pad P or hidden_states as needed
        batch_size, seq_len, hidden_dim = hidden_states_unpermuted.shape
        if seq_len > self.n_positions:
            # Handle case where seq_len is larger than initialized P
            logger.warning(
                f"Sequence length ({seq_len}) is greater than P matrix size ({self.n_positions}). Truncating P."
            )
            current_P_logits = self.P[:seq_len, :seq_len]
        else:
            current_P_logits = self.P[:seq_len, :seq_len]

        # --- Calculate Permutation Matrix ---
        # P is always calculated based on the relevant part of self.P, regardless of fixed segments
        self.permutation = sinkhorn(current_P_logits, n_iters=gumbel_iters)
        # self.permutation = self.relu(current_P_logits)  # Use ReLU to ensure non-negativity

        # --- Apply Permutation Selectively ---
        input_length = 20
        if seq_len > input_length + 1:  # Ensure there are target tokens to permute
            # 1. Extract segments
            input_embeds = hidden_states_unpermuted[:, :input_length, :]  # batch size, input_length, hidden_dim
            target_embeds_unpermuted = hidden_states_unpermuted[
                :, input_length:-1, :
            ]  # Target is between input and EOS
            eos_embeds = hidden_states_unpermuted[:, -1:, :]  # Keep dimension for concat

            # 2. Extract the submatrix of P corresponding to the target tokens
            P_target = self.permutation
            # breakpoint()
            # 3. Permute only the target segment
            # Original: target_embeds_permuted = torch.einsum("ij,bjd->bid", P_target, target_embeds_unpermuted)
            # P_target: (target_len, target_len)
            # target_embeds_unpermuted: (batch_size, target_len, hidden_dim)
            # Result: (batch_size, target_len, hidden_dim)
            # We want to apply P_target to each element in the batch.
            # P_target acts on the seq_len dimension (dim=1) of target_embeds_unpermuted.
            # This is equivalent to a batch matrix multiplication.
            # P_target needs to be expanded to (batch_size, target_len, target_len) for bmm.
            P_target_expanded = P_target.unsqueeze(0).expand(batch_size, -1, -1)
            target_embeds_permuted = torch.bmm(P_target_expanded, target_embeds_unpermuted)

            # 4. Concatenate segments back together
            hidden_states_permuted = torch.cat([input_embeds, target_embeds_permuted, eos_embeds], dim=1)
            # breakpoint()
        else:
            # If sequence is too short (no target tokens), pass embeddings through without permutation
            logger.warning(
                f"Sequence length ({seq_len}) is too short to separate input/target/EOS. Skipping permutation."
            )
            hidden_states_permuted = hidden_states_unpermuted

        # --- Original Transformer Call ---
        # Pass the selectively permuted embeddings to the transformer
        # breakpoint()
        transformer_outputs = self.transformer(
            input_ids=None,  # Use embeddings instead
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=hidden_states_permuted,  # Use the modified embeddings
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # breakpoint()

        hidden_states_from_transformer = transformer_outputs[0]

        # --- Apply Inverse Permutation to Hidden States --- (Moved from logits)
        hidden_states_final_order = hidden_states_from_transformer  # Default if no permutation or error
        if seq_len > input_length + 1 and self.permutation is not None:
            try:
                # 1. Extract segments from hidden_states_from_transformer
                hs_input = hidden_states_from_transformer[:, :input_length, :]  # batch_size, input_length, hidden_dim
                hs_target_permuted = hidden_states_from_transformer[
                    :, input_length:-1, :
                ]  # batch_size, target_len, hidden_dim
                hs_eos = hidden_states_from_transformer[:, -1:, :]  # batch_size, 1, hidden_dim

                # 2. Get the inverse (transpose) of the permutation matrix P
                target_len = hs_target_permuted.shape[1]
                # P_inv should be the transpose of the P used for the target segment permutation
                # Assuming self.permutation is the P_target used earlier (seq_len x seq_len) or (target_len x target_len)
                # If self.permutation was (full_seq_len, full_seq_len), we need the submatrix for target.
                # Current P logic: current_P_logits = self.P[:seq_len, :seq_len], then P = sinkhorn(current_P_logits)
                # So self.permutation refers to P for the current seq_len.
                # If input_length was fixed, P_target was self.permutation without slicing.
                # Here we assume self.permutation is already the correct P matrix for the target block (target_len x target_len)
                # This was P_target from the forward permutation: P_target = P (where P was for seq_len, target was cut from seq_len)
                # So, if P was derived from self.P[:seq_len, :seq_len], P_target for embeddings was effectively P itself.
                # And P_inv here should correspond to that P.
                P_inv = self.permutation[:target_len, :target_len].T  # Use the relevant part of stored P and transpose

                # 3. Apply inverse permutation to the target hidden states
                batch_size = hs_target_permuted.shape[0]  # Get batch_size from hs_target_permuted
                P_inv_expanded = P_inv.unsqueeze(0).expand(batch_size, -1, -1)
                hs_target_original = torch.bmm(P_inv_expanded, hs_target_permuted)

                # 4. Concatenate segments back together
                hidden_states_final_order = torch.cat([hs_input, hs_target_original, hs_eos], dim=1)
            except Exception as e:
                logger.error(
                    f"Error during inverse permutation of hidden states: {e}. Using permuted hidden states.",
                    exc_info=True,
                )
                # Fallback to using the permuted hidden states if error occurs
                hidden_states_final_order = hidden_states_from_transformer
        elif self.permutation is None and seq_len > input_length + 1:
            logger.warning(
                "Sequence length suggests permutation should have occurred for hidden states, but self.permutation is None. Using original hidden states order from transformer."
            )
            hidden_states_final_order = hidden_states_from_transformer

        # Now, calculate lm_logits using the (potentially) reordered hidden states
        lm_logits_final = self.lm_head(hidden_states_final_order)

        # --- Original Logits and Loss Calculation --- (No change here, uses lm_logits_final)
        loss = None
        if labels is not None:
            # Use the potentially reordered logits (lm_logits_final) for loss calculation
            # Shift so that tokens < n predict token n
            shift_logits = lm_logits_final[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            # Use lm_logits_final for the output tuple as well
            output = (lm_logits_final,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits_final,  # Return the potentially reordered logits
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,  # Original hidden_states from transformer (permuted if input was permuted)
            attentions=transformer_outputs.attentions,  # Make sure attention is included
            cross_attentions=transformer_outputs.cross_attentions,
        )


class NselectGPT2(GPT2LMHeadModel):
    def __init__(
        self,
        config,
        **kwargs,
    ):  # Add **kwargs to accept other HuggingFace arguments
        super().__init__(config, **kwargs)  # Pass **kwargs
        self.input_length = 20
        self.n_exp = 1
        self.Ps = make_perm_family(self.input_length, self.n_exp)  # Create permutation family for input length
        # self.Ps needs to be moved to the appropriate device, but using hidden_states_unpermuted.device in forward
        # makes .to(self.transformer.wte.weight.device) here potentially different from the device actually used.
        # It may be safer to remove or comment out .to(device) here, as self.Ps[i].to(device) is done in forward.
        # For now, following the user's recent edit, we will leave it as is.
        self.Ps = self.Ps.to(self.transformer.wte.weight.device)  # Ensure same device as model weights
        for i in range(len(self.Ps)):
            self.register_buffer(f"P{i}", self.Ps[i])

        self.perm_logits = nn.Parameter(torch.zeros(2**self.n_exp))  # Random initialization
        self.permutation = None

        # Parameters for temperature-controlled softmax
        self.initial_temperature = 1.0  # Initial temperature
        self.final_temperature = 0.05  # Final temperature (learning may become unstable if too small)
        # Be careful that final_temperature does not become zero or negative
        self.final_temperature = max(self.final_temperature, 1e-6)
        if self.initial_temperature < self.final_temperature:
            logger.warning(
                f"Initial temperature ({self.initial_temperature}) is less than final temperature ({self.final_temperature}). Temperature will increase or stay at final_temperature."
            )

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        global_step: Optional[int] = None,  # Learning step
        max_steps: Optional[int] = None,  # Total learning steps
    ):
        output_attentions = True  # Always output attention for entropy calculation of layer 1
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            inputs_embeds = self.transformer.wte(input_ids)
        hidden_states_unpermuted = inputs_embeds
        batch_size, seq_len, hidden_dim = hidden_states_unpermuted.shape

        device = hidden_states_unpermuted.device

        input_len = self.input_length  # This should be consistent with perm_matrix_size
        P_target_perm = None
        target_len = seq_len - input_len - 1  # Calculate target length based on input_length
        # if target_len > 0:
        #     # Temperature calculation
        #     current_temp = self.initial_temperature
        #     if global_step is not None and max_steps is not None and max_steps > 0 and global_step >= 0:
        #         progress = min(float(global_step) / max_steps, 1.0)
        #         # Linearly decrease the temperature
        #         current_temp = self.initial_temperature - (self.initial_temperature - self.final_temperature) * progress
        #         # Clip the temperature so that it does not fall below final_temperature and does not become extremely small
        #         current_temp = max(current_temp, self.final_temperature)
        #         current_temp = max(current_temp, 1e-6)  # Prevent division by zero or numerical instability
        #     else:
        #         if global_step is not None or max_steps is not None:  # If one is None and the other is provided
        #             logger.warning(
        #                 f"Temperature scheduling requires both global_step and max_steps. "
        #                 f"Using initial_temperature={self.initial_temperature}. "
        #                 f"global_step={global_step}, max_steps={max_steps}"
        #             )
        #         # If global_step or max_steps is not provided, use the initial temperature (it is desirable to output the log only for the first time or when necessary)

        #     perm_weights = torch.softmax(self.perm_logits / current_temp, dim=0)  # Softmax with temperature
        #     sum_perm = torch.zeros(
        #         (target_len, target_len), device=hidden_states_unpermuted.device, dtype=hidden_states_unpermuted.dtype
        #     )

        #     # Note: If the size L of self.Ps[i] (currently 6) and target_len do not match, problems may occur in the following processing.
        #     # The size of P needs to be target_len x target_len.
        #     # You need to either make L of make_perm_family dynamic or slice P.
        #     # Example: P_resized = P[:target_len, :target_len]
        #     # Since the instruction this time is to change softmax, this part will not be changed, but it is a potential problem.
        #     for i in range(len(self.Ps)):
        #         P = self.Ps[i].to(device)  # It is certain to do .to(device) again here
        #         # It is necessary to handle the case where target_len and P.shape[0] (or P.shape[1]) are different
        #         if P.shape[0] != target_len or P.shape[1] != target_len:
        #             # If P is larger, handle it by slicing, etc.
        #             # logger.warning(f"Permutation matrix P (shape: {P.shape}) size mismatch with target_len ({target_len}). Slicing P.")
        #             # P_adjusted = P[:target_len, :target_len]
        #             # If P is smaller, an error or processing such as padding is required. Currently, L=6 is fixed.
        #             # Here, for simplicity, if P is larger than target_len, it will be sliced, and if it is smaller, it will be an error.
        #             if P.shape[0] < target_len or P.shape[1] < target_len:
        #                 logger.error(
        #                     f"CRITICAL: Permutation matrix P (shape: {P.shape}) is smaller than target_len ({target_len}). This will cause an error."
        #                 )
        #                 # Cause an error or some kind of fallback processing
        #                 # Here, we will not continue the process and will cause an error (it should be a dimension mismatch error when adding to sum_perm)
        #             P_adjusted = P[:target_len, :target_len]  # Slice if P is large
        #         else:
        #             P_adjusted = P
        #         sum_perm += perm_weights[i] * P_adjusted  # Use only the relevant part of P
        # progress = min(float(global_step) / max_steps, 1.0)
        # Linearly decrease the temperature
        # current_temp = self.initial_temperature - (self.initial_temperature - self.final_temperature) * progress
        # Clip the temperature so that it does not fall below final_temperature and does not become extremely small
        # current_temp = max(current_temp, self.final_temperature)
        # current_temp = max(current_temp, 1e-6)  # Prevent division by zero or numerical instability
        y_hard, idx = sample_perm_index(self.perm_logits, tau=1.0, training=True)
        # sum_perm = self.Ps[idx].to(device)  # It is certain to do .to(device) again here
        sum_perm = torch.zeros((target_len, target_len), device=device)
        for i in range(len(self.Ps)):
            P = self.Ps[i].to(device)
            sum_perm += y_hard[i] * P  # Use only the relevant part of P
            
        # breakpoint()

        self.permutation = sum_perm
        P_target_perm = self.permutation

        # --- Apply Permutation Selectively ---
        if self.permutation is not None and P_target_perm is not None and target_len > 0:
            # 1. Extract segments
            input_embeds = hidden_states_unpermuted[:, :input_len, :]  # batch size, input_length, hidden_dim
            target_embeds_unpermuted = hidden_states_unpermuted[:, input_len:-1, :]  # Target is between input and EOS
            eos_embeds = hidden_states_unpermuted[:, -1:, :]  # Keep dimension for concat

            # 2. Extract the submatrix of P corresponding to the target tokens
            # P_target = self.permutation # self.permutation is already P_target_perm (target_len x target_len)
            # No, P_target here should be self.permutation as it's already the correct matrix for the target
            P_for_bmm = self.permutation  # Which is P_target_perm of size (target_len, target_len)

            # breakpoint()
            # 3. Permute only the target segment
            # P_target acts on the seq_len dimension (dim=1) of target_embeds_unpermuted.
            # This is equivalent to a batch matrix multiplication.
            # P_target needs to be expanded to (batch_size, target_len, target_len) for bmm.
            P_expanded = P_for_bmm.unsqueeze(0).expand(batch_size, target_len, target_len)
            target_embeds_permuted = torch.bmm(P_expanded, target_embeds_unpermuted)

            # 4. Concatenate segments back together
            hidden_states_permuted = torch.cat([input_embeds, target_embeds_permuted, eos_embeds], dim=1)
            # breakpoint()
        else:
            # If sequence is too short (no target tokens), pass embeddings through without permutation
            logger.warning(
                f"Sequence length ({seq_len}) is too short to separate input/target/EOS. Skipping permutation."
            )
            hidden_states_permuted = hidden_states_unpermuted

        # breakpoint()

        # --- Original Transformer Call ---
        # Pass the selectively permuted embeddings to the transformer
        transformer_outputs = self.transformer(
            input_ids=None,  # Use embeddings instead
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=hidden_states_permuted,  # Use the modified embeddings
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states_from_transformer = transformer_outputs[0]

        # --- Apply Inverse Permutation to Hidden States --- (Moved from logits)
        hidden_states_final_order = hidden_states_from_transformer  # Default if no permutation or error
        if (
            self.permutation is not None and P_target_perm is not None and target_len > 0
        ):  # Check if permutation was applied
            try:
                # 1. Extract segments from hidden_states_from_transformer
                hs_input = hidden_states_from_transformer[:, :input_len, :]  # batch_size, input_length, hidden_dim
                hs_target_permuted = hidden_states_from_transformer[
                    :, input_len:-1, :
                ]  # batch_size, target_len, hidden_dim
                hs_eos = hidden_states_from_transformer[:, -1:, :]  # batch_size, 1, hidden_dim

                # 2. Get the inverse (transpose) of the permutation matrix P
                # target_len was calculated earlier. hs_target_permuted.shape[1] should be the same.
                # self.permutation here is P_target_perm (target_len x target_len)
                P_inv = self.permutation.T  # P_target_perm is already (target_len, target_len)

                # 3. Apply inverse permutation to the target hidden states
                batch_size_hs = hs_target_permuted.shape[0]
                P_inv_expanded = P_inv.unsqueeze(0).expand(batch_size_hs, target_len, target_len)
                hs_target_original = torch.bmm(P_inv_expanded, hs_target_permuted)

                # 4. Concatenate segments back together
                hidden_states_final_order = torch.cat([hs_input, hs_target_original, hs_eos], dim=1)
            except Exception as e:
                logger.error(
                    f"Error during inverse permutation of hidden states: {e}. Using permuted hidden states.",
                    exc_info=True,
                )
                # Fallback to using the permuted hidden states if error occurs
                hidden_states_final_order = hidden_states_from_transformer
        elif self.permutation is None and seq_len > input_len + 1:
            logger.warning(
                "Sequence length suggests permutation should have occurred for hidden states, but self.permutation is None. Using original hidden states order from transformer."
            )
            hidden_states_final_order = hidden_states_from_transformer

        # Now, calculate lm_logits using the (potentially) reordered hidden states
        lm_logits_final = self.lm_head(hidden_states_final_order)

        # --- Original Logits and Loss Calculation --- (No change here, uses lm_logits_final)
        loss = None
        if labels is not None:
            # Use the potentially reordered logits (lm_logits_final) for loss calculation
            # Shift so that tokens < n predict token n
            shift_logits = lm_logits_final[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            # Use lm_logits_final for the output tuple as well
            output = (lm_logits_final,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits_final,  # Return the potentially reordered logits
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,  # Original hidden_states from transformer (permuted if input was permuted)
            attentions=transformer_outputs.attentions,  # Make sure attention is included
            cross_attentions=transformer_outputs.cross_attentions,
        )


# Option: Add Sinkhorn normalization function as needed
def sinkhorn(P, n_iters=10):
    # Ensure working with float for stability
    P_float = P.float()
    for _ in range(n_iters):
        # Normalize rows
        row_sum = P_float.sum(dim=1, keepdim=True)
        # Add epsilon to avoid division by zero
        P_float = P_float / (row_sum + 1e-9)

        # Normalize columns
        col_sum = P_float.sum(dim=0, keepdim=True)
        # Add epsilon to avoid division by zero
        P_float = P_float / (col_sum + 1e-9)
    # Return in the original dtype if needed, or keep as float
    return P_float.to(P.dtype)


def make_perm_family(L: int, n: int) -> torch.Tensor:
    """
    Returns 2^n permutation matrices (L x L).
    Always includes Id (pattern=0) and Rev (pattern=2^{n-1}).
    Prerequisite: 2**(n-1) <= L
    """
    assert 2 ** (n - 1) <= L, "n is too large (shifts overlap)."

    step = max(1, L // (2 ** (n - 1)))  # Step of Delta
    perm_mats = []

    for r in (0, 1):  # Inversion flag
        for s in range(2 ** (n - 1)):  # Shift sign
            shift = (s * step) % L
            idx = [(L - 1 - ((i + shift) % L)) if r else ((i + shift) % L) for i in range(L)]
            perm_mats.append(torch.eye(L)[idx])  # one-hot row
    return torch.stack(perm_mats)


def sample_perm_index(logits, tau=1.0, training=True):
    """
    logits: (K,) learnable parameters for K permutation candidates
    tau   : temperature; 1->smooth, gradually decrease to around 0.05 to approach one-hot
    """
    if training:
        # gumbel = -torch.empty_like(logits).exponential_().log()
        y = (logits) / tau
        y_soft = y.softmax(dim=-1)  # For gradient
        idx = y_soft.argmax(dim=-1)  # Permutation to actually use
        y_hard = torch.zeros_like(y_soft)
        y_hard[idx] = 1.0
        # Straight-Through: forward is hard, backward is soft
        return (y_hard - y_soft).detach() + y_soft, idx
        # return y_soft, idx
    else:
        idx = logits.argmax(dim=-1)
        y_hard = torch.zeros_like(logits).scatter_(0, idx, 1.0)
        return y_hard, idx