import random
import torch
import string
import numpy
from sympy.combinatorics import Permutation, PermutationGroup
from sympy.combinatorics import CyclicGroup, DihedralGroup

class GroupAddition:
    """
    A generator for long sequences of group addition problems
    where the group is randomly chosen for the run, and the
    the assignment of symbols to group elements is also selected
    to be a fixed random map used for the whole run. Intended to
    mimic an in-context learning scenario.
    """
    def __init__(self, num_symbols: int = 30, max_order: int = 20,
            holdout_zero: bool = False, form: str = 'chain', seed: int = 42) -> None:
        assert(max_order <= num_symbols)
        assert form in ['chain', 'facts']
        self.task_name = self._task_name()
        self.num_symbols = num_symbols
        self.max_order = max_order
        self.form = form
        self.holdout_zero = holdout_zero
        self.prng = random.Random(seed)

        # Setup the vocabulary
        digitletters = string.digits + string.ascii_letters
        self.vocab = [digitletters[i] for i in range(num_symbols)]
        self.predict_token_id = len(self.vocab)
        self.vocab.append('=')
        self.start_token_id = len(self.vocab)
        self.vocab.append('^')
        self.pad_token_id = self.start_token_id
        if form != 'chain':
            self.sep_token_id = len(self.vocab)
            self.vocab.append(',')
        self.vocab_size = len(self.vocab)
        self.numfor = { v: i for i, v in enumerate(self.vocab) }

    def _task_name(self):
        return 'group'

    def sample_batch(self, batch_size: int,
            k_shots: int = 341, hold_out: int | list = 0,
            commute_out: bool = False,
            multi_runs: bool = False,
            max_length: int = 1024,
            unshuffled: bool = False,
            fixed_group: PermutationGroup = None):
        '''
        Returns a batch of batch_size examples as tensors, each demonstrating
        a run of k_shots of group composition facts drawn from a group.

        If hold_out is set, then in each run the given number of facts is
        held out of each run (always holding out commutative inverses if
        commute_out is set), and each run ends with a held out sample.
        If hold_out is set to a number, then the specified number of
        held out examples will be randomly chosen. If hold_out is a
        list of pairs of integers, then the corresponding entries in the
        cayley tables will be held out, and the first entry of the list
        will be used as the final test question.

        If fixed_group is provided, the group used will be the one given;
        otherwise the group in each run is chosen randomly.

        If unshuffled is True, the vocabulary used is 01234... with 0 given to
        the identity element of the group; otherwise if it
        is false, the vocabulary is assigned randomly in each run.

        The return structure provides the group for each run, its order,
        and the vocabulary for each run, as well as a mask showing where
        all the "=" signs are.

        Calls sample_run to do the work of sampling individual runs.
        '''
        assert multi_runs == False

        expressions, g, o, v = zip(*[
            self.sample_run(k_shots, hold_out, commute_out, unshuffled, fixed_group)
                for _ in range(batch_size)])
        tensor = self.tensor_from_expression(expressions)

        # Every token is a goal
        return {
            "inputs": tensor[:,:-1],
            "targets": tensor[:,1:],
            "group": g,
            "order": o,
            "vocabulary": [''.join(voc) for voc in v],
            "prediction_mask": (tensor[:,:-1] == self.predict_token_id)
        }

    def tensor_from_expression(self, expressions):
        def recursive_numfor(e):
            if isinstance(e, (list, tuple)):
                return [recursive_numfor(el) for el in e]
            return [self.numfor[el] for el in e]
        return torch.tensor(recursive_numfor(expressions), dtype=torch.long)

    def sample_run(self, k_shots: int, hold_out: int | list = 0, commute_out: bool = False,
            unshuffled: bool = False, fixed_group = None):

        # Sample a random group
        if fixed_group is not None:
            G = fixed_group
        else:
            G = self._sample_group()
        order = G.order()
        assert order <= self.max_order or fixed_group is not None

        # Select a random vocabulary
        elems = list(G.generate())
        if unshuffled:
            vocab = self.vocab[:order]
        else:
            while True:
                vocab = self.prng.sample(self.vocab[:self.num_symbols], order)
                # Hold out the 1/num_symbols assignment of vocabularies where 0 is '0'.
                if not self.holdout_zero or vocab[0] != '0':
                    break
        wordfor = { g: vocab[i] for i, g in enumerate(elems) }

        # Create Cayley table
        facts = [(a, b) for a in elems for b in elems]

        # Hold out some facts
        held_out = []
        if isinstance(hold_out, int):
            while len(held_out) < hold_out:
                (a, b) = facts.pop(self.prng.randrange(0, len(facts)))
                held_out.append((a, b))
                if commute_out:
                    if (b, a) in facts:
                        facts.remove((b, a))
                        held_out.append((b, a))
        elif isinstance(hold_out, list):
            for (ai, bi) in hold_out:
                a, b = (elems[ai], elems[bi])
                if (a, b) in facts:
                    held_out.append((a, b))
                    facts.remove((a, b))
                if commute_out:
                    if (b, a) in facts:
                        facts.remove((b, a))
                        held_out.append((b, a))

        if self.form == 'chain':
            return self.sample_chain(k_shots, elems, wordfor, facts, held_out), G, order, vocab
        elif self.form == 'facts':
            return self.sample_facts(k_shots, elems, wordfor, facts, held_out), G, order, vocab

    def sample_chain(self, k_shots: int, elems: list, wordfor: dict, facts: list, held_out: list):
        '''
        Output pattern: ^ab=cd=fe=gh=i for a+b=c, c+d=f, f+e=g, g+h=i.
        '''

        # Random walk around the fact table, excepting held-out facts until the end
        current = elems[self.prng.randrange(0, len(elems))]
        sequence = ['^', wordfor[current]]
        for _ in range(k_shots - (2 if held_out else 0)):
            while True:
                inc = self.prng.choice(elems)
                if (current, inc) in facts:
                    break
            jumpto = current * inc
            sequence.extend([wordfor[inc], '=', wordfor[jumpto]])
            current = jumpto
        # If holding out some facts, end with a held-out fact
        if held_out:
            try:
                # If the algebra has inverses, use it
                inc = ~current * held_out[0][0]
            except:
                # If the algebra has left-division use it
                inc = current // held_out[0][0]
                # If none of the above, then chaining makes no sense
            jumpto = current * inc
            sequence.extend([wordfor[inc], '=', wordfor[jumpto]])
            current = jumpto
            assert current == held_out[0][0], (current, held_out[0][0])
            inc = held_out[0][1]
            jumpto = current * inc
            sequence.extend([wordfor[inc], '=', wordfor[jumpto]])
        return ''.join(sequence)

    def sample_facts(self, k_shots: int, elems: list, wordfor: dict, facts: list, held_out: list):
        '''
        Output pattern: ,ab=c,de=f,gh=i,jk=l for a+b=c, d+e=f, g+h=i, j+k=l
        '''
        sequence = []
        # Random samples around the fact table, excepting held-out facts until the end
        for _ in range(k_shots - (1 if held_out else 0)):
            a, b = self.prng.choice(facts)
            c = a * b
            sequence.extend([',', wordfor[a], wordfor[b], '=', wordfor[c]])
        # If holding out some facts, end with a held-out fact
        if held_out:
            a, b = held_out[0]
            c = a * b
            sequence.extend([',', wordfor[a], wordfor[b], '=', wordfor[c]])
        return ''.join(sequence)

    def stringify(self, seq):
        if isinstance(seq, dict):
            return ''.join([f'\n{k}:\n{self.stringify(v)}' for k, v in seq.items()])
        if isinstance(seq, str):
            return f'"{seq}"'
        if isinstance(seq, PermutationGroup):
            if seq.is_cyclic:
                return f'CyclicGroup({seq.order()})'
            elif seq.is_dihedral:
                return f'DihedralGroup({seq.order() // 2})'
            return str(seq)
        if isinstance(seq, int):
            return str(seq)
        if isinstance(seq, (list, tuple)) and len(seq) and (
                isinstance(seq[0], (PermutationGroup, int))):
            return ' '.join([self.stringify(i) for i in seq])
        if isinstance(seq, (list, tuple)):
            return '\n'.join([self.stringify(i) for i in seq])
        if numpy.ndim(seq) > 1:
            return '\n'.join([self.stringify(d) for d in seq])
        if numpy.ndim(seq) == 1:
            return ''.join([
                self.vocab[int(i)] if 0 <= int(i) < len(self.vocab) else '%'
                for i in seq])
        return str(seq)

    def summarize(self, batches, predictions, accuracy, length=72):
        # Just summarize predictions within the first batch
        batch = batches[0]
        predictions = predictions[0]
        large_order = max(batch['order'])
        small_order = min(batch['order'])
        large_index = batch['order'].index(large_order)
        small_index = batch['order'].index(small_order)
        def charfor(a):
            return self.vocab[int(a)] if 0 <= int(a) < len(self.vocab) else '%'
        summary = ''
        # Summarize 2 or 12 samples of the test cases
        for index in ([large_index, small_index] if accuracy < 0.5
                      else range(min(12, len(batch['inputs'])))):
            inputs = batch['inputs'][index][-length:]
            targets = batch['targets'][index][-length:]
            pred = predictions[index][-length:]
            # include lines of raw output
            # summary += ''.join(['\n' + self.stringify(d) for d in [inputs, targets, pred]])
            summary_chars = []
            for i, a in enumerate(inputs):
                summary_chars.append(charfor(a))
                if int(a) == self.predict_token_id:
                    summary_chars.append(f'[{charfor(pred[i])}]')
            summary += '\n' + (''.join(summary_chars) + charfor(targets[-1]))[-length:]
        return summary

class CyclicGroupAddition(GroupAddition):
    """
    Uniformly sample cyclic groups of order at least 3, up to the maximum order
    """
    def _task_name(self):
        return 'cyclic'

    def _sample_group(self):
        modulus = self.prng.randrange(3, self.max_order + 1)
        return CyclicGroup(modulus)

class DihedralGroupAddition(GroupAddition):
    """
    Uniformly sample dihedral groups of order at least 4, up to the maximum order
    """
    def _task_name(self):
        return 'dihedral'

    def _sample_group(self):
        modulus = self.prng.randrange(2, self.max_order // 2 + 1)
        return DihedralGroup(modulus)

class RosetteGroupAddition(GroupAddition):
    """
    Uniformly sample cyclic or dihedral groups of order at least 3, up to the maximum order
    """
    def _task_name(self):
        return 'rosette'

    def _sample_group(self):
        num_cyclic = self.max_order + 1 - 3
        num_dihedral = (self.max_order // 2) + 1 - 2
        which_group = self.prng.randrange(num_cyclic + num_dihedral)
        if which_group < num_cyclic:
            return CyclicGroup(which_group + 3)
        else:
            return DihedralGroup((which_group - num_cyclic) + 2)

def _unit_test():
    import re
    def eqstring(a, b):
        # Remove space at start of lines
        [a, _], [b, _] = [re.subn(r'((?<=\n)|^) *', '', s) for s in [a, b]]
        if len(a) != len(b):
            print(a)
            print(f'Difference in length {len(a)} vs {len(b)}')
        for i in range(min(len(b), len(b))):
            if a[i] != b[i]:
                print(a)
                print(f'Difference at index: {i}: "{a[i:i+3]}" vs "{b[i:i+3]}"')
                break
        return a == b
    a = RosetteGroupAddition(max_order=4, num_symbols=5, holdout_zero=True)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=20, unshuffled=True, hold_out=1)), '''
        inputs:
        ^02=21=31=21=30=30=33=00=00=00=01=11=00=01=13=21=33=02=22=03=
        ^21=01=11=20=20=21=00=00=01=10=11=21=01=10=12=01=12=00=00=02=
        ^22=01=10=10=11=22=00=01=10=13=03=32=11=22=01=12=30=31=00=02=
        targets:
        02=21=31=21=30=30=33=00=00=00=01=11=00=01=13=21=33=02=22=03=3
        21=01=11=20=20=21=00=00=01=10=11=21=01=10=12=01=12=00=00=02=2
        22=01=10=10=11=22=00=01=10=13=03=32=11=22=01=12=30=31=00=02=2
        group:
        DihedralGroup(2) CyclicGroup(3) CyclicGroup(4)
        order:
        4 3 4
        vocabulary:
        "0123"
        "012"
        "0123"
        prediction_mask:
        0001001001001001001001001001001001001001001001001001001001001
        0001001001001001001001001001001001001001001001001001001001001
        0001001001001001001001001001001001001001001001001001001001001''')
    # Skip forward for a case where holdout_zero makes a difference
    a.stringify(a.sample_batch(batch_size=3, k_shots=20))
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=20)), '''
        inputs:
        ^42=33=14=44=12=23=42=33=13=31=34=24=31=32=41=44=14=41=41=41=
        ^32=22=12=31=13=11=21=33=33=33=32=21=31=11=21=31=12=33=32=22=
        ^30=11=43=34=31=04=04=04=04=03=13=01=31=01=33=41=14=13=01=34=
        targets:
        42=33=14=44=12=23=42=33=13=31=34=24=31=32=41=44=14=41=41=41=4
        32=22=12=31=13=11=21=33=33=33=32=21=31=11=21=31=12=33=32=22=1
        30=11=43=34=31=04=04=04=04=03=13=01=31=01=33=41=14=13=01=34=3
        group:
        DihedralGroup(2) CyclicGroup(3) DihedralGroup(2)
        order:
        4 3 4
        vocabulary:
        "1423"
        "312"
        "4301"
        prediction_mask:
        0001001001001001001001001001001001001001001001001001001001001
        0001001001001001001001001001001001001001001001001001001001001
        0001001001001001001001001001001001001001001001001001001001001''')
    a = DihedralGroupAddition(max_order=14)
    assert eqstring(
         a.stringify(a.sample_batch(batch_size=3, k_shots=20)), '''
        inputs:
        ^30=08=i8=3l=lh=03=0l=t8=2i=t2=8i=3l=l4=d8=0t=4h=87=ld=41=p3=
        ^am=am=a8=dm=da=4a=d4=8a=6m=66=d8=a4=6m=68=4m=44=ma=a6=44=ma=
        ^rc=77=62=27=37=21=8c=1l=r8=9c=lc=om=33=6c=c2=6o=o8=r1=l2=9r=
        targets:
        30=08=i8=3l=lh=03=0l=t8=2i=t2=8i=3l=l4=d8=0t=4h=87=ld=41=p3=p
        am=am=a8=dm=da=4a=d4=8a=6m=66=d8=a4=6m=68=4m=44=ma=a6=44=ma=a
        rc=77=62=27=37=21=8c=1l=r8=9c=lc=om=33=6c=c2=6o=o8=r1=l2=9r=m
        group:
        DihedralGroup(7) DihedralGroup(3) DihedralGroup(6)
        order:
        14 6 12
        vocabulary:
        "30n87p4tlh2id1"
        "mda846"
        "6m21l79ro3c8"
        prediction_mask:
        0001001001001001001001001001001001001001001001001001001001001
        0001001001001001001001001001001001001001001001001001001001001
        0001001001001001001001001001001001001001001001001001001001001''')
    a = CyclicGroupAddition(max_order=13)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=20, unshuffled=True)), '''
        inputs:
        ^10=1b=c4=33=63=92=bb=91=aa=7b=58=01=19=a6=30=30=31=43=73=a8=
        ^08=83=bb=aa=8b=78=36=93=07=79=44=80=82=ab=96=35=84=02=23=55=
        ^03=30=32=12=32=10=13=00=03=30=32=12=31=00=00=01=12=30=31=00=
        targets:
        10=1b=c4=33=63=92=bb=91=aa=7b=58=01=19=a6=30=30=31=43=73=a8=5
        08=83=bb=aa=8b=78=36=93=07=79=44=80=82=ab=96=35=84=02=23=55=a
        03=30=32=12=32=10=13=00=03=30=32=12=31=00=00=01=12=30=31=00=0
        group:
        CyclicGroup(13) CyclicGroup(12) CyclicGroup(4)
        order:
        13 12 4
        vocabulary:
        "0123456789abc"
        "0123456789ab"
        "0123"
        prediction_mask:
        0001001001001001001001001001001001001001001001001001001001001
        0001001001001001001001001001001001001001001001001001001001001
        0001001001001001001001001001001001001001001001001001001001001''')
    a = DihedralGroupAddition(max_order=10, form='facts', holdout_zero=True)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=12)), '''
        inputs:
        ,7n=n,34=n,74=4,77=7,7n=n,4n=3,43=n,77=7,4n=3,34=n,43=n,3n=
        ,5o=0,to=d,0o=5,oo=p,5o=0,r5=t,ro=m,mp=m,rt=o,55=p,5t=d,tr=
        ,3n=3,nh=h,he=3,h3=e,eh=3,nh=h,ne=e,e3=h,he=3,nh=h,e3=h,n3=
        targets:
        7n=n,34=n,74=4,77=7,7n=n,4n=3,43=n,77=7,4n=3,34=n,43=n,3n=4
        5o=0,to=d,0o=5,oo=p,5o=0,r5=t,ro=m,mp=m,rt=o,55=p,5t=d,tr=5
        3n=3,nh=h,he=3,h3=e,eh=3,nh=h,ne=e,e3=h,he=3,nh=h,e3=h,n3=3
        group:
        DihedralGroup(2) DihedralGroup(4) DihedralGroup(2)
        order:
        4 8 4
        vocabulary:
        "74n3"
        "pr0ot5md"
        "neh3"
        prediction_mask:
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001''')
    a = CyclicGroupAddition(max_order=10, form='facts', holdout_zero=True)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=12, unshuffled=True, hold_out=1)), '''
        inputs:
        ,30=3,11=2,10=1,10=1,03=3,30=3,02=2,23=1,30=3,33=2,21=3,00=
        ,00=0,00=0,01=1,03=3,03=3,20=2,21=3,00=0,20=2,03=3,23=1,31=
        ,64=1,84=3,40=4,00=0,22=4,61=7,48=3,40=4,21=3,30=3,48=3,31=
        targets:
        30=3,11=2,10=1,10=1,03=3,30=3,02=2,23=1,30=3,33=2,21=3,00=0
        00=0,00=0,01=1,03=3,03=3,20=2,21=3,00=0,20=2,03=3,23=1,31=0
        64=1,84=3,40=4,00=0,22=4,61=7,48=3,40=4,21=3,30=3,48=3,31=4
        group:
        CyclicGroup(4) CyclicGroup(4) CyclicGroup(9)
        order:
        4 4 9
        vocabulary:
        "0123"
        "0123"
        "012345678"
        prediction_mask:
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001''')

if __name__ == '__main__':
    _unit_test()
