import torch
import numpy as np
import typing
import dataclasses as dc
import itertools
import random
from typing import Any, TypeVar, Sequence, Literal, cast

from core.reasoning import CoT, Reasoner, RewardModel
from core.reasoning.evaluator import CumulatedReward
from core.utils.th import NestedTensorDict
from core.utils import iterate
from core.utils.buf import TokenBuffer
from core.reasoning.reflection.self_verify import KwReflCumrewEvaluator


@typing.overload
def randint_of_digit(d: int) -> int: ...
@typing.overload
def randint_of_digit(d: int, size: int) -> list[int]: ...
def randint_of_digit(d: int, size: int | None = None) -> int | list[int]:
    """sample a random integer with `d` digits"""
    if d == 1:
        low = 0
        high = 10
    elif d > 1:
        low = 10**(d - 1)
        high = 10**d
    else:
        raise ValueError("digit `d` must be no less than 1")
    
    out = np.random.randint(low, high, size=size)
    if size is None:
        out = np.random.randint(low, high)
        assert isinstance(out, int)
    else:
        out = np.random.randint(low, high, size=size).tolist()
        assert isinstance(out, list) and (not out or isinstance(out[0], int))
    return out


type In = tuple[int, int]


@dc.dataclass
class MultCoT[Thought](CoT[In, Thought, int]):

    digits: In | None = None
    correct: bool | None = None

    def __post_init__(self):
        
        if not isinstance(self.input, tuple):
            self.input = tuple(self.input)

        x, y = self.input

        if self.digits is None:
            self.digits = len(str(x)), len(str(y))
        elif not isinstance(self.digits, tuple):
            self.digits = tuple(self.digits)
        if self.correct is None:
            self.correct = (self.outcome == x * y)


_MultThought = TypeVar('_MultThought')
Multiplier = Reasoner[In, _MultThought, int]


def random_input(dxrange: Sequence[int] | int, dyrange: Sequence[int] | int | None = None):
    if dyrange is None:
        dyrange = dxrange
    dx = dxrange if isinstance(dxrange, int) else random.choice(dxrange)
    dy = dyrange if isinstance(dyrange, int) else random.choice(dyrange)
    x = randint_of_digit(dx)
    y = randint_of_digit(dy)
    return x, y


def generate_instances(
    multiplier: Multiplier,
    d1: int,
    d2: int,
    n: int,
    verbose = 0,
) -> list[MultCoT]:
    """generate n samples"""

    samples: list[MultCoT] = []
    n_correct = 0

    xs = randint_of_digit(d1, size=n)
    ys = randint_of_digit(d2, size=n)

    for i, xy in enumerate(zip(xs, ys, strict=True)):
        sample = MultCoT(xy, *multiplier(xy))
        assert isinstance(sample, MultCoT)
        samples.append(sample)
        if sample.correct:
            n_correct += 1
        if verbose >= 1:
            print("Generated %d / %d samples of %dx%d multiplication, with %d correct answers from the reasoner." % (i+1, n, d1, d2, n_correct), end='\r')

    if verbose >= 1:
        print("Generated %d / %d samples of %dx%d multiplication, with %d correct answers from the reasoner." % (n, n, d1, d2, n_correct))

    return samples


class MultEvaluator(KwReflCumrewEvaluator):

    def _kwmap(self, **references):
        d = references["digits"]
        assert isinstance(d, tuple)
        return ("%dx%d" % d, "all")
