

import os

import torch
import torch.distributed as dist

try:
    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
except ImportError:

    import sys

    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")))
finally:
    from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy

import verl.utils.torch_functional as verl_F

compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)

MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)
VERIFY_TORCH_SELF = os.environ.get("VERIFY_TORCH_SELF", False)
LOW_MEMORY = os.environ.get("LOW_MEMORY", False)
LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16)

def run_torch_entropy(
    hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
) -> list[torch.Tensor]:

    if len(hidden.shape) > 2:
        hidden = hidden.view(-1, hidden.shape[-1])
    if len(labels.shape) > 1:
        labels = labels.view(-1)
    logits = torch.matmul(
        hidden.to(torch.float32),
        weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32),
    )
    logits /= temperature
    pd = torch.nn.functional.softmax(logits, dim=-1)
    entropy_a = torch.logsumexp(logits, dim=-1)
    entropy_b = torch.sum(pd * logits, dim=-1)
    entropy = entropy_a - entropy_b
    logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction)
    logprobs = torch.neg(logprobs)
    return logprobs, entropy

class TorchEntropyTP(torch.autograd.Function):

    @staticmethod
    def forward(
        ctx,
        hidden: torch.Tensor,
        weight: torch.Tensor,
        labels: torch.Tensor,
        temperature: float,
        dist_process_group: torch.distributed.ProcessGroup,
    ):

        ctx.original_hidden_shape = hidden.shape
        if len(hidden.shape) > 2:
            hidden = hidden.view(-1, hidden.shape[-1])
        if len(labels.shape) > 1:
            labels = labels.view(-1)

        logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T)
        logits /= temperature
        whole_logits = torch.empty(
            (logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)),
            dtype=logits.dtype,
            device=logits.device,
        )
        whole_logits_ref = [
            whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]]
            for i in range(dist.get_world_size(dist_process_group))
        ]
        dist.all_gather(whole_logits_ref, logits, group=dist_process_group)

        pd = torch.nn.functional.softmax(whole_logits, dim=-1)
        entropy_a = torch.logsumexp(whole_logits, dim=-1)
        entropy_b = torch.sum(pd * whole_logits, dim=-1)
        entropy = entropy_a - entropy_b

        logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none")
        logprobs = torch.neg(logprobs)

        ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b)
        ctx.dist_process_group = dist_process_group
        ctx.temperature = temperature
        return logprobs, entropy

    @staticmethod
    def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor):
        hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors
        dist_process_group = ctx.dist_process_group
        temperature = ctx.temperature
        batch_size, hidden_size = hidden.shape
        vocab_size, hidden_size = weight.shape
        rank = dist.get_rank(dist_process_group)

        maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True)
        exp_logits = torch.exp(whole_logits - maximum)
        accumulate = exp_logits.sum(dim=-1, keepdim=True)
        pd = exp_logits / accumulate

        d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1)))

        one_hot = torch.zeros_like(whole_logits)
        one_hot.scatter_(1, labels.unsqueeze(1), 1)
        g_logprobs = torch.neg(g_logprobs)
        d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot)

        d_logits = d_logits_entropy + d_logits_logprobs
        d_logits /= temperature

        local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size]

        d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32))
        d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32))
        d_hidden = d_hidden.view(ctx.original_hidden_shape)

        return d_hidden, d_weight, None, None, None

run_torch_entropy_tp = TorchEntropyTP.apply

class TestLinearCrossEntropy_TensorParallel:
    def __init__(self):
        dist.init_process_group(backend="nccl")
        self.group = dist.group.WORLD

        self.local_rank = dist.get_rank(self.group)
        self.world_size = dist.get_world_size(self.group)
        device = torch.device(f"cuda:{self.local_rank}")
        torch.cuda.set_device(device)
        print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}")

    def initialize(self, test_case_idx: int, temperature: float = 1.5):
        self.test_case_idx = test_case_idx
        self.temperature = temperature

    def shutdown(self):
        dist.destroy_process_group()

    def cleanup(self):
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        import gc

        gc.collect()
        torch.cuda.synchronize()

    def generate_hyper(self):
        global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES

        self.dtype = torch.bfloat16
        if self.test_case_idx == 0:
            self.batch_size = 1
            self.num_tokens = 1937
            self.hidden_size = 3584
            self.vocab_size = 152064
        elif self.test_case_idx == 1:
            self.batch_size = 1
            self.num_tokens = 2169
            self.hidden_size = 896
            self.vocab_size = 151936
        elif self.test_case_idx == 2:
            self.batch_size = 1
            self.num_tokens = 1530
            self.hidden_size = 2048
            self.vocab_size = 32256
        elif self.test_case_idx == 3:
            self.batch_size = 1
            self.num_tokens = 1388
            self.hidden_size = 4096
            self.vocab_size = 102400
        elif self.test_case_idx == 4:
            self.batch_size = 1
            self.num_tokens = 8192
            self.hidden_size = 4096
            self.vocab_size = 102400
        else:
            raise ValueError(f"Invalid test case index: {self.test_case_idx}")
        if LOW_MEMORY:
            self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR)
        assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5."

    def generate_forward_inputs(self):
        hidden = (
            torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda")
            .uniform_(-0.5, 0.5)
            .requires_grad_()
        )
        weight = (
            torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda")
            .uniform_(-0.5, 0.5)
            .requires_grad_()
        )
        labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda")
        return hidden, weight, labels

    def generate_backward_inputs(self):
        g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)
        g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)
        return g_entropy, g_logprobs

    def verify_torch_itself(self, iterations: int = 5):
        self.cleanup()
        self.generate_hyper()

        for i in range(iterations):
            hidden, weight, labels = self.generate_forward_inputs()

            dist.broadcast(hidden, src=0, group=self.group)
            dist.broadcast(labels, src=0, group=self.group)

            whole_weight = torch.empty(
                (self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device
            )

            whole_weight_views = [
                whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size)
            ]

            dist.all_gather(whole_weight_views, weight, group=self.group)

            whole_weight.requires_grad_()

            (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature)

            (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)

            torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4)
            torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4)

            g_entropy, g_logprobs = self.generate_backward_inputs()

            dist.broadcast(g_entropy, src=0, group=self.group)
            dist.broadcast(g_logprobs, src=0, group=self.group)

            (single_d_hidden, single_d_weight) = torch.autograd.grad(
                (single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False
            )

            (tp_d_hidden, tp_d_weight) = torch.autograd.grad(
                (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
            )

            dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group)

            torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4)

            torch.testing.assert_close(
                tp_d_weight,
                single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size],
                atol=1e-2,
                rtol=1e-4,
            )

        if self.local_rank == 0:
            print("[PASS] torch TP correctness is verified")

    def check_torch_storage(self):
        self.cleanup()
        self.generate_hyper()

        hidden, weight, labels = self.generate_forward_inputs()

        dist.broadcast(hidden, src=0, group=self.group)
        dist.broadcast(labels, src=0, group=self.group)

        torch.cuda.reset_peak_memory_stats()
        (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
        torch.cuda.synchronize()
        forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024

        g_entropy, g_logprobs = self.generate_backward_inputs()

        dist.broadcast(g_entropy, src=0, group=self.group)
        dist.broadcast(g_logprobs, src=0, group=self.group)

        torch.cuda.reset_peak_memory_stats()
        (d_tp_hidden, d_tp_weight) = torch.autograd.grad(
            (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
        )
        torch.cuda.synchronize()
        backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024

        dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group)

        if self.local_rank == 0:
            print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB")
            print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB")

    def verify_kernel_correctness(self, iterations: int = 5):
        self.cleanup()
        self.generate_hyper()

        torch_forward_latency = list()
        torch_backward_latency = list()
        kernel_forward_latency = list()
        kernel_backward_latency = list()

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        for i in range(iterations):
            hidden, weight, labels = self.generate_forward_inputs()

            dist.broadcast(hidden, src=0, group=self.group)
            dist.broadcast(labels, src=0, group=self.group)

            start_event.record()
            (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
            end_event.record()
            torch.cuda.synchronize()
            torch_forward_latency.append(start_event.elapsed_time(end_event))

            start_event.record()
            (kernel_logprobs, kernel_entropy) = linear_cross_entropy(
                hidden, weight, labels, self.temperature, "none", self.group
            )
            end_event.record()
            torch.cuda.synchronize()
            kernel_forward_latency.append(start_event.elapsed_time(end_event))

            torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2)
            torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2)

            g_entropy, g_logprobs = self.generate_backward_inputs()

            dist.broadcast(g_entropy, src=0, group=self.group)
            dist.broadcast(g_logprobs, src=0, group=self.group)

            start_event.record()
            (torch_d_hidden, torch_d_weight) = torch.autograd.grad(
                (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
            )
            end_event.record()
            torch.cuda.synchronize()
            torch_backward_latency.append(start_event.elapsed_time(end_event))

            dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group)

            start_event.record()
            (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad(
                (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
            )
            end_event.record()
            torch.cuda.synchronize()
            kernel_backward_latency.append(start_event.elapsed_time(end_event))

            dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group)

            torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2)
            torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2)

        torch_forward_latency = torch_forward_latency[1:]
        torch_backward_latency = torch_backward_latency[1:]
        kernel_forward_latency = kernel_forward_latency[1:]
        kernel_backward_latency = kernel_backward_latency[1:]

        if self.local_rank == 0:
            print("\n[PASS]: Verified kernel forward & backward correctness.")

            print(
                f"[INFO]: Forward pass: Torch implementation average time: "
                f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms"
            )
            print(
                f"[INFO]: Backward pass: torch implementation average time: "
                f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms"
            )
            print(
                f"[INFO]: Forward pass: Kernel implementation average time: "
                f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms"
            )
            print(
                f"[INFO]: Backward pass: kernel implementation average time: "
                f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms"
            )

    def check_kernel_storage(self):
        self.cleanup()
        self.generate_hyper()

        hidden, weight, labels = self.generate_forward_inputs()

        dist.broadcast(hidden, src=0, group=self.group)
        dist.broadcast(labels, src=0, group=self.group)

        torch.cuda.reset_peak_memory_stats()
        (kernel_logprobs, kernel_entropy) = linear_cross_entropy(
            hidden, weight, labels, self.temperature, "none", self.group
        )
        torch.cuda.synchronize()
        kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024

        g_entropy, g_logprobs = self.generate_backward_inputs()

        dist.broadcast(g_entropy, src=0, group=self.group)
        dist.broadcast(g_logprobs, src=0, group=self.group)

        torch.cuda.reset_peak_memory_stats()
        (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(
            (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
        )
        torch.cuda.synchronize()
        kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024

        dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group)

        if self.local_rank == 0:
            print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB")
            print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB")

if __name__ == "__main__":

    assert int(os.environ["WORLD_SIZE"]) > 1, (
        "[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to "
        "execute this script."
    )
    torch.manual_seed(233376 + int(os.environ.get("RANK", 0)))

    test = TestLinearCrossEntropy_TensorParallel()
    for test_case_idx in range(MAX_TEST_CASES):
        print(f"[INFO] Running test case {test_case_idx}")
        test.initialize(test_case_idx)
        if VERIFY_TORCH_SELF:
            test.verify_torch_itself()
        test.check_torch_storage()
        test.verify_kernel_correctness()
        test.check_kernel_storage()

    test.shutdown()
