import torch
import time
import copy
import torchvision
import numpy as np
import torch.nn as nn
import torch_geometric.nn as agf_module
import typing
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.data import Data
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import (
    Adj,
    NoneType,
    OptTensor,
    PairTensor,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import (
    add_self_loops,
    is_torch_sparse_tensor,
    remove_self_loops,
    softmax,
)
from torch_geometric.utils.sparse import set_sparse_value

if typing.TYPE_CHECKING:
    from typing import overload
else:
    from torch.jit import _overload_method as overload
    

class IKModule(nn.Module):
    def __init__(self, hidden_size=512):
        super(IKModule, self).__init__()
        self.encoder = nn.Sequential(nn.Linear(in_features=7, out_features=hidden_size), nn.ReLU(), nn.Dropout(0.),
                                     nn.Linear(in_features=hidden_size, out_features=hidden_size), nn.ReLU(), nn.Dropout(0.),
                                     nn.Linear(in_features=hidden_size, out_features=hidden_size), nn.ReLU(), nn.Dropout(0.))
        self.decoder = nn.Linear(in_features=hidden_size, out_features=5)

    def forward(self, x, mode = "predict"):
        torch.use_deterministic_algorithms(False)
        enc = self.encoder(self.scale(x))
        pred = self.decoder(enc)
        if mode != "train":
            pred = pred.sigmoid()
        return pred
    
    def scale(self, x):
        x[:, 6] = (x[:, 6] % (2*np.pi))
        return x
    
class GOModule(nn.Module):
    def __init__(self, hidden_size=512):
        super(GOModule, self).__init__()
        self.encoder = nn.Sequential(nn.Linear(in_features=14, out_features=hidden_size), nn.ReLU(), nn.Dropout(0.),
                                     nn.Linear(in_features=hidden_size, out_features=hidden_size), nn.ReLU(), nn.Dropout(0.),
                                     nn.Linear(in_features=hidden_size, out_features=hidden_size), nn.ReLU(), nn.Dropout(0.))
        self.decoder = nn.Linear(in_features=hidden_size, out_features=5)

    def forward(self, x, mode = "predict", mask=None):
        torch.use_deterministic_algorithms(False)
        enc = self.encoder(self.scale(x))
        pred = self.decoder(enc)
        pred = nn.functional.hardtanh(pred, min_val=0., max_val=1.)
        if mode != "train":
            if mask is not None:
                pred = pred * mask
        return pred
    
    def scale(self, x):
        x[:, 6] = (x[:, 6] % (2*np.pi))
        x[:, 13] = (x[:, 13] % (2*np.pi))
        return x

class AGFModule(nn.Module):
    def __init__(self):
        super(AGFModule, self).__init__()
        self.node_encoder = nn.Sequential(nn.Linear(7, 256), nn.ReLU())
        self.edge_encoder = nn.Sequential(nn.Linear(7, 256), nn.ReLU())
        self.conv = EdgeGATv2Conv(in_channels=256, out_channels=256, heads=4, edge_dim=256, add_self_loops=False)
        self.decoder = nn.Sequential(nn.Linear(256*4, 256), nn.ReLU(), nn.Linear(256, 6))

    def forward(self, data, mode = "predict"):
        node_enc = self.node_encoder(data.x)
        attr_enc = self.edge_encoder(data.edge_attr)
        node_enc = self.conv(node_enc, data.edge_index, attr_enc)
        pred = self.decoder(node_enc)
        if mode != "train":
            pred = pred.sigmoid()
        return pred
    
class GRN(nn.Module):
    def __init__(self, go_hidden_size=512, ik_hidden_size=512, args=None):
        super(GRN, self).__init__()
        self.args = args
        self.ik_module = IKModule(ik_hidden_size)
        self.go_module = GOModule(go_hidden_size)
        self.agf_module = AGFModule()

    def forward(self, data, mode = "predict", return_data = False):
        data.x = self.scale(data.x)
        IK_preds = torch.ones((data.x.shape[0], 5)).to(self.args.device)
        IK_masks = torch.ones((data.x.shape[0], 5)).to(self.args.device)
        GO_preds = torch.zeros((data.edge_attr.shape[0], 5)).to(self.args.device)
        if "IK" in self.args.edge_features or "GO" in self.args.edge_features:
            data.edge_attr = torch.cat((data.edge_attr, torch.zeros(data.edge_attr.shape[0], 5).to(self.args.device)), dim = 1)
            if "IK" in self.args.edge_features:
                IK_preds[data.movable_mask == 1] = self.ik_module(data.x[data.movable_mask == 1], "train")
                IK_masks[data.movable_mask == 1] = IK_preds[data.movable_mask == 1].sigmoid()
        
            if "GO" in self.args.edge_features:
                neighbors = data.edge_index[:, data.proximity_mask == True]
                GO_input_features = torch.cat((data.x[neighbors[1]], data.x[neighbors[0]]), dim=1)
                GO_preds[data.proximity_mask == True] = self.go_module(GO_input_features, "predict", IK_masks[neighbors[1]])
                data.edge_attr[:, -5:] = GO_preds

        mask = torch.where(data.movable_mask == 1)[0].to(self.args.device)
        edges = torch.cat((mask.unsqueeze(0), mask.unsqueeze(0)), dim = 0)
        data.edge_index = torch.cat((data.edge_index, edges), dim = 1)
        edge_features = torch.tensor([[0, 1] for _ in range(mask.shape[0])]).to(self.args.device)
        if "IK" in self.args.edge_features or "GO" in self.args.edge_features:
            edge_features = torch.cat((edge_features, torch.ones(edge_features.shape[0], 5).to(self.args.device) - IK_masks[mask]), 
                                    dim = 1).to(self.args.device)
        data.edge_attr = torch.cat((data.edge_attr, edge_features), dim = 0).to(self.args.device)
        #data.proximity_mask = torch.cat((data.proximity_mask, torch.zeros((mask.shape[0], 1), dtype = bool).to(self.args.device)), dim = 0)

        F_preds = self.agf_module(data, mode)
        if mode != "train":
            IK_preds = IK_preds.sigmoid()
        if return_data:
            return F_preds, IK_preds, GO_preds, data
        else:
            return F_preds, IK_preds, GO_preds
    
        
    def scale(self, x):
        x[:, 6] = (x[:, 6] % (2*np.pi))
        return x
    
    def predict_from_scene(self, scene):
        data = self.to_graph(scene)
        feasibility_preds, IK_features, GO_features, data = self.forward(data.to(self.args.device), return_data = True)
        return feasibility_preds, IK_features, GO_features, data
    
    def compute_distance(self, pose1, pose2):
        return np.linalg.norm(np.array(pose1) - np.array(pose2))

    def compute_threshold(self, dim1, dim2):
        return (max(dim1) + max(dim2) + 0.6) / 2
    
    def is_neighbor(self, dim1, dim2, pose1, pose2):
        distance = self.compute_distance(pose1[:2], pose2[:2])
        threshold = self.compute_threshold(dim1[:2], dim2[:2])
        if distance > threshold:
            return False
        else:
            return True
    
    def to_graph(self, scene):
        objects = list(scene["objects"].keys())
        indices = {obj: i for i, obj in enumerate(objects)}
        nodes = torch.zeros((len(objects), 7))
        movable_mask = torch.zeros((len(objects)), dtype = bool)
        frame_ids = torch.zeros((len(objects)), dtype = int)
        pos = torch.empty((0,4))
        edges = torch.empty(2, 0)
        proximity_mask = torch.empty(0, dtype = bool)
        edge_features = torch.empty(0, 2)

        for obj in objects:
            scene["objects"] = self.compute_abs_poses(scene["objects"], obj)

        for i, obj in enumerate(objects):
            object_ = scene["objects"][obj]
            frame_id = object_["frame_id"]
            #===================================== Nodes =====================================
            if object_["fixed"]:
                movable_mask[i] = 0
            else:
                movable_mask[i] = 1

            node_features = object_["dimensions"] + object_["abs_pose"][:3] + [object_["abs_pose"][-1]]
            nodes[i] = torch.tensor(node_features).unsqueeze(0)
            if frame_id == "world" or frame_id == "odom_combined":
                frame_ids[i] = -1
            else:
                frame_ids[i] = indices[frame_id]

            pos = torch.cat((pos, torch.tensor(object_["abs_pose"][:3] + [object_["abs_pose"][-1]]).unsqueeze(0)), dim = 0)

            #===================================== Edges =====================================
            if not object_["fixed"]:
                if frame_id != "world" and frame_id != "odom_combined":
                    edge = [indices[frame_id], indices[obj]]
                    edges = torch.cat((edges, torch.tensor(edge).unsqueeze(1)), dim=1)
                    proximity_mask = torch.cat((proximity_mask, torch.tensor(True).unsqueeze(0)), dim = 0)
                    edge_features = torch.cat((edge_features, torch.tensor([1, 0]).unsqueeze(0)), dim = 0)

                for neighbor in objects:
                    if neighbor == obj or neighbor == frame_id or (frame_id != "base" and neighbor == "base"):    
                        continue
                    if not self.is_neighbor(scene["objects"][obj]["dimensions"], scene["objects"][neighbor]["dimensions"], 
                                            scene["objects"][obj]["abs_pose"], scene["objects"][neighbor]["abs_pose"]):
                        continue
                    edge = [indices[neighbor], indices[obj]]
                    edges = torch.cat((edges, torch.tensor(edge).unsqueeze(1)), dim=1)
                    proximity_mask = torch.cat((proximity_mask, torch.tensor(True).unsqueeze(0)), dim = 0)
                    edge_features = torch.cat((edge_features, torch.tensor([1, 0]).unsqueeze(0)), dim = 0)

        base_index = indices["base"]
        base_mask = torch.tensor([False if n != base_index else True for n in range(len(nodes))])
        data = Data(x = nodes, movable_mask = movable_mask, frame_ids = frame_ids, pos = pos, base_mask = base_mask,
                    edge_index = edges.long(), proximity_mask = proximity_mask, edge_attr = edge_features.float())
            
        return data

    def compute_abs_pose(self, object_rel_pose, frame_abs_pose):
        support_rot = np.array([[np.cos(frame_abs_pose[-1]), -1*np.sin(frame_abs_pose[-1]), 0],
                                [np.sin(frame_abs_pose[-1]), np.cos(frame_abs_pose[-1]), 0],
                                [0, 0, 1]])
        support_trans = np.array(frame_abs_pose[:3])
        object_trans = np.array(object_rel_pose[:3])
        abs_pose = np.matmul(support_rot, object_trans) + support_trans
        abs_yaw = frame_abs_pose[-1] + object_rel_pose[-1]
        return [abs_pose[0], abs_pose[1], abs_pose[2], 0., 0., abs_yaw]

    def compute_abs_poses(self, scene, object_id):
        if scene[object_id]["frame_id"] == "world" or scene[object_id]["frame_id"] == "odom_combined":
            scene[object_id]["abs_pose"] = copy.deepcopy(scene[object_id]["pose"])
        elif "abs_pose" not in scene[object_id]:
            if "abs_pose" not in scene[scene[object_id]["frame_id"]]:
                scene = self.compute_abs_poses(scene, scene[object_id]["frame_id"])
                
            scene[object_id]["abs_pose"] = self.compute_abs_pose(scene[object_id]["pose"], scene[scene[object_id]["frame_id"]]["abs_pose"])
        return scene
    
    
class EdgeGATv2Conv(MessagePassing):
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        concat: bool = True,
        negative_slope: float = 0.2,
        dropout: float = 0.0,
        add_self_loops: bool = True,
        edge_dim: Optional[int] = None,
        fill_value: Union[float, Tensor, str] = 'mean',
        bias: bool = True,
        share_weights: bool = False,
        **kwargs,
    ):
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.edge_dim = edge_dim
        self.fill_value = fill_value
        self.share_weights = share_weights

        if isinstance(in_channels, int):
            self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,
                                weight_initializer='glorot')
            if share_weights:
                self.lin_r = self.lin_l
            else:
                self.lin_r = Linear(in_channels, heads * out_channels,
                                    bias=bias, weight_initializer='glorot')
        else:
            self.lin_l = Linear(in_channels[0], heads * out_channels,
                                bias=bias, weight_initializer='glorot')
            if share_weights:
                self.lin_r = self.lin_l
            else:
                self.lin_r = Linear(in_channels[1], heads * out_channels,
                                    bias=bias, weight_initializer='glorot')

        self.att = Parameter(torch.empty(1, heads, 3*out_channels))

        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False,
                                   weight_initializer='glorot')
        else:
            self.lin_edge = None

        if bias and concat:
            self.bias = Parameter(torch.empty(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)

        self.out_lin = Linear(2*out_channels, out_channels)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        if self.lin_edge is not None:
            self.lin_edge.reset_parameters()
        glorot(self.att)
        zeros(self.bias)

    @overload
    def forward(
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        return_attention_weights: NoneType = None,
    ) -> Tensor:
        pass

    @overload
    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        return_attention_weights: bool = None,
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        pass

    @overload
    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: SparseTensor,
        edge_attr: OptTensor = None,
        return_attention_weights: bool = None,
    ) -> Tuple[Tensor, SparseTensor]:
        pass

    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        return_attention_weights: Optional[bool] = None,
    ) -> Union[
            Tensor,
            Tuple[Tensor, Tuple[Tensor, Tensor]],
            Tuple[Tensor, SparseTensor],
    ]:
        r"""Runs the forward pass of the module.

        Args:
            x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
                features.
            edge_index (torch.Tensor or SparseTensor): The edge indices.
            edge_attr (torch.Tensor, optional): The edge features.
                (default: :obj:`None`)
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels

        x_l: OptTensor = None
        x_r: OptTensor = None
        if isinstance(x, Tensor):
            assert x.dim() == 2
            x_l = self.lin_l(x).view(-1, H, C)
            if self.share_weights:
                x_r = x_l
            else:
                x_r = self.lin_r(x).view(-1, H, C)
        else:
            x_l, x_r = x[0], x[1]
            assert x[0].dim() == 2
            x_l = self.lin_l(x_l).view(-1, H, C)
            if x_r is not None:
                x_r = self.lin_r(x_r).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                num_nodes = x_l.size(0)
                if x_r is not None:
                    num_nodes = min(num_nodes, x_r.size(0))
                edge_index, edge_attr = remove_self_loops(
                    edge_index, edge_attr)
                edge_index, edge_attr = add_self_loops(
                    edge_index, edge_attr, fill_value=self.fill_value,
                    num_nodes=num_nodes)
            elif isinstance(edge_index, SparseTensor):
                if self.edge_dim is None:
                    edge_index = torch_sparse.set_diag(edge_index)
                else:
                    raise NotImplementedError(
                        "The usage of 'edge_attr' and 'add_self_loops' "
                        "simultaneously is currently not yet supported for "
                        "'edge_index' in a 'SparseTensor' form")
                
        # edge_updater_type: (x: PairTensor, edge_attr: OptTensor)
        alpha, edge_attr = self.edge_updater(edge_index, x=(x_l, x_r),
                                  edge_attr=edge_attr)

        # propagate_type: (x: PairTensor, alpha: Tensor)
        out = self.propagate(edge_index, x=(x_l, x_r), edge_attr=edge_attr, alpha=alpha)

        out = F.leaky_relu(self.out_lin(out), self.negative_slope)

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out = out + self.bias

        # out = F.leaky_relu(self.out_lin(out), self.negative_slope)

        if isinstance(return_attention_weights, bool):
            if isinstance(edge_index, Tensor):
                if is_torch_sparse_tensor(edge_index):
                    # TODO TorchScript requires to return a tuple
                    adj = set_sparse_value(edge_index, alpha)
                    return out, (adj, alpha)
                else:
                    return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor,
                    index: Tensor, ptr: OptTensor,
                    dim_size: Optional[int]) -> Tensor:
        x = torch.cat([x_i, x_j], dim=-1)

        if edge_attr is not None:
            if edge_attr.dim() == 1:
                edge_attr = edge_attr.view(-1, 1)
            assert self.lin_edge is not None
            edge_attr = self.lin_edge(edge_attr)
            edge_attr = edge_attr.view(-1, self.heads, self.out_channels)
            x = torch.cat([x, edge_attr], dim=-1)

        x = F.leaky_relu(x, self.negative_slope)
        alpha = (x * self.att).sum(dim=-1)
        alpha = softmax(alpha, index, ptr, dim_size)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return alpha, edge_attr

    def message(self, x_j: Tensor, edge_attr: OptTensor, alpha: Tensor) -> Tensor:
        if edge_attr is not None:
            x = torch.cat([x_j, edge_attr], dim=-1)
        else:
            x = x_j
        return x * alpha.unsqueeze(-1)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')
