from typing import Dict, List, Tuple

from math import sqrt
from LieCG import so13
import numpy as np
import torch
import torch.nn
import torch.utils.data
from LorentzMACE.tools.torch_tools import to_numpy
from LorentzMACE.tools.scatter import scatter_sum


def compute_forces(energy: torch.Tensor,
                   positions: torch.Tensor,
                   training=True) -> torch.Tensor:
    gradient = torch.autograd.grad(
        outputs=energy,  # [n_graphs, ]
        inputs=positions,  # [n_nodes, 3]
        grad_outputs=torch.ones_like(energy),
        retain_graph=
        training,  # Make sure the graph is not destroyed during training
        create_graph=training,  # Create graph for second derivative
        only_inputs=True,  # Diff only w.r.t. inputs
        allow_unused=False,  # For complete dissociation turn to true
    )[0]  # [n_nodes, 3]

    return -1 * gradient


def get_edge_vectors_and_lengths(
    positions: torch.Tensor,  # [n_nodes, 3]
    edge_index: torch.Tensor,  # [2, n_edges]
    normalize: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    sender, receiver = edge_index
    vectors = positions[receiver] - positions[sender]  # [n_edges, 4]
    vectors_squared = torch.pow(vectors, 2)
    lengths_sq = (2 * vectors_squared[..., 0] -
                  vectors_squared.sum(dim=-1)).unsqueeze(-1)  # [n_edges, 1]
    mask = (lengths_sq != 0)
    lengths = lengths_sq.abs().sqrt()
    if normalize:
        vectors = torch.where(mask, vectors / lengths, vectors)
    return vectors, lengths_sq


def vectors_to_rep(
        vector: torch.Tensor,  #[batch,4]
) -> torch.Tensor:
    """real 4-vectors (no complex dimension) as a tensor
    in Cartesian coordinates (t,x,y,z) and converts to the canonical basis,
    Input: tensor of shape [batch,4]
    Output: complex tensor of size [batch,4] """
    device = vector.device
    dtype = vector.dtype
    cartesian4 = torch.tensor([[[1, 0, 0, 0], [0, 1 / sqrt(2.), 0, 0],
                                [0, 0, 0, 1], [0, -1 / sqrt(2.), 0, 0]],
                               [[0, 0, 0, 0], [0, 0, -1 / sqrt(2.), 0],
                                [0, 0, 0, 0], [0, 0, -1 / sqrt(2.), 0]]],
                              device=device,
                              dtype=dtype)
    vector = torch.unsqueeze(vector, -1)
    rep = torch.stack([
        torch.matmul(cartesian4[0], vector),
        torch.matmul(cartesian4[1], vector)
    ], -1)
    rep = torch.squeeze(torch.view_as_complex(rep), -1)
    return rep  #[batch, 4]


def minkowski_norm(
        vector: torch.Tensor,  #[batch,4]
) -> torch.Tensor:
    vector_squared = torch.pow(vector, 2)
    return 2 * vector_squared[..., 0] - vector_squared.sum(dim=-1)


def metric(irrep_in: so13.Lorentz_Irreps):
    k = irrep_in.k
    n = irrep_in.l
    met = [(1 if l == ll and m + mm == 0 else 0) * (-1)**(int(l + m))
           for l in np.arange(abs(k - n) / 2, (k + n) / 2 + 1, 1)
           for m in np.arange(-l, l + 1, 1)
           for ll in np.arange(abs(k - n) / 2, (k + n) / 2 + 1, 1)
           for mm in np.arange(-ll, ll + 1, 1)]
    return torch.tensor(met, dtype=torch.float).view([(k + 1) * (n + 1),
                                                      (k + 1) * (n + 1)])


def compute_avg_num_neighbors(
        data_loader: torch.utils.data.DataLoader) -> float:
    num_neighbors = []

    for batch in data_loader:
        _, receivers = batch.edge_index
        _, counts = torch.unique(receivers, return_counts=True)
        num_neighbors.append(counts)

    avg_num_neighbors = torch.mean(
        torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()))
    return to_numpy(avg_num_neighbors).item()