# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import torch
import torch.distributed as dist
from utils.synchronize import synchronize
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput


class MyHFConfig(PretrainedConfig):
    def __init__(self, vocab_size, pad_token_id=None, bos_token_id=None, eos_token_id=None, **kwargs):
        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
        self.model_type = "my_gpt"
        self.vocab_size = vocab_size


class MyHFWrapper(PreTrainedModel, GenerationMixin):
    def __init__(self, model, batch_size_fwd, num_class, max_allowed_num_token, hf_config):
        super().__init__(hf_config)
        self.config_class = MyHFConfig
        self.main_input_name = "input_ids"
        self.model = model
        self.batch_size_fwd = batch_size_fwd
        self.num_class = num_class
        self.max_allowed_num_token = max_allowed_num_token
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        self.inner_model = model

    def forward(self, input_ids, **kwargs):
        # In:  (num_sample, num_token); int64; contiguous; detached
        # Out: (num_sample, num_token, num_class); float32; contiguous; detached

        # Get `num_sample` and `num_token`
        num_sample, num_token = input_ids.shape
        # Assumption: 1 <= num_sample <= (world_size * batch_size_fwd)
        assert 1 <= num_sample <= (self.world_size * self.batch_size_fwd)
        # Assumption: 1 <= num_token <= max_allowed_num_token
        assert 1 <= num_token <= self.max_allowed_num_token

        # Step 1: Synchronize across ranks
        synchronize()

        # Step 1.5: Pad `input_ids` on rank 0
        if self.rank == 0:
            # (world_size * batch_size_fwd, num_token); int64; contiguous; detached
            input_ids_padded = torch.zeros(
                size=(self.world_size * self.batch_size_fwd, num_token),
                dtype=input_ids.dtype,
                device=input_ids.device,
            )
            input_ids_padded[:num_sample] = input_ids
        else:
            input_ids_padded = None

        # Step 2: Split the work on rank 0 and scatter to each rank
        if self.rank == 0:
            scatter_list = list(input_ids_padded.chunk(self.world_size, dim=0))
        else:
            scatter_list = None
        local_input = torch.empty(
            size=(self.batch_size_fwd, num_token),
            dtype=input_ids.dtype,
            device=input_ids.device,
        )
        dist.scatter(local_input, scatter_list=scatter_list, src=0)

        # Step 3: Compute on each rank
        # (batch_size, num_token, num_class); float32; contiguous; detached
        with torch.no_grad():
            local_logits = self.inner_model(x=local_input, y=None)

        # Step 4: Gather to rank 0
        if self.rank == 0:
            gather_list = [torch.empty_like(local_logits) for _ in range(self.world_size)]
        else:
            gather_list = None
        dist.gather(local_logits, gather_list=gather_list, dst=0)

        # Step 5: Construct the output
        if self.rank == 0:
            # Remove padding on rank 0
            # (num_sample, num_token, num_class); float32; contiguous; detached
            logits = torch.cat(gather_list, dim=0)[:num_sample].contiguous()
        else:
            # Construct a zero tensor for other ranks
            logits = torch.zeros(
                num_sample, num_token, self.num_class,
                dtype=local_logits.dtype, device=local_logits.device
            )
        out = CausalLMOutput(logits=logits)

        # Step 6: Synchronize across ranks
        synchronize()
        return out

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
        return {"input_ids": input_ids}

    def _init_weights(self, *args, **kwargs):
        pass
