"""Custom trainer for binary token classification."""

import torch.nn.functional as F
from transformers import Trainer
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch
from .sampler import GroupedRandomSampler

class BinaryTokenRewardTrainer(Trainer):
    """Custom trainer for binary token classification.

    Expects labels with values in {0, 1} and -100 for ignored positions.
    """

    def __init__(
        self,
        *trainer_args,
        config_args=None,
        num_completions_per_prompt: int = None,
        sampler_group_size: int = None,
        group_softmax: bool = False,
        sum_group_softmax: bool = False,
        group_size: int = None,
        **kwargs,
    ):
        super().__init__(*trainer_args, **kwargs)
        assert len(self.train_dataset) % sampler_group_size == 0
        assert self.args.per_device_train_batch_size >= sampler_group_size
        self.config_args = config_args
        self.num_completions_per_prompt = num_completions_per_prompt
        self.sampler_group_size = sampler_group_size
        self.group_softmax = group_softmax
        self.sum_group_softmax = sum_group_softmax
        self.group_size = group_size
        
        self.train_sampler = GroupedRandomSampler(
            n=len(self.train_dataset),
            args=config_args,
            num_completions_per_prompt=num_completions_per_prompt,
            sampler_group_size=sampler_group_size,
        )



    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Custom loss function using binary cross entropy with logits.
        
        This overrides the default cross entropy loss to use 
        F.binary_cross_entropy_with_logits which is appropriate for
        binary classification with a single logit output.
        """
        if not self.group_softmax:
            labels = inputs.pop("labels") # pop so that the model doesn't compute the loss inside forward pass
            # Forward pass
            kwargs = {}
            if num_items_in_batch is not None:
                kwargs["num_items_in_batch"] = num_items_in_batch
            outputs = model(**inputs, **kwargs)
            logits = outputs.get("logits")

            # Remove last dimension for single logit binary classification
            logits = logits.squeeze(-1)
            
            # Create mask for non-ignored tokens (labels != -100)
            loss_mask = (labels != -100)

            # Only compute loss on non-ignored tokens
            active_logits = logits[loss_mask]
            active_labels = labels[loss_mask].float()
            
            # Cross entropy loss
            loss = F.binary_cross_entropy_with_logits(
                active_logits, 
                active_labels,
                reduction='mean',
            )
            
            return (loss, outputs) if return_outputs else loss
        elif self.group_softmax:
            labels = inputs.pop("labels") # [B, T]
            # Forward pass
            kwargs = {}
            if num_items_in_batch is not None:
                kwargs["num_items_in_batch"] = num_items_in_batch
            B, T, _ = inputs["inputs_embeds"].shape
            G = self.group_size
            N = B // G
            outputs = model(**inputs, **kwargs)
            logits = outputs.get("logits") # [B+N, T, 1]
            logits = logits.squeeze(-1) # [B+N, T]
            assert logits.shape[0] == B+N
            assert logits.shape[1] == T
            logits = logits.reshape(N, G+1, T)
            logits_class = logits[:,:G] # [N, G, T]
            logits_class = logits_class.view(N, G*T) # [N, G*T]
            logits_abstain = logits[:,G] # [N, T]
            bin_idx = inputs["bin_idx"] # [B, T]
            bin_idx = bin_idx.view(N, G*T) # [N, G*T]
            labels = labels.view(N, G*T) # [N, G*T]
            total_loss = 0.0
            for n in range(N):
                bin_idx_n = bin_idx[n] # [G*T]
                valid_pos = bin_idx_n != -100 # [G*T]
                bin_idx_valid = bin_idx_n[valid_pos] # [valid]
                logits_class_valid = logits_class[n][valid_pos] # [valid]
                labels_valid = labels[n][valid_pos] # [valid]
                logits_abstain_valid = logits_abstain[n][0] # [1]
                assert len(logits_class_valid) == G
                same_bins_edge_matrix = (bin_idx_valid[:,None] == bin_idx_valid[None,:]).to(logits_class_valid.dtype) # [valid, valid]
                if self.sum_group_softmax:
                    logits_class_valid_aggregated = torch.einsum("ab,b->a", same_bins_edge_matrix, logits_class_valid) # [valid, valid]
                else:
                    normalized_same_bins_edge_matrix = same_bins_edge_matrix / same_bins_edge_matrix.sum(dim=-1, keepdim=True) # [valid, valid]
                    logits_class_valid_aggregated = torch.einsum("ab,b->a", normalized_same_bins_edge_matrix, logits_class_valid) # [valid, valid]
                logits_class_valid = logits_class_valid_aggregated
                unique_bins = []
                unique_bins_logits = []
                unique_bins_labels = []

                for i, idx in enumerate(bin_idx_valid):
                    if idx.item() not in unique_bins:
                        unique_bins.append(idx.item())
                        unique_bins_logits.append(logits_class_valid[i])
                        unique_bins_labels.append(labels_valid[i])

                unique_bins_logits.append(logits_abstain_valid)
                unique_bins_labels.append(~torch.any(torch.stack(unique_bins_labels)))
                # if unique_bins_labels[-1] == 1:
                #     print("abstain label is 1.\n Hooray!!")
                unique_bins_logits = torch.stack(unique_bins_logits) # [num_unique_bins+1]
                unique_bins_labels = torch.stack(unique_bins_labels) # [num_unique_bins+1]
                unique_labels = torch.nonzero(unique_bins_labels==1) # [num_1_labels]
                losses = [F.cross_entropy(unique_bins_logits, unique_labels[i][0], reduction='mean') for i in range(unique_labels.shape[0])]
                loss = torch.mean(torch.stack(losses))
                total_loss += loss
            return (total_loss, outputs) if return_outputs else total_loss


    # ---------------------- use standard dataloader ----------------------
    def get_train_dataloader(self):
        # Check if a custom sampler has been set
        
        class SetEpochDataloader(DataLoader):
            def set_epoch(self, epoch: int):
                self.sampler.set_epoch(epoch)

        # Use the custom sampler
        train_dataloader = SetEpochDataloader(
            self.train_dataset,
            sampler=self.train_sampler,
            batch_size=self._train_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )
        
        self.train_sampler.set_epoch(self.state.epoch if self.state.epoch is not None else 0)
        
        return train_dataloader

def create_optimizer(model, args):
    """
    Create custom optimizer with separate learning rates for gating_lambdas and agent_embeddings.
    """
    # Separate parameters by type
    gating_lambda_params = []
    agent_embedding_params = []
    other_params = []
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'gating_lambdas' in name:
                gating_lambda_params.append(param)
            elif 'agent_embeddings' in name:
                agent_embedding_params.append(param)
            else:
                other_params.append(param)
    
    # Create parameter groups with different learning rates
    param_groups = []
    
    if other_params:
        param_groups.append({
            'params': other_params,
            'lr': args.lr
        })
        
    if gating_lambda_params:
        param_groups.append({
            'params': gating_lambda_params, 
            'lr': args.gating_lr if args.gating_lr is not None else args.lr,
        })
        
    if agent_embedding_params:
        param_groups.append({
            'params': agent_embedding_params,
            'lr': args.agent_lr if args.agent_lr is not None else args.lr,
        })
    
    # Ensure we have at least one parameter group
    if not param_groups:
        raise ValueError("No trainable parameters found in the model")
    
    optimizer = AdamW(
        param_groups,
    )
    
    return optimizer

    