import random
import torch
import string
import numpy

class ShuffledModularAddition:
    """
    A generator for long sequences of modular addition problems
    where the modulus is a random constant for the run, and where
    the assignment of symbols to numbers is also selected to be
    a fixed random map used for the whole run. Intended to
    mimic an in-context learning scenario.

    To test this class:

    from tasks.modular_addition import ShuffledModularAddition
    a = ShuffledModularAddition()
    print(a.stringify(a.sample_batch(batch_size=3, length=10)))

    The output should look like this:
    inputs:
    ^H+i=i,n+M=i,H+i=i,r+H=r,M+M=n,H+M=M,n+i=H,r+i=n,r+i=n,i+r=
    ^x+k=d,Q+e=e,D+y=X,s+X=T,C+Q=C,S+D=h,Q+y=y,T+g=I,c+X=d,C+C=
    ^o+m=m,m+m=c,m+o=m,c+K=m,m+m=c,c+o=c,o+m=m,o+o=o,c+m=K,K+o=
    targets:
    nkK
    """
    def __init__(self, max_number: int = 50, max_modulus: int = 10, zipf: bool = False,
            unshuffled_range: bool = False, start: bool = True, plus: bool = True,
            equals: bool = True, space: bool = True, seed: int = 42) -> None:
        assert(max_modulus * (2 if unshuffled_range else 1) < max_number)
        self.task_name = "shuffled_modular_addition"
        self.max_number = max_number
        self.max_modulus = max_modulus
        self.zipf = zipf
        self.unshuffled_range = unshuffled_range
        self.start = start
        self.plus = plus
        self.equals = equals
        self.space = space
        self.prng = random.Random(seed)
        self.nprng = numpy.random.RandomState(seed)

        # Setup the vocabulary
        digitletters = string.digits + string.ascii_letters
        self.vocab = {i: digitletters[i] for i in range(max_number)}
        self.predict_token_id = max_number
        self.plus_token_id = max_number + 1
        self.space_token_id = max_number + 2
        self.start_token_id = max_number + 3
        self.pad_token_id = max_number + 4
        self.vocab[self.predict_token_id] = "="
        self.vocab[self.plus_token_id] = "+"
        self.vocab[self.space_token_id] = ","
        self.vocab[self.start_token_id] = "^"
        self.vocab[self.pad_token_id] = " "
        self.vocab_size = len(self.vocab)

    def sample_batch(self, batch_size: int,
            k_shots: int = 600, hold_out: int = 0,
            multi_runs: bool = False,
            max_length: int = None,
            unshuffled: bool = False,
            fixed_modulus: int = None):
        # When packing multiple runs into a context, vary the number
        # of shots according to a geometric distribution with mean k_shots.
        if multi_runs:
            results = []
            for _ in range(batch_size):
                # A little random space before each run.
                cat_runs = [self.space_token_id] * self.prng.randint(0, 7)
                while len(cat_runs) < max_length:
                    k = self.nprng.geometric(1.0 / k_shots)
                    expression, m, v = self._sample_run(k_shots, hold_out, unshuffled, fixed_modulus)
                    cat_runs.extend(expression)
                results.append((cat_runs[-max_length:], m, v))
            expressions, m, v = zip(*results)
        else:
            expressions, m, v = zip(*[self._sample_run(k_shots, hold_out, unshuffled, fixed_modulus)
                for _ in range(batch_size)])

        expressions = torch.tensor(expressions, dtype=torch.long)

        # Every token is a goal
        return {
            "inputs": expressions[:,:-1],
            "targets": expressions[:,1:],
            "modulus": m,
            "vocabulary": v
        }

    def _extend(self, sequence, *symbols):
        for a in symbols:
            if a in ['^', '+', '=', ',']:
                if a == '^' and self.start: sequence.append(self.start_token_id)
                if a == '+' and self.plus: sequence.append(self.plus_token_id)
                if a == '=' and self.equals: sequence.append(self.predict_token_id)
                if a == ',' and self.space: sequence.append(self.space_token_id)
            else:
                sequence.append(a)

    def _sample_run(self, k_shots: int, hold_out: int = 0, unshuffled: bool = False, fixed_modulus: int = None):

        # Sample a modulus
        if fixed_modulus is not None:
            modulus = fixed_modulus
        else:
            modulus = self.prng.randint(2, self.max_modulus)

        # Select a random vocabulary
        if unshuffled:
            vocab = list(range(modulus))
        else:
            minrange = self.max_modulus if self.unshuffled_range else 0
            vocab = self.prng.sample(range(minrange, self.max_number), modulus)

        # Create fact table
        def weight(a, b):
            return 1.0 / (b + 3) if self.zipf else 1.0
        facts = [(a, b, (a + b) % modulus, weight(a, b))
                for a in range(modulus) for b in range(modulus)]

        # Hold out some facts
        for _ in range(hold_out):
            held_out = facts.pop(self.prng.randrange(0, len(facts)))

        # Sample random facts, ensuring held out example is last
        if hold_out:
            exemplars = self.prng.choices(facts, k=k_shots, weights=[f[-1] for f in facts]) + [
                    held_out]
        else:
            exemplars = self.prng.choices(facts, k=k_shots, weights=[f[-1] for f in facts])

        # Convert numbers to sequence
        sequence = []
        self._extend(sequence, '^')
        for a, b, c in exemplars:
            self._extend(sequence, vocab[a], '+', vocab[b], '=', vocab[c], ',')
        sequence.pop() # remove space to ensure last token is nontrivial
        return sequence, modulus, vocab

    def stringify(self, seq):
        if isinstance(seq, dict):
            return ''.join([f'\n{k}:\n{self.stringify(v)}' for k, v in seq.items()])
        if numpy.ndim(seq) > 1:
            return '\n'.join([self.stringify(d) for d in seq])
        return ''.join([self.vocab[int(i)] if 0 <= int(i) < len(self.vocab) else '%' for i in seq])

    def summarize(self, batches, predictions, accuracy, length=72):
        # Just summarize predictions within the first batch
        batch = batches[0]
        predictions = predictions[0]
        large_modulus = max(batch['modulus'])
        small_modulus = min(batch['modulus'])
        large_index = batch['modulus'].index(large_modulus)
        small_index = batch['modulus'].index(small_modulus)
        def charfor(a):
            return self.vocab[int(a)] if 0 <= int(a) < len(self.vocab) else '%'
        summary = ''
        # Summarize 2 or 12 samples of the test cases
        for index in ([large_index, small_index] if accuracy < 0.5
                      else range(min(12, len(batch['inputs'])))):
            inputs = batch['inputs'][index][-length:]
            targets = batch['targets'][index][-length:]
            pred = predictions[index][-length:]
            # include lines of raw output
            # summary += ''.join(['\n' + self.stringify(d) for d in [inputs, targets, pred]])
            summary_chars = []
            for i, a in enumerate(inputs):
                summary_chars.append(charfor(a))
                if int(a) == self.predict_token_id:
                    summary_chars.append(f'[{charfor(pred[i])}]')
            summary += '\n' + (''.join(summary_chars) + charfor(targets[-1]))[-length:]
        return summary

class ChainedModularAddition(ShuffledModularAddition):
    def _sample_run(self, k_shots: int, hold_out: int = 0, unshuffled: bool = False, fixed_modulus: int = None):

        # Sample a modulus
        if fixed_modulus is not None:
            modulus = fixed_modulus
        else:
            modulus = self.prng.randint(3, self.max_modulus)
        weights = [(1.0 / (i + 3) if self.zipf else 1.0) for i in range(modulus)]

        # Select a random vocabulary
        if unshuffled:
            vocab = list(range(modulus))
        else:
            minrange = self.max_modulus if self.unshuffled_range else 0
            vocab = self.prng.sample(range(minrange, self.max_number), modulus)

        # Create fact table
        facts = [(a, b) for a in range(modulus) for b in range(modulus)]

        # Hold out some facts
        held_out = []
        for _ in range(hold_out):
            held_out.append(facts.pop(self.prng.randrange(0, len(facts))))

        # Random walk around the fact table, applying the zipf distribution
        # excepting held-out fact until the end
        current = self.prng.randrange(0, modulus)
        sequence = []
        self._extend(sequence, '^', vocab[current])
        for _ in range(k_shots - (2 if hold_out else 0)):
            while True:
                inc = self.prng.choices(range(modulus), weights=weights, k=1)[0]
                if (current, inc) not in held_out:
                    break
            jumpto = (current + inc) % modulus
            self._extend(sequence, '+', vocab[inc], '=', vocab[jumpto])
            current = jumpto
        # If holding out some facts, end with a held-out fact
        if hold_out:
            inc = (held_out[0][0] - current + modulus) % modulus
            jumpto = (current + inc) % modulus
            self._extend(sequence, '+', vocab[inc], '=', vocab[jumpto])
            current = jumpto
            assert current == held_out[0][0]
            inc = held_out[0][1]
            jumpto = (current + inc) % modulus
            self._extend(sequence, '+', vocab[inc], '=', vocab[jumpto])
        return sequence, modulus, vocab

class ModularAddition:

    def __init__(self, max_number: int = 100) -> None:
        self.task_name = "modular_addition"
        self.max_number = max_number

        # Setup the vocabulary
        self.vocab = {i: str(i) for i in range(max_number)}
        self.vocab[max_number] = "_"
        self.pad_token_id = max_number
        self.vocab_size = len(self.vocab)

    def sample_batch(self, batch_size: int):
        expressions, results = self._sample_expressions_and_results(batch_size)

        # Extend the expressions with a single empty token to make space for the prediction
        expressions = [expr + [self.pad_token_id] for expr in expressions]

        return {
            "inputs": torch.tensor(expressions, dtype=torch.long),
            "targets": torch.tensor(results, dtype=torch.long)
        }

    def _sample_expressions_and_results(self, batch_size: int):

        # Sample random numbers
        integer_n = [random.randint(0, self.max_number) for _ in range(batch_size)]
        integer_m = [random.randint(0, self.max_number) for _ in range(batch_size)]

        # Convert numbers to sequence
        expressions = []
        for n, m in zip(integer_n, integer_m):
            expr = [n, m]
            expressions.append(expr)

        # Compute modular sums
        results = []
        for n, m in zip(integer_n, integer_m):
            sum_result = (n + m) % (self.max_number + 1)
            results.append([sum_result])

        return expressions, results


class FewShotModularAddition:

    def __init__(self, max_number: int = 100) -> None:
        self.task_name = "few_shot_modular_addition"
        self.max_number = max_number

        # Setup the vocabulary
        self.vocab = {i: str(i) for i in range(max_number)}
        self.vocab[max_number] = "_"
        self.vocab[max_number + 1] = ","
        self.pad_token_id = max_number
        self.sep_token_id = max_number + 1
        self.vocab_size = len(self.vocab)

    def sample_batch(self, batch_size: int, k_shots: int = 3, use_autoregressive: bool = False):
        batch_inputs = []
        batch_targets = []

        # Sample different primes for each prompt in the batch
        batch_moduli = [random.randint(2, self.max_number) for _ in range(batch_size)]

        # For each prompt in the batch
        for modulus in batch_moduli:
            demos_for_this_example = []

            # Generate k_shot demonstrations all using the same prime
            for _ in range(k_shots):
                n = random.randint(0, self.max_number)
                m = random.randint(0, self.max_number)
                result = (n + m) % modulus
                demos_for_this_example.append([n, m, result])

            # Generate target question using the same prime
            target_n = random.randint(0, self.max_number)
            target_m = random.randint(0, self.max_number)
            target_result = [(target_n + target_m) % modulus]

            # Build the full sequence
            input_sequence = []
            for demo in demos_for_this_example:
                input_sequence.extend(demo)
                input_sequence.extend([self.sep_token_id])

            input_sequence.extend([target_n, target_m])

            # Define the target sequence
            target_sequence = input_sequence.copy()
            target_sequence.extend(target_result)

            batch_inputs.append(input_sequence)
            batch_targets.append(target_sequence)

        return {
            "inputs": torch.tensor(batch_inputs, dtype=torch.long),
            "targets": torch.tensor(batch_targets, dtype=torch.long),
            "moduli": batch_moduli
        }

    def _sample_expressions_and_results(self, batch_size: int):

        # Sample random numbers and random primes for each example
        integer_n = [random.randint(0, self.max_number) for _ in range(batch_size)]
        integer_m = [random.randint(0, self.max_number) for _ in range(batch_size)]

        # Sample a different prime for each example
        moduli = random.sample(self.moduli, batch_size)

        # Convert numbers to sequence
        expressions = []
        for n, m in zip(integer_n, integer_m):
            expr = [n, m]
            expressions.append(expr)

        # Compute modular sums using the corresponding prime for each example
        results = []
        for n, m, modulus in zip(integer_n, integer_m, moduli):
            sum_result = (n + m) % modulus
            results.append([sum_result])

        return expressions, results, moduli

