import torch
import torch.nn as nn

class Contractable:
    global_bs = None

    def __init__(self, tensor, bond_str):
        shape = list(tensor.shape)
        num_dim = len(shape)
        str_len = len(bond_str)

        global_bs = Contractable.global_bs
        batch_dim = tensor.size(0)

        if ('b' not in bond_str and str_len == num_dim) or \
           ('b' == bond_str[0] and str_len == num_dim + 1):
            if global_bs is not None:
                tensor = tensor.unsqueeze(0).expand([global_bs] + shape)
            else:
                raise RuntimeError("No batch size given and no previous "
                                   "batch size set")
            if bond_str[0] != 'b':
                bond_str = 'b' + bond_str

        elif bond_str[0] != 'b' or str_len != num_dim:
            raise ValueError("Length of bond string '{bond_str}' "
                            f"({len(bond_str)}) must match order of "
                            f"tensor ({len(shape)})")

        elif global_bs is None or global_bs != batch_dim:
            Contractable.global_bs = batch_dim

        elif global_bs != batch_dim:
                raise RuntimeError(f"Batch size previously set to {global_bs}"
                                    ", but input tensor has batch size "
                                   f"{batch_dim}")
        
        self.tensor = tensor
        self.bond_str = bond_str

    def __mul__(self, contractable, rmul=False):


        if isinstance(contractable, Scalar) or \
           not hasattr(contractable, 'tensor') or \
           type(contractable) is MatRegion:
            return NotImplemented
        tensors = [self.tensor, contractable.tensor]
        bond_strs = [list(self.bond_str), list(contractable.bond_str)]
        lowercases = [chr(c) for c in range(ord('a'), ord('z')+1)]

        if rmul:
            tensors = tensors[::-1]
            bond_strs = bond_strs[::-1]

        for i, bs in enumerate(bond_strs):
            assert bs[0] == 'b'
            assert len(set(bs)) == len(bs)
            assert all([c in lowercases for c in bs])
            assert (i == 0 and 'r' in bs) or (i == 1 and 'l' in bs)

        used_chars = set(bond_strs[0]).union(bond_strs[1])
        free_chars = [c for c in lowercases if c not in used_chars]

        specials = ['b','w' ,'l', 'r']
        for i, c in enumerate(bond_strs[1]):
            if c in bond_strs[0] and c not in specials:
                bond_strs[1][i] = free_chars.pop()

        sum_char = free_chars.pop()
        bond_strs[0][bond_strs[0].index('r')] = sum_char
        bond_strs[1][bond_strs[1].index('l')] = sum_char
        specials.append(sum_char)

        out_str = ['b','w']
        for bs in bond_strs:
            out_str.extend([c for c in bs if c not in specials])
        out_str.append('l' if 'l' in bond_strs[0] else '')
        out_str.append('r' if 'r' in bond_strs[1] else '')

        bond_strs = [''.join(bs) for bs in bond_strs]
        out_str = ''.join(out_str)
        ein_str = f"{bond_strs[0]},{bond_strs[1]}->{out_str}"

        out_tensor = torch.einsum(ein_str, [tensors[0], tensors[1]])

        if out_str == 'br':
            return EdgeVec(out_tensor, is_left_vec=True)
        elif out_str == 'bl':
            return EdgeVec(out_tensor, is_left_vec=False)
        elif out_str == 'bwlr':
            return SingleMat(out_tensor)
        elif out_str == 'bolr':
            return OutputCore(out_tensor)
        else:
            return Contractable(out_tensor, out_str)

    def __rmul__(self, contractable):
        return self.__mul__(contractable, rmul=True)

    def reduce(self):
        return self

class ContractableList(Contractable):

    def __init__(self, contractable_list):
        if not isinstance(contractable_list, list) or contractable_list is []:
            raise ValueError("Input to ContractableList must be nonempty list")
        for i, item in enumerate(contractable_list):
            if not isinstance(item, Contractable):
                raise ValueError("Input items to ContractableList must be "
                                f"Contractable instances, but item {i} is not")
        self.contractable_list = contractable_list
        self.WaTen = nn.Parameter(torch.full([1],0.5)).cuda()
        self.WbTen = nn.Parameter(torch.full([1],0.5)).cuda()

    def __mul__(self, contractable, rmul=False):
        assert hasattr(contractable, 'tensor')
        output = contractable.tensor

        if rmul:
            for item in self.contractable_list:
                output = item * output
        else:
            for item in self.contractable_list[::-1]:
                output = output * item

        return output

    def __rmul__(self, contractable):
        return self.__mul__(contractable, rmul=True)

    def reduce(self, parallel_eval=False):

        c_list = self.contractable_list

        AF = None
        if parallel_eval:
            c1 =  c_list[0].reduce()
            c2 = c_list[1].reduce()
            c3 = c_list[2].reduce()
            c_list = [c1,c2,c3]
        

        while len(c_list) > 1:
            try:
                c_list[-2] = c_list[-2] * c_list[-1]
                del c_list[-1]
            except TypeError:
                c_list[1] = c_list[0] * c_list[1]
                del c_list[0]

        return c_list[0]

class MatRegion(Contractable):

    def __init__(self, mats):
        shape = list(mats.shape)
        if len(shape) not in [4, 5] or shape[-2] != shape[-1]:
            raise ValueError("MatRegion tensors must have shape "
                             "[batch_size, num_mats, D, D], or [num_mats,"
                             " D, D] if batch size has already been set")

        super().__init__(mats, bond_str='bwslr')

    def __mul__(self, edge_vec, rmul=False):
        if not isinstance(edge_vec, EdgeVec):
            return NotImplemented

        mats = self.tensor
        num_mats = mats.size(2)
        batch_size = mats.size(0)

        dummy_ind = 1 if rmul else 2
        vec = edge_vec.tensor.unsqueeze(dummy_ind)
        mat_list = [mat.squeeze(2) for mat in torch.chunk(mats, num_mats, 1)]

        log_norm = 0
        for i, mat in enumerate(mat_list[::(1 if rmul else -1)], 1):
            if rmul:
                vec = torch.bmm(vec, mat)
            else:
                vec = torch.bmm(mat, vec)
        return EdgeVec(vec.squeeze(dummy_ind), is_left_vec=rmul)

    def __rmul__(self, edge_vec):
        return self.__mul__(edge_vec, rmul=True)

    def reduce(self):

        mats = self.tensor
        shape = list(mats.shape)
        batch_size = mats.size(0)
        size, D = shape[2:4]
        while size > 1:
            odd_size = (size % 2 == 1)
            half_size = size // 2
            nice_size = 2 * half_size
        
            even_mats = mats[:,:, 0:nice_size:2]
            odd_mats = mats[:,:, 1:nice_size:2]
            leftover = mats[:,:, nice_size:]

            mats = torch.einsum('bwslu,bwsur->bwslr', [even_mats, odd_mats])
            mats = torch.cat([mats, leftover], 2)

            size = half_size + int(odd_size)
        return SingleMat(mats.squeeze(2))

class OutputCore(Contractable):

    def __init__(self, tensor):
        if len(tensor.shape) not in [3, 4]:
            raise ValueError("OutputCore tensors must have shape [batch_size, "
                             "output_dim, D_l, D_r], or else [output_dim, D_l,"
                             " D_r] if batch size has already been set")

        super().__init__(tensor, bond_str='bolr')

class SingleMat(Contractable):
    def __init__(self, mat):
        if len(mat.shape) not in [3, 4]:
            raise ValueError("SingleMat tensors must have shape [batch_size, "
                             "D_l, D_r], or else [D_l, D_r] if batch size "
                             "has already been set")

        super().__init__(mat, bond_str='bwlr')

class OutputMat(Contractable):

    def __init__(self, mat, is_left_mat):
        if len(mat.shape) not in [2, 3]:
            raise ValueError("OutputMat tensors must have shape [batch_size, "
                             "D, output_dim], or else [D, output_dim] if "
                             "batch size has already been set")

        bond_str = 'b' + ('r' if is_left_mat else 'l') + 'o'
        super().__init__(mat, bond_str=bond_str)

    def __mul__(self, edge_vec, rmul=False):
        if not isinstance(edge_vec, EdgeVec):
            raise NotImplemented
        else:
            return super().__mul__(edge_vec, rmul)

    def __rmul__(self, edge_vec):
        return self.__mul__(edge_vec, rmul=True)

class EdgeVec(Contractable):

    def __init__(self, vec, is_left_vec):
        if len(vec.shape) not in [1, 2]:
            raise ValueError("EdgeVec tensors must have shape "
                             "[batch_size, D], or else [D] if batch size "
                             "has already been set")

        bond_str = 'b' + ('r' if is_left_vec else 'l')
        super().__init__(vec, bond_str=bond_str)

    def __mul__(self, right_vec):
        if not isinstance(right_vec, EdgeVec):
            return NotImplemented

        left_vec = self.tensor.unsqueeze(1)
        right_vec = right_vec.tensor.unsqueeze(2)
        batch_size = left_vec.size(0)

        scalar = torch.bmm(left_vec, right_vec).view([batch_size])

        return Scalar(scalar)

class Scalar(Contractable):
    def __init__(self, scalar):
        shape = list(scalar.shape)
        if shape is []:
            scalar = scalar.view([1])
            shape = [1]
            
        if len(shape) != 1:
            raise ValueError("input scalar must be a torch tensor with shape "
                             "[batch_size], or [] or [1] if batch size has "
                             "been set")

        super().__init__(scalar, bond_str='b')

    def __mul__(self, contractable):
        scalar = self.tensor
        tensor = contractable.tensor
        bond_str = contractable.bond_str

        ein_string = f"{bond_str},b->{bond_str}"
        out_tensor = torch.einsum(ein_string, [tensor, scalar])

        contract_class = type(contractable)
        if contract_class is not Contractable:
            return contract_class(out_tensor)
        else:
            return Contractable(out_tensor, bond_str)

    def __rmul__(self, contractable):
        return self.__mul__(contractable)
