
from __future__ import annotations

from typing import Tuple, Callable, Iterable, List, Any, Dict

import package.group

from package.group import Group, GroupElement, IrreducibleRepresentation, DirectProductGroup
from package.group.irrep import restrict_irrep

import numpy as np
import itertools
import re

__all__ = [
    'DoubleGroup',
    'double_group'
]


class DoubleGroup(DirectProductGroup):

    def __init__(self, G: str, name: str = None, **group_keys):

        assert all(k.startswith('G_') for k in group_keys.keys())

        keys1 = {
            'G1_' + k[2:]: v for k, v in group_keys.items()
        }
        keys2 = {
            'G2_' + k[2:]: v for k, v in group_keys.items()
        }
        super(DoubleGroup, self).__init__(G, G, name, **keys1, **keys2)

    @property
    def _keys(self) -> Dict[str, Any]:
        keys = dict()
        keys['G'] = self.G1.__class__.__name__

        if not self._defaulf_name:
            keys['name'] = self.name

        keys.update({
            'G_' + k: v
            for k, v in self.G1._keys.items()
        })
        return keys

    @property
    def subgroup_diagonal_id(self):
        r"""
            The subgroup id associated with the diagonal group.
            This is the subgroup containing elements in the form :math:`(g, g)` for :math:`g \in G` and is isomorphic
            to :math:`G` itself.
            The id can be used in the method :meth:`~package.group.Group.subgroup` to generate the subgroup.
        """
        return 'diagonal'

    def _process_subgroup_id(self, id):
        if id == 'diagonal':
            return id
        else:
            return super(DoubleGroup, self)._process_subgroup_id(id)

    def _combine_subgroups(self, sg_id1, sg_id2):
        raise NotImplementedError

    def _subgroup(self, id) -> Tuple[
        package.group.Group,
        Callable[[package.group.GroupElement], package.group.GroupElement],
        Callable[[package.group.GroupElement], package.group.GroupElement]
    ]:
        if id == 'diagonal':
            return self.G1, inclusion(self), restriction(self)

        else:
            return super(DoubleGroup, self)._subgroup(id)

    _cached_group_instance = dict()

    @classmethod
    def _generator(cls, G: str, **group_keys) -> 'DirectProductGroup':

        key = {
            'G': G,
        }
        key.update(**group_keys)

        key = tuple(sorted(key.items()))

        if key not in cls._cached_group_instance:
            cls._cached_group_instance[key] = DoubleGroup(G, **group_keys)

        cls._cached_group_instance[key]._build_representations()

        return cls._cached_group_instance[key]


def restriction(G: DoubleGroup):
    def _map(e: GroupElement, G=G):
        assert e.group == G

        e1, e2 = G.split_element(e)

        if e1 == e2:
            return e1
        else:
            return None
    return _map


def inclusion(G: DoubleGroup):
    def _map(e: GroupElement, G=G):
        assert e.group == G.G1

        return G.element(
            (e.value, e.value),
            param=f'[{e.param} | {e.param}]'
        )

    return _map


def double_group(G: Group, name: str = None):

    group_keys = {
        'G_' + k: v
        for k, v in G._keys.items()
    }
    return DoubleGroup._generator(
        G.__class__.__name__,
        name=name,
        **group_keys
    )

