import random
import torch
import string
import numpy
from sympy.combinatorics import Permutation, PermutationGroup
from sympy.combinatorics import CyclicGroup, DihedralGroup

class MultiGroupAddition:
    """
    A generator for long sequences of group addition problems
    where the one or more groups are randomly chosen for the run,
    and the the assignment of symbols to group elements is also
    selected to be a fixed random map used for the whole run.
    Unlike GroupAddition, a single run can include facts from
    more than one group.  Intended to mimic an in-context
    learning scenario.
    """
    def __init__(self, num_symbols: int = 16, max_order: int = 10,
            holdout_zero: bool = False, seed: int = 42) -> None:
        assert(max_order <= num_symbols)
        self.task_name = self._task_name()
        self.num_symbols = num_symbols
        self.max_order = max_order
        self.holdout_zero = holdout_zero
        self.prng = random.Random(seed)

        # Setup the vocabulary
        digitletters = string.digits + string.ascii_letters
        self.vocab = [digitletters[i] for i in range(num_symbols)]
        self.predict_token_id = len(self.vocab)
        self.vocab.append('=')
        self.start_token_id = len(self.vocab)
        self.vocab.append('^')
        self.pad_token_id = self.start_token_id
        self.sep_token_id = len(self.vocab)
        self.vocab.append(',')
        self.vocab_size = len(self.vocab)
        self.numfor = { v: i for i, v in enumerate(self.vocab) }

    def _task_name(self):
        return 'multigroup'

    def sample_batch(self, batch_size: int,
            k_shots: int = 200, hold_out: int | list = 0,
            commute_out: bool = True,
            multi_runs: bool = False,
            max_length: int = 1024,
            unshuffled: bool = False,
            fixed_groups: list = None):
        assert multi_runs == False

        expressions, g, o, v = zip(*[
            self.sample_run(k_shots, hold_out, commute_out, unshuffled, fixed_groups)
                for _ in range(batch_size)])
        tensor = self.tensor_from_expression(expressions)

        # Every token is a goal
        return {
            "inputs": tensor[:,:-1],
            "targets": tensor[:,1:],
            "groups": g,
            "orders": o,
            "vocabulary": [''.join(voc) for voc in v],
            "prediction_mask": (tensor[:,:-1] == self.predict_token_id)
        }

    def tensor_from_expression(self, expressions):
        def recursive_numfor(e):
            if isinstance(e, (list, tuple)):
                return [recursive_numfor(el) for el in e]
            return [self.numfor[el] for el in e]
        return torch.tensor(recursive_numfor(expressions), dtype=torch.long)

    def sample_run(self, k_shots: int, hold_out: int | list = 0, commute_out: bool = True,
            unshuffled: bool = False, fixed_groups = None):

        # Sample random groups 
        if fixed_groups is not None:
            Glist = fixed_groups
        else:
            Glist = self.sample_groups()
        orders = [G.order() for G in Glist]
        total_order = sum(orders)
        assert (fixed_groups is not None or
                (max(orders) <= self.max_order and total_order <= self.num_symbols) )

        # Select a random vocabulary
        elems = [list(G.generate()) for G in Glist]
        all_elems = sum(elems, [])
        if unshuffled:
            vocab = self.vocab[:total_order]
            wordfor = { g: vocab[i] for i, g in enumerate(all_elems) }
        else:
            while True:
                vocab = self.prng.sample(self.vocab[:self.num_symbols], total_order)
                wordfor = { g: vocab[i] for i, g in enumerate(all_elems) }
                # Hold out the 1/num_symbols assignment of vocabularies where 0 is '0'.
                if not self.holdout_zero or not any(wordfor[E[0]] == '0' for E in elems):
                    break

        # Create Cayley table
        facts = [(a, b) for E in elems for a in E for b in E]

        # Hold out some facts
        held_out = []
        if isinstance(hold_out, int):
            while len(held_out) < hold_out:
                (a, b) = facts.pop(self.prng.randrange(0, len(facts)))
                held_out.append((a, b))
                if commute_out:
                    if (b, a) in facts:
                        facts.remove((b, a))
                        held_out.append((b, a))
        elif isinstance(hold_out, list):
            for (ai, bi) in hold_out:
                a, b = (all_elems[ai], all_elems[bi])
                if (a, b) in facts:
                    facts.remove((a, b))
                    held_out.append((a, b))
                if commute_out:
                    if (b, a) in facts:
                        facts.remove((b, a))
                        held_out.append((b, a))

        return self.sample_facts(k_shots, elems, wordfor, facts, held_out), Glist, orders, vocab

    def sample_groups(self):
        total_order = 0
        Glist = []
        # Sample groups until the total order exceeds the number of
        # available symbols, then pop the last one.
        while total_order <= self.num_symbols:
            G = self._sample_group()
            Glist.append(G)
            total_order += G.order()
        total_order -= Glist.pop().order()
        return Glist

    def sample_facts(self, k_shots: int, elems: list, wordfor: dict, facts: list, held_out: list):
        '''
        Output pattern: ,ab=c,de=f,gh=i,jk=l for a+b=c, d+e=f, g+h=i, j+k=l
        '''
        sequence = []
        # Random samples around the fact table, excepting held-out facts until the end
        for _ in range(k_shots - (1 if held_out else 0)):
            a, b = self.prng.choice(facts)
            c = a * b
            sequence.extend([',', wordfor[a], wordfor[b], '=', wordfor[c]])
        # If holding out some facts, end with a held-out fact
        if held_out:
            a, b = held_out[0]
            c = a * b
            sequence.extend([',', wordfor[a], wordfor[b], '=', wordfor[c]])
        return ''.join(sequence)

    def stringify(self, seq):
        if isinstance(seq, dict):
            return ''.join([f'\n{k}:\n{self.stringify(v)}' for k, v in seq.items()])
        if isinstance(seq, str):
            return f'"{seq}"'
        if isinstance(seq, PermutationGroup):
            if seq.is_cyclic:
                return f'CyclicGroup({seq.order()})'
            elif seq.is_dihedral:
                return f'DihedralGroup({seq.order() // 2})'
            return str(seq)
        if isinstance(seq, int):
            return str(seq)
        if isinstance(seq, (list, tuple)) and len(seq) and (
                isinstance(seq[0], (PermutationGroup, int))):
            return ' '.join([self.stringify(i) for i in seq])
        if isinstance(seq, (list, tuple)):
            return '\n'.join([self.stringify(i) for i in seq])
        if numpy.ndim(seq) > 1:
            return '\n'.join([self.stringify(d) for d in seq])
        if numpy.ndim(seq) == 1:
            return ''.join([
                self.vocab[int(i)] if 0 <= int(i) < len(self.vocab) else '%'
                for i in seq])
        return str(seq)

    def summarize(self, batches, predictions, accuracy, length=72):
        # Just summarize predictions within the first batch
        batch = batches[0]
        predictions = predictions[0]
        def charfor(a):
            return self.vocab[int(a)] if 0 <= int(a) < len(self.vocab) else '%'
        summary = ''
        # Summarize 2 or 12 samples of the test cases
        for index in range(min(12, len(batch['inputs']))):
            inputs = batch['inputs'][index][-length:]
            targets = batch['targets'][index][-length:]
            pred = predictions[index][-length:]
            # include lines of raw output
            # summary += ''.join(['\n' + self.stringify(d) for d in [inputs, targets, pred]])
            summary_chars = []
            for i, a in enumerate(inputs):
                summary_chars.append(charfor(a))
                if int(a) == self.predict_token_id:
                    summary_chars.append(f'[{charfor(pred[i])}]')
            summary += '\n' + (''.join(summary_chars) + charfor(targets[-1]))[-length:]
        return summary

class MultiCyclicGroupAddition(MultiGroupAddition):
    """
    Uniformly sample cyclic groups of order at least 3, up to the maximum order
    """
    def _task_name(self):
        return 'multicyclic'

    def _sample_group(self):
        modulus = self.prng.randrange(3, self.max_order + 1)
        return CyclicGroup(modulus)

class MultiDihedralGroupAddition(MultiGroupAddition):
    """
    Uniformly sample dihedral groups of order at least 4, up to the maximum order
    """
    def _task_name(self):
        return 'multidihedral'

    def _sample_group(self):
        modulus = self.prng.randrange(2, self.max_order // 2 + 1)
        return DihedralGroup(modulus)

class MultiRosetteGroupAddition(MultiGroupAddition):
    """
    Uniformly sample cyclic or dihedral groups of order at least 3, up to the maximum order
    """
    def _task_name(self):
        return 'multirosette'

    def _sample_group(self):
        num_cyclic = self.max_order + 1 - 3
        num_dihedral = (self.max_order // 2) + 1 - 2
        which_group = self.prng.randrange(num_cyclic + num_dihedral)
        if which_group < num_cyclic:
            return CyclicGroup(which_group + 3)
        else:
            return DihedralGroup((which_group - num_cyclic) + 2)

def _unit_test():
    import re
    def eqstring(a, b):
        # Remove space at start of lines
        [a, _], [b, _] = [re.subn(r'((?<=\n)|^) *', '', s) for s in [a, b]]
        if len(a) != len(b):
            print(f'Difference in length {len(a)} vs {len(b)}')
        for i in range(min(len(a), len(b))):
            if a[i] != b[i]:
                print(f'Difference at index: {i}: "{a[i:i+3]}" vs "{b[i:i+3]}"')
                print(a)
                break
        return a == b
    a = MultiRosetteGroupAddition(max_order=5, num_symbols=12, holdout_zero=True)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=2, k_shots=12, unshuffled=True, hold_out=1)), '''
        inputs:
        ,33=3,53=5,99=7,45=3,a7=6,87=9,35=5,34=4,45=3,53=5,54=3,45=
        ,a9=8,97=9,00=0,74=4,79=9,67=6,97=9,77=7,47=4,67=6,20=2,7a=
        targets:
        33=3,53=5,99=7,45=3,a7=6,87=9,35=5,34=4,45=3,53=5,54=3,45=3
        a9=8,97=9,00=0,74=4,79=9,67=6,97=9,77=7,47=4,67=6,20=2,7a=a
        groups:
        CyclicGroup(3) CyclicGroup(3) CyclicGroup(5)
        CyclicGroup(3) CyclicGroup(4) DihedralGroup(2)
        orders:
        3 3 5
        3 4 4
        vocabulary:
        "0123456789a"
        "0123456789a"
        prediction_mask:
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001''')
    # Skip forward for a case where holdout_zero makes a difference
    a.stringify(a.sample_batch(batch_size=3, k_shots=12))
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=12)), '''
        inputs:
        ,73=3,70=0,77=7,03=a,67=6,99=3,70=0,0a=3,69=7,79=9,aa=7,7a=
        ,91=9,3a=3,aa=a,03=a,3a=3,00=3,3a=3,89=1,99=2,92=8,81=8,30=
        ,76=7,87=5,00=a,85=7,85=7,77=8,78=5,77=8,67=7,65=5,01=0,01=
        targets:
        73=3,70=0,77=7,03=a,67=6,99=3,70=0,0a=3,69=7,79=9,aa=7,7a=a
        91=9,3a=3,aa=a,03=a,3a=3,00=3,3a=3,89=1,99=2,92=8,81=8,30=a
        76=7,87=5,00=a,85=7,85=7,77=8,78=5,77=8,67=7,65=5,01=0,01=0
        groups:
        DihedralGroup(2) CyclicGroup(4) DihedralGroup(2)
        CyclicGroup(4) CyclicGroup(3) CyclicGroup(3)
        CyclicGroup(3) CyclicGroup(4) CyclicGroup(4)
        orders:
        4 4 4
        4 3 3
        3 4 4
        vocabulary:
        "b4821659703a"
        "182974ba30"
        "1a03b926785"
        prediction_mask:
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001''')
    a = MultiDihedralGroupAddition(max_order=14)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=12)), '''
        inputs:
        ,35=4,c3=8,d3=2,f0=f,01=1,e3=1,39=2,5b=0,f6=4,6f=4,d5=9,8d=
        ,eb=a,14=7,c6=8,34=9,6c=8,f5=f,0d=6,19=a,ef=4,41=a,d6=2,c6=
        ,ea=b,a8=d,a4=5,88=1,1e=2,81=a,85=e,31=8,b2=1,fe=6,e4=d,df=
        targets:
        35=4,c3=8,d3=2,f0=f,01=1,e3=1,39=2,5b=0,f6=4,6f=4,d5=9,8d=4
        eb=a,14=7,c6=8,34=9,6c=8,f5=f,0d=6,19=a,ef=4,41=a,d6=2,c6=8
        ea=b,a8=d,a4=5,88=1,1e=2,81=a,85=e,31=8,b2=1,fe=6,e4=d,df=1
        groups:
        DihedralGroup(7)
        DihedralGroup(3) DihedralGroup(5)
        DihedralGroup(7)
        orders:
        14
        6 10
        14
        vocabulary:
        "0b43c2189def65"
        "8cd02654b17efa39"
        "91d3e6475b28af"
        prediction_mask:
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001''')
    a = MultiCyclicGroupAddition(max_order=13)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=12, unshuffled=True)), '''
        inputs:
        ,06=6,55=a,4a=1,45=9,29=b,20=2,a9=6,19=a,b8=6,84=c,08=8,07=
        ,eb=f,03=3,db=e,bd=e,fb=a,cf=b,ab=b,ca=c,af=f,df=c,db=e,00=
        ,66=7,34=2,57=7,d5=d,75=7,23=0,21=3,7a=c,22=4,77=9,76=8,ac=
        targets:
        06=6,55=a,4a=1,45=9,29=b,20=2,a9=6,19=a,b8=6,84=c,08=8,07=7
        eb=f,03=3,db=e,bd=e,fb=a,cf=b,ab=b,ca=c,af=f,df=c,db=e,00=0
        66=7,34=2,57=7,d5=d,75=7,23=0,21=3,7a=c,22=4,77=9,76=8,ac=8
        groups:
        CyclicGroup(13)
        CyclicGroup(4) CyclicGroup(6) CyclicGroup(6)
        CyclicGroup(5) CyclicGroup(9)
        orders:
        13
        4 6 6
        5 9
        vocabulary:
        "0123456789abc"
        "0123456789abcdef"
        "0123456789abcd"
        prediction_mask:
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001''')
    a = MultiDihedralGroupAddition(max_order=10, holdout_zero=True)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=12)), '''
        inputs:
        ,f5=f,ca=8,c8=f,aa=5,5c=c,aa=5,ba=e,f1=a,88=5,ff=5,15=1,11=
        ,68=8,99=6,57=9,eb=7,5f=b,bb=6,8f=5,65=5,e6=e,97=5,fb=5,e9=
        ,d5=e,b5=5,5e=b,ce=a,c6=d,d5=e,ed=5,6c=d,ed=5,67=e,7c=e,56=
        targets:
        f5=f,ca=8,c8=f,aa=5,5c=c,aa=5,ba=e,f1=a,88=5,ff=5,15=1,11=5
        68=8,99=6,57=9,eb=7,5f=b,bb=6,8f=5,65=5,e6=e,97=5,fb=5,e9=f
        d5=e,b5=5,5e=b,ce=a,c6=d,d5=e,ed=5,6c=d,ed=5,67=e,7c=e,56=7
        groups:
        DihedralGroup(2) DihedralGroup(2) DihedralGroup(4)
        DihedralGroup(4) DihedralGroup(4)
        DihedralGroup(2) DihedralGroup(4)
        orders:
        4 4 8
        8 8
        4 8
        vocabulary:
        "d37940265c1f8aeb"
        "20a341cd6f9785be"
        "8139b5d76ace"
        prediction_mask:
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001''')
    a = MultiCyclicGroupAddition(max_order=10, holdout_zero=True)
    assert eqstring(
        a.stringify(a.sample_batch(batch_size=3, k_shots=12, unshuffled=True, hold_out=1)), '''
        inputs:
        ,45=5,31=0,db=a,23=1,ba=7,10=1,03=3,23=1,79=9,7c=c,cd=b,7a=
        ,65=8,78=6,46=7,00=0,34=4,64=7,54=6,46=7,33=3,37=7,54=6,38=
        ,75=4,75=4,75=4,55=6,66=4,67=5,45=5,65=7,44=4,47=7,47=7,56=
        targets:
        45=5,31=0,db=a,23=1,ba=7,10=1,03=3,23=1,79=9,7c=c,cd=b,7a=a
        65=8,78=6,46=7,00=0,34=4,64=7,54=6,46=7,33=3,37=7,54=6,38=8
        75=4,75=4,75=4,55=6,66=4,67=5,45=5,65=7,44=4,47=7,47=7,56=7
        groups:
        CyclicGroup(4) CyclicGroup(3) CyclicGroup(7)
        CyclicGroup(3) CyclicGroup(6)
        CyclicGroup(4) CyclicGroup(4)
        orders:
        4 3 7
        3 6
        4 4
        vocabulary:
        "0123456789abcd"
        "012345678"
        "01234567"
        prediction_mask:
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001
        00010000100001000010000100001000010000100001000010000100001''')

if __name__ == '__main__':
    _unit_test()
