import dataclasses
import enum
import typing
import unittest

import torch

from sglang.srt.sampling.penaltylib.orchestrator import (
    BatchedPenalizerOrchestrator,
    _BatchedPenalizer,
    _BatchLike,
)


@dataclasses.dataclass
class MockSamplingParams:
    frequency_penalty: float = 0.0
    min_new_tokens: int = 0
    stop_token_ids: typing.List[int] = None
    presence_penalty: float = 0.0
    repetition_penalty: float = 1.0


@dataclasses.dataclass
class MockTokenizer:
    eos_token_id: int


@dataclasses.dataclass
class MockReq:
    origin_input_ids: typing.List[int]
    sampling_params: MockSamplingParams
    tokenizer: MockTokenizer


class StepType(enum.Enum):
    INPUT = "input"
    OUTPUT = "output"


@dataclasses.dataclass
class Step:
    type: StepType
    token_ids: typing.List[int]
    expected_tensors: typing.Dict[str, torch.Tensor]
    # assume initial logits are all 1
    expected_logits: torch.Tensor


@dataclasses.dataclass
class Subject:
    sampling_params: MockSamplingParams
    # first step must be input, which will be converted to Req
    steps: typing.List[Step]
    eos_token_id: int = -1

    def __post_init__(self):
        if self.steps[0].type != StepType.INPUT:
            raise ValueError("First step must be input")

        # each steps should have the same expected_tensors.keys()
        for i in range(1, len(self.steps)):
            if self.tensor_keys(i) != self.tensor_keys():
                raise ValueError(
                    f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
                )

    def tensor_keys(self, i: int = 0) -> typing.Set[str]:
        return set(self.steps[i].expected_tensors.keys())

    def to_req(self) -> MockReq:
        return MockReq(
            origin_input_ids=self.steps[0].token_ids,
            sampling_params=self.sampling_params,
            tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
        )


@dataclasses.dataclass
class Case:
    enabled: bool
    test_subjects: typing.List[Subject]

    def __post_init__(self):
        # each test_subjects.steps should have the same expected_tensors.keys()
        for i in range(1, len(self.test_subjects)):
            if self.tensor_keys(i) != self.tensor_keys():
                raise ValueError(
                    f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
                )

    def tensor_keys(self, i: int = 0) -> typing.List[str]:
        return set(self.test_subjects[i].tensor_keys())


class BaseBatchedPenalizerTest(unittest.TestCase):
    Penalizer: typing.Type[_BatchedPenalizer]
    device = "cuda"
    vocab_size = 5

    enabled: Subject = None
    disabled: Subject = None

    def setUp(self):
        if self.__class__ == BaseBatchedPenalizerTest:
            self.skipTest("Base class for penalizer tests")

        self.create_test_subjects()
        self.create_test_cases()

    def tensor(self, data, **kwargs) -> torch.Tensor:
        """
        Shortcut to create a tensor with device=self.device.
        """
        return torch.tensor(data, **kwargs, device=self.device)

    def create_test_subjects(self) -> typing.List[Subject]:
        raise NotImplementedError()

    def create_test_cases(self):
        self.test_cases = [
            Case(enabled=True, test_subjects=[self.enabled]),
            Case(enabled=False, test_subjects=[self.disabled]),
            Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
        ]

    def _create_penalizer(
        self, case: Case
    ) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
        orchestrator = BatchedPenalizerOrchestrator(
            vocab_size=self.vocab_size,
            batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
            device=self.device,
            Penalizers={self.Penalizer},
        )

        return orchestrator, orchestrator.penalizers[self.Penalizer]

    def test_is_required(self):
        for case in self.test_cases:
            with self.subTest(case=case):
                _, penalizer = self._create_penalizer(case)
                self.assertEqual(case.enabled, penalizer.is_required())

    def test_prepare(self):
        for case in self.test_cases:
            with self.subTest(case=case):
                orchestrator, penalizer = self._create_penalizer(case)
                self.assertEqual(case.enabled, penalizer.is_prepared())

                if case.enabled:
                    for key, tensor in {
                        key: torch.cat(
                            tensors=[
                                subject.steps[0].expected_tensors[key]
                                for subject in case.test_subjects
                            ],
                        )
                        for key in case.tensor_keys()
                    }.items():
                        torch.testing.assert_close(
                            actual=getattr(penalizer, key),
                            expected=tensor,
                            msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
                        )

                actual = orchestrator.apply(
                    torch.ones(
                        size=(len(case.test_subjects), self.vocab_size),
                        dtype=torch.float32,
                        device=self.device,
                    )
                )
                expected = torch.cat(
                    tensors=[
                        subject.steps[0].expected_logits
                        for subject in case.test_subjects
                    ],
                )
                torch.testing.assert_close(
                    actual=actual,
                    expected=expected,
                    msg=f"logits\nactual={actual}\nexpected={expected}",
                )

    def test_teardown(self):
        for case in self.test_cases:
            with self.subTest(case=case):
                _, penalizer = self._create_penalizer(case)
                penalizer.teardown()

                for key in case.test_subjects[0].steps[0].expected_tensors.keys():
                    self.assertIsNone(getattr(penalizer, key, None))

    def test_filter(self):
        for case in self.test_cases:
            with self.subTest(case=case):
                orchestrator, penalizer = self._create_penalizer(case)

                indices_to_keep = [0]
                orchestrator.filter(indices_to_keep=indices_to_keep)

                filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]

                if penalizer.is_required():
                    self.assertTrue(penalizer.is_prepared())
                    for key, tensor in {
                        key: torch.cat(
                            tensors=[
                                subject.steps[0].expected_tensors[key]
                                for subject in filtered_subjects
                            ],
                        )
                        for key in case.tensor_keys()
                    }.items():
                        torch.testing.assert_close(
                            actual=getattr(penalizer, key),
                            expected=tensor,
                            msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
                        )

                actual_logits = orchestrator.apply(
                    torch.ones(
                        size=(len(filtered_subjects), self.vocab_size),
                        dtype=torch.float32,
                        device=self.device,
                    )
                )
                filtered_expected_logits = torch.cat(
                    tensors=[
                        subject.steps[0].expected_logits
                        for subject in filtered_subjects
                    ],
                )
                torch.testing.assert_close(
                    actual=actual_logits,
                    expected=filtered_expected_logits,
                    msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
                )

    def test_merge_enabled_with_disabled(self):
        enabled_test_case = self.test_cases[0]
        disabled_test_case = self.test_cases[1]

        orchestrator, penalizer = self._create_penalizer(enabled_test_case)
        theirs, _ = self._create_penalizer(disabled_test_case)

        orchestrator.merge(theirs)

        for key, tensor in {
            key: torch.cat(
                tensors=[
                    enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
                    disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
                ],
            )
            for key in enabled_test_case.tensor_keys()
        }.items():
            torch.testing.assert_close(
                actual=getattr(penalizer, key),
                expected=tensor,
                msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
            )

    def test_cumulate_apply_repeat(self):
        for case in self.test_cases:
            with self.subTest(case=case):
                orchestrator, penalizer = self._create_penalizer(case)

                max_step = max(len(subject.steps) for subject in case.test_subjects)
                for i in range(1, max_step):
                    orchestrator.filter(
                        indices_to_keep=[
                            j
                            for j, subject in enumerate(case.test_subjects)
                            if i < len(subject.steps)
                        ]
                    )

                    filtered_subjects = [
                        subject
                        for subject in case.test_subjects
                        if i < len(subject.steps)
                    ]

                    inputs: typing.List[typing.List[int]] = []
                    outputs: typing.List[typing.List[int]] = []
                    for subject in filtered_subjects:
                        step = subject.steps[i]
                        if step.type == StepType.INPUT:
                            inputs.append(step.token_ids)
                            outputs.append([])
                        else:
                            inputs.append([])
                            outputs.append(step.token_ids)

                    if any(inputs):
                        orchestrator.cumulate_input_tokens(inputs)

                    if any(outputs):
                        orchestrator.cumulate_output_tokens(outputs)

                    if penalizer.is_required():
                        self.assertTrue(penalizer.is_prepared())
                        for key, tensor in {
                            key: torch.cat(
                                tensors=[
                                    subject.steps[i].expected_tensors[key]
                                    for subject in filtered_subjects
                                ],
                            )
                            for key in case.tensor_keys()
                        }.items():
                            torch.testing.assert_close(
                                actual=getattr(penalizer, key),
                                expected=tensor,
                                msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
                            )

                    actual_logits = orchestrator.apply(
                        torch.ones(
                            size=(len(filtered_subjects), self.vocab_size),
                            dtype=torch.float32,
                            device=self.device,
                        )
                    )
                    filtered_expected_logits = torch.cat(
                        tensors=[
                            subject.steps[i].expected_logits
                            for subject in filtered_subjects
                        ],
                    )
                    torch.testing.assert_close(
                        actual=actual_logits,
                        expected=filtered_expected_logits,
                        msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
                    )
