from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.utils import subgraph

from shapely.geometry import Polygon
import cv2
import math

import random

import global_vars


class TemporalData(Data):

    def __init__(self,
                 x: Optional[torch.Tensor] = None,
                 positions: Optional[torch.Tensor] = None,
                 edge_index: Optional[torch.Tensor] = None,
                 edge_attrs: Optional[List[torch.Tensor]] = None,
                 y: Optional[torch.Tensor] = None,
                 num_nodes: Optional[int] = None,
                 padding_mask: Optional[torch.Tensor] = None,
                 bos_mask: Optional[torch.Tensor] = None,
                 rotate_angles: Optional[torch.Tensor] = None,
                 lane_vectors: Optional[torch.Tensor] = None,
                 is_intersections: Optional[torch.Tensor] = None,
                 turn_directions: Optional[torch.Tensor] = None,
                 traffic_controls: Optional[torch.Tensor] = None,
                 lane_actor_index: Optional[torch.Tensor] = None,
                 lane_actor_vectors: Optional[torch.Tensor] = None,
                 seq_id: Optional[int] = None,
                 **kwargs) -> None:
        if x is None:
            super(TemporalData, self).__init__()
            return
        super(TemporalData, self).__init__(x=x, positions=positions, edge_index=edge_index, y=y, num_nodes=num_nodes,
                                           padding_mask=padding_mask, bos_mask=bos_mask, rotate_angles=rotate_angles,
                                           lane_vectors=lane_vectors, is_intersections=is_intersections,
                                           turn_directions=turn_directions, traffic_controls=traffic_controls,
                                           lane_actor_index=lane_actor_index, lane_actor_vectors=lane_actor_vectors,
                                           seq_id=seq_id, **kwargs)
        if edge_attrs is not None:
            for t in range(self.x.size(1)):
                self[f'edge_attr_{t}'] = edge_attrs[t]

    def __inc__(self, key, value):
        if key == 'lane_actor_index':
            return torch.tensor([[self['lane_vectors'].size(0)], [self.num_nodes]])
        else:
            return super().__inc__(key, value)


class DistanceDropEdge(object):

    def __init__(self, max_distance: Optional[float] = None) -> None:
        self.max_distance = max_distance

    def __call__(self,
                 edge_index: torch.Tensor,
                 edge_attr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.max_distance is None:
            return edge_index, edge_attr
        row, col = edge_index
        mask = torch.norm(edge_attr, p=2, dim=-1) < self.max_distance
        edge_index = torch.stack([row[mask], col[mask]], dim=0)
        edge_attr = edge_attr[mask]
        return edge_index, edge_attr

class MLPEmbedding(nn.Module):

    def __init__(self,
                 in_channel: int,
                 out_channel: int) -> None:
        super(MLPEmbedding, self).__init__()
        self.embed = nn.Sequential(
            nn.Linear(in_channel, out_channel),
            nn.LayerNorm(out_channel),
            nn.ReLU(inplace=True),
            nn.Linear(out_channel, out_channel),
            nn.LayerNorm(out_channel),
            nn.ReLU(inplace=True),
            nn.Linear(out_channel, out_channel),
            nn.LayerNorm(out_channel))
        self.apply(init_weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.embed(x)

class V2XDistanceDropEdge(object):

    def __init__(self, max_distance: Optional[float] = None) -> None:
        self.max_distance = max_distance

    def __call__(self,
                 edge_index: torch.Tensor,
                 edge_attr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
                 
        if self.max_distance is None:
            return edge_index, edge_attr

        row, col = edge_index
        mask = torch.norm(edge_attr, p=2, dim=-1) < self.max_distance
        edge_index = torch.stack([row[mask], col[mask]], dim=0)
        edge_attr = edge_attr[mask]
        return edge_index, edge_attr, mask


class IOUDropEdge(object):

    def __init__(self, iou_threshold: Optional[float] = None) -> None:
        self.iou_threshold = iou_threshold

    def __call__(self,
                 edge_index: torch.Tensor,
                 edge_attr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.iou_threshold is None:
            return edge_index, edge_attr
        row, col = edge_index
        mask = torch.norm(edge_attr, p=2, dim=-1) < self.iou_threshold
        edge_index = torch.stack([row[mask], col[mask]], dim=0)
        edge_attr = edge_attr[mask]
        return edge_index, edge_attr

def intersection(g, p):
    g=np.asarray(g)
    p=np.asarray(p)
    g = Polygon(g[:8].reshape((4, 2)))
    p = Polygon(p[:8].reshape((4, 2)))
    if not g.is_valid or not p.is_valid:
        return 0
    inter = Polygon(g).intersection(Polygon(p)).area
    union = g.area + p.area - inter
    if union == 0:
        return 0
    else:
        return inter/union

def aabox_overlap(tra1, tra2):

    min_x1, max_x1 = min(tra1[:, 0]), max(tra1[:, 0])
    min_x2, max_x2 = min(tra2[:, 0]), max(tra2[:, 0])

    if min_x1 < min_x2:
        if max_x1 < min_x2:
            return 0
    if min_x2 < min_x1:
        if max_x2 < min_x1:
            return 0

    min_y1, max_y1 = min(tra1[:, 1]), max(tra1[:, 1])
    min_y2, max_y2 = min(tra2[:, 1]), max(tra2[:, 1])

    if min_y1 < min_y2:
        if max_y1 < min_y2:
            return 0
    if min_y2 < min_y1:
        if max_y2 < min_y1:
            return 0

    return 1


def cal_iou(tra1, tra2):
    
    if aabox_overlap(tra1, tra2) == 0:
        return 0
    return 1
    rect1 = cv2.minAreaRect(np.float32(tra1))
    box1 = cv2.boxPoints(rect1)

    rect1 = cv2.minAreaRect(np.float32(tra2))
    box2 = cv2.boxPoints(rect1)

    return intersection(box1, box2)

def init_weights(m: nn.Module) -> None:
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        fan_in = m.in_channels / m.groups
        fan_out = m.out_channels / m.groups
        bound = (6.0 / (fan_in + fan_out)) ** 0.5
        nn.init.uniform_(m.weight, -bound, bound)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.MultiheadAttention):
        if m.in_proj_weight is not None:
            fan_in = m.embed_dim
            fan_out = m.embed_dim
            bound = (6.0 / (fan_in + fan_out)) ** 0.5
            nn.init.uniform_(m.in_proj_weight, -bound, bound)
        else:
            nn.init.xavier_uniform_(m.q_proj_weight)
            nn.init.xavier_uniform_(m.k_proj_weight)
            nn.init.xavier_uniform_(m.v_proj_weight)
        if m.in_proj_bias is not None:
            nn.init.zeros_(m.in_proj_bias)
        nn.init.xavier_uniform_(m.out_proj.weight)
        if m.out_proj.bias is not None:
            nn.init.zeros_(m.out_proj.bias)
        if m.bias_k is not None:
            nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
        if m.bias_v is not None:
            nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                for ih in param.chunk(4, 0):
                    nn.init.xavier_uniform_(ih)
            elif 'weight_hh' in name:
                for hh in param.chunk(4, 0):
                    nn.init.orthogonal_(hh)
            elif 'weight_hr' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias_ih' in name:
                nn.init.zeros_(param)
            elif 'bias_hh' in name:
                nn.init.zeros_(param)
                nn.init.ones_(param.chunk(4, 0)[1])
    elif isinstance(m, nn.GRU):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                for ih in param.chunk(3, 0):
                    nn.init.xavier_uniform_(ih)
            elif 'weight_hh' in name:
                for hh in param.chunk(3, 0):
                    nn.init.orthogonal_(hh)
            elif 'bias_ih' in name:
                nn.init.zeros_(param)
            elif 'bias_hh' in name:
                nn.init.zeros_(param)

    
def rowwise_in(a,b):
    """ 
    a - tensor of size a0,c
    b - tensor of size b0,c
    returns - tensor of size a1 with 1 for each row of a in b, 0 otherwise
    """
    '''
    # dimensions
    a0 = a.shape[0]
    b0 = b.shape[0]
    c  = a.shape[1]
    assert c == b.shape[1] , "Tensors must have same number of columns"

    a_expand = a.unsqueeze(1).expand(a0,b0,c)
    b_expand = b.unsqueeze(0).expand(a0,b0,c)

    # element-wise equality
    equal = a_expand == b_expand

    # sum along dim 2 (all elements along this dimension must be true for the summed dimension to be True)
    row_equal = torch.prod(equal,dim = 2)

    row_in_b = torch.max(row_equal, dim = 1)[0]
    return row_in_b
    '''
    shape1 = a.shape[0]
    shape2 = b.shape[0]
    c  = a.shape[1]
    assert c == b.shape[1] , "Tensors must have same number of columns"

    a_expand = a.unsqueeze(1).expand(-1,shape2,c)
    b_expand = b.unsqueeze(0).expand(shape1,-1,c)
    # element-wise equality
    mask = (a_expand == b_expand).all(-1).any(-1)
    return mask

def bool_based_on_probability(probability=0.5):
    return random.random() < probability

def save_output_asso(input, y_hat, metrics, idx):
    
    horizon = 50
    
    agent_index = input.agent_index
    seq_id = input.seq_id.squeeze().cpu().numpy()
    
    asso_idx = input.asso_idx.item()
    
    padding_agt = input.padding_mask[agent_index, :horizon]
    past_trajectory = input.positions[agent_index, :horizon, :][~padding_agt, :].squeeze().cpu().numpy()
    
    if asso_idx > 0:
        padding_asso = input.padding_mask[asso_idx, :horizon]
        traj_asso = input.positions[asso_idx, :horizon, :][~padding_asso, :].squeeze().cpu().numpy()
    else:
        padding_asso = padding_agt
        traj_asso = past_trajectory

    sample_forecasted_trajectories = y_hat[:, agent_index, :, :2].squeeze().cpu().numpy()
    sample_forecasted_trajectories_ls = []
    for modal in range(sample_forecasted_trajectories.shape[0]):
        sample_forecasted_trajectories_ls.append(np.asarray(sample_forecasted_trajectories[modal, :, :], np.float64).tolist())

    global_vars.output_dict[str(seq_id)] = {
        'asso_idx': asso_idx,
        'traj_asso': np.asarray(traj_asso, np.float64).tolist(),
        'theta_v': input.theta.item(),
        'theta_agt': input.rotate_angles[agent_index].item(),
        'past_trajectory': np.asarray(past_trajectory, np.float64).tolist(),
        'groundtruth': np.asarray(input.y[agent_index, :, :].squeeze().cpu().numpy(), np.float64).tolist(),
        'forecasted_trajectories': sample_forecasted_trajectories_ls,
        'metrics': metrics
    }
    
    return
