import torch
from e3nn import o3
# import pytorch_block_sparse as pbs
# import ipdb
from typing import List, Tuple, Dict, Any, Optional, Union, NamedTuple, Callable
from collections import OrderedDict
import time
from e3nn.util.codegen import CodeGenMixin
from e3nn.util.jit import compile_mode
# import pytorch_block_sparse as pbs
# from pytorch_block_sparse import BlockSparseLinear
from torch import fx

def linear_init(module):
    for m in module.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.zeros_(m.bias) 
    
    
# This is adopted from e3nn/o3/_tensor_product/_codegen.py
def codegen_get_invariant(
    irreps_in1: o3.Irreps,
    irreps_in2: o3.Irreps,
    irreps_out: o3.Irreps,
) -> fx.GraphModule:
    graph = fx.Graph()
    
    irreps_in12 = (irreps_in1 + irreps_in2).sort().irreps.simplify()
            
        # = Function definitions =
    tracer = fx.proxy.GraphAppendingTracer(graph)
    constants = OrderedDict()
    
    x1s = fx.Proxy(graph.placeholder("x1", torch.Tensor), tracer=tracer)
    x2s = fx.Proxy(graph.placeholder("x2", torch.Tensor), tracer=tracer)
    
    # compute broadcasted inputs to support the broadcasting in sinet 
    empty = fx.Proxy(graph.call_function(torch.empty, ((),), dict(device="cpu")), tracer=tracer)
    output_shape = torch.broadcast_tensors(empty.expand(x1s.shape[:-1]), empty.expand(x2s.shape[:-1]))[0].shape
    
    del empty
    
    
    x1s, x2s = x1s.broadcast_to(output_shape + (-1,)), x2s.broadcast_to(output_shape + (-1,))
        
    output_shape = output_shape + (irreps_out.dim,)
    
    x1s = x1s.reshape(-1, irreps_in1.dim)
    x2s = x2s.reshape(-1, irreps_in2.dim)
    
    batch_numel = x1s.shape[0]
    
    if len(irreps_in1) == 1:
        x1_list = [x1s.reshape(batch_numel, irreps_in1[0].mul, irreps_in1[0].ir.dim)]
    else:
        x1_list = [
            x1s[:, i].reshape(batch_numel, mul_ir.mul, mul_ir.ir.dim) for i, mul_ir in zip(irreps_in1.slices(), irreps_in1)
        ]
    
    if len(irreps_in2) == 1:
        x2_list = [x2s.reshape(batch_numel, irreps_in2[0].mul, irreps_in2[0].ir.dim)]
    else:
        x2_list = [
            x2s[:, i].reshape(batch_numel, mul_ir.mul, mul_ir.ir.dim) for i, mul_ir in zip(irreps_in2.slices(), irreps_in2)
        ]
    
    all_scalar = []
    l_wize_basis_dict = OrderedDict()
    
    for i, (mul, ir_i) in enumerate(irreps_in1):
        x1_i = x1_list[i] # (batch_size, mul, dim)
        if l_wize_basis_dict.get(str(ir_i)) is None:
            l_wize_basis_dict[str(ir_i)] = [x1_i]
        else:
            l_wize_basis_dict[str(ir_i)].append(x1_i)
    for i, (mul, ir_i) in enumerate(irreps_in2):
        x2_i = x2_list[i]
        if l_wize_basis_dict.get(str(ir_i)) is None:
            l_wize_basis_dict[str(ir_i)] = [x2_i]
        else:
            l_wize_basis_dict[str(ir_i)].append(x2_i)
            
    for ir_i, basis_i in l_wize_basis_dict.items():
        ir_i = o3.Irrep(ir_i)
        if len(basis_i) == 1:
            basis_i = basis_i[0]
        else:
            basis_i = torch.cat(basis_i, dim=-2) # (batch_size, mul1+mul2+..., dim)
        # use str as key to avoid `invalid decimal literal` error in torchscript
        l_wize_basis_dict[str(ir_i)] = basis_i
        if ir_i.dim == 1:
            all_scalar.append(basis_i.squeeze(-1)) # (batch_size, MUL)
            continue
        
        n_basis_i = irreps_in12.count(ir_i)
        # inner product of any two not-equal vectors
        indices_name = f"_indices_{n_basis_i}"
        indices = fx.Proxy(graph.get_attr(indices_name, torch.Tensor), tracer=tracer)
        indices_1, indices_2 = indices
        # This is equivalent to basis_i[:, indices_1, :] * basis_i[:, indices_2, :]
        indices_1 = indices_1[..., None].expand(batch_numel, -1, ir_i.dim)
        indices_2 = indices_2[..., None].expand(batch_numel, -1, ir_i.dim)
        f1 = basis_i.gather(1, indices_1)
        f2 = basis_i.gather(1, indices_2)
        unique_invariants = (f1 * f2).sum(dim=-1)
        all_scalar.append(unique_invariants)
        
        # Remove unused indices and add unstored indices
        if len(indices.node.users) == 0:
            graph.erase_node(indices.node)
        else:
            if indices_name not in constants:
                constants[indices_name] = torch.triu_indices(n_basis_i, n_basis_i, offset=0)

    if len(all_scalar) == 1:
        all_scalar = all_scalar[0]
    else:
        all_scalar = torch.cat(all_scalar, dim=1)
        
    l_wize_basis_dict_return = {}
    for k, v in l_wize_basis_dict.items():
        l_wize_basis_dict_return[k] = v.node
    output = (all_scalar.node, l_wize_basis_dict_return)
    graph.output(output)
    
    graph.lint()
    # Make GraphModules
    constants_root = torch.nn.Module()
    for key, value in constants.items():
        constants_root.register_buffer(key, value)
    
    graphmodule = fx.GraphModule(constants_root, graph)
    return graphmodule
    
    
def codegen_combine_basis_coeff(
    irreps_out: o3.Irreps,
) -> fx.GraphModule:
    graph = fx.Graph()
    
    # = Function definitions =
    tracer = fx.proxy.GraphAppendingTracer(graph)
        
    coeffs = fx.Proxy(graph.placeholder("coeffs", torch.Tensor), tracer=tracer)
    l_wize_basis_dict = fx.Proxy(graph.placeholder("l_wize_basis_dict", Dict[str, torch.Tensor]), tracer=tracer)
    coeff_begin = 0
    out = []
    for mul, ir in irreps_out:
        if ir == o3.Irrep('0e'):
            results = coeffs[..., coeff_begin:coeff_begin+mul]
            out.append(results)
            coeff_begin += mul
            continue
        
        basis_i = l_wize_basis_dict[str(ir)]
        batch_size = basis_i.shape[0]
        # assert batch_size == coeffs.shape[0]
        mul_fused_basis = basis_i.shape[1]
        coeff_num = mul * mul_fused_basis
        # coeff: (batch_size, mul_fused_basis, mul_out)
        coeff = coeffs[:, coeff_begin:coeff_begin + coeff_num].reshape(-1, mul_fused_basis, mul) / mul_fused_basis
        # result: (batch_size, mul_out, dim)
        # result = torch.einsum('bmd,bmn->bnd', basis_i, coeff)
        # use torch.sum & multiply instead of torch.einsum
        basis_i = basis_i.unsqueeze(2)
        coeff = coeff.unsqueeze(3)
        result = torch.sum(basis_i * coeff, dim=1)
        result = result.reshape(batch_size, -1)
        out.append(result)
        
        coeff_begin = coeff_begin + coeff_num
        
    if len(out) == 1:
        out = out[0]
    else:
        out = torch.cat(out, dim=-1)
        
    graph.output(out.node, torch.Tensor)
    
    # check graphs
    graph.lint()
    
    # Make GraphModules
    constants_root = torch.nn.Module()
    
    graphmodule = fx.GraphModule(constants_root, graph)
    return graphmodule
    
    
@compile_mode('script')
class Sinet2(CodeGenMixin, torch.nn.Module):
    def __init__(self,
                 irreps_in1: o3.Irreps,
                 irreps_in2: o3.Irreps,
                 irreps_out: o3.Irreps,
                 hidden_dim: int = 128,
                 activation: Callable = torch.nn.ReLU()
    ):
        super().__init__()
        irreps_in12 = (irreps_in1 + irreps_in2).sort().irreps.simplify()
        num_invariant_in = 0
        num_invariant_out = 0
        self.weight_numel = hidden_dim
        for mul, ir in irreps_in12:
            if ir.dim == 1:
                num_invariant_in += mul
            else:
                num_invariant_in += mul * (mul + 1) // 2
        for mul, ir in irreps_out:
            if ir.dim == 1:
                num_invariant_out += mul
            else:
                num_invariant_out += mul * irreps_in12.count(ir)
                
        self.mlp_encode = torch.nn.Sequential(
            torch.nn.Linear(num_invariant_in, hidden_dim),
            torch.nn.LayerNorm(hidden_dim),
            activation
        )
        self.mlp_decode = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, num_invariant_out)
        )
        linear_init(self.mlp_encode)
        linear_init(self.mlp_decode)
        self.mlp_encode = torch.jit.script(self.mlp_encode)
        self.mlp_decode = torch.jit.script(self.mlp_decode)
        
        graphmodule_get_invariant = codegen_get_invariant(irreps_in1, irreps_in2, irreps_out)
        graphmodule_combine_basis_coeff = codegen_combine_basis_coeff(irreps_out)
        
        self._codegen_register({"_compiled_sinet": graphmodule_get_invariant,
                                "_compiled_combine_basis_coeff": graphmodule_combine_basis_coeff})
    
    
    def forward(self, x1, x2, weights=None):
        invariant_features, l_wize_basis_dict = self._compiled_sinet(x1, x2)
        invariant_features = self.mlp_encode(invariant_features)
        if weights is not None:
            invariant_features = invariant_features + weights
        coeffs = self.mlp_decode(invariant_features)
        out = self._compiled_combine_basis_coeff(coeffs, l_wize_basis_dict)
        return out
