from typing import List, Tuple

import torch
from e3nn.util.jit import compile_mode
from LieCG import so13


def tp_out_irreps_with_instructions(
    irreps1: so13.Lorentz_Irreps,
    irreps2: so13.Lorentz_Irreps,
    target_irreps: so13.Lorentz_Irreps,
) -> Tuple[so13.Lorentz_Irreps, List]:
    trainable = True

    # Collect possible irreps and their instructions
    irreps_out_list: List[Tuple[int, so13.Lorentz_Irreps]] = []
    instructions = []
    for i, (mul, ir_in) in enumerate(irreps1):
        for j, (_, ir_edge) in enumerate(irreps2):
            for ir_out in ir_in * ir_edge:  # | l1 - l2 | <= l <= l1 + l2 # | k1 - k2 | <= k <= k1 + k2
                if ir_out in target_irreps:
                    k = len(irreps_out_list)  # instruction index
                    irreps_out_list.append((mul, ir_out))
                    instructions.append((i, j, k, 'uvu', trainable))

    # We sort the output irreps of the tensor product so that we can simplify them
    # when they are provided to the second o3.Linear
    irreps_out = so13.Lorentz_Irreps(irreps_out_list)
    irreps_out, permut, _ = irreps_out.sort()

    # Permute the output indexes of the instructions to match the sorted irreps:
    instructions = [(i_in1, i_in2, permut[i_out], mode, train)
                    for i_in1, i_in2, i_out, mode, train in instructions]
    return irreps_out, instructions


def linear_out_irreps(
        irreps: so13.Lorentz_Irreps,
        target_irreps: so13.Lorentz_Irreps) -> so13.Lorentz_Irreps:
    # Assuming simplified irreps
    irreps_mid = []
    for _, ir_in in irreps:
        found = False

        for mul, ir_out in target_irreps:
            if ir_in == ir_out:
                irreps_mid.append((mul, ir_out))
                found = True
                break

        if not found:
            raise RuntimeError(f'{ir_in} not in {target_irreps}')

    return so13.Lorentz_Irreps(irreps_mid)


@compile_mode("script")
class reshape_irreps(torch.nn.Module):

    def __init__(self, irreps: so13.Lorentz_Irreps) -> None:
        super().__init__()
        self.irreps = irreps

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        ix = 0
        out = []
        batch, _ = tensor.shape
        for mul, ir in self.irreps:
            d = ir.dim
            field = tensor[:, ix:ix + mul * d]  # [batch, sample, mul * repr]
            ix += mul * d
            field = field.reshape(batch, mul, d)
            out.append(field)
        return torch.cat(out, dim=-1)
