from torch import nn
import torch
import torch.nn.functional as F
from detectron2.modeling.poolers import ROIPooler
from detectron2.structures.boxes import Boxes
from common import static_preds, mutable_preds, unary_preds, binary_preds, norm_x, norm_y
all_preds = static_preds + mutable_preds

class MLPClassifier(nn.Module):
  def __init__(self, input_dim, latent_dim, output_dim, n_layers, dropout_rate):
    super(MLPClassifier, self).__init__()

    layers = []
    layers.append(nn.Linear(input_dim, latent_dim))
    layers.append(nn.ReLU())
    layers.append(nn.BatchNorm1d(latent_dim))
    layers.append(nn.Dropout(dropout_rate))
    for _ in range(n_layers - 1):
      layers.append(nn.Linear(latent_dim, latent_dim))
      layers.append(nn.ReLU())
      layers.append(nn.BatchNorm1d(latent_dim))
      layers.append(nn.Dropout(dropout_rate))
    layers.append(nn.Linear(latent_dim, output_dim))
    layers.append(nn.Sigmoid())
    self.net = nn.Sequential(*layers)

  def forward(self, x):
    logits = self.net(x)
    return logits

class ConvInputModel(nn.Module):
    def __init__(self):
        super(ConvInputModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 16, 4, stride=2, padding=1)
        self.batchNorm1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 16, 4, stride=2, padding=1)
        self.batchNorm2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 16, 4, stride=2, padding=1)
        self.batchNorm3 = nn.BatchNorm2d(16)
        self.conv4 = nn.Conv2d(16, 16, 4, stride=2, padding=1)
        self.batchNorm4 = nn.BatchNorm2d(16)

    def forward(self, img):
        """convolution"""
        x = self.conv1(img)
        x = F.relu(x)
        x = self.batchNorm1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.batchNorm2(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.batchNorm3(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.batchNorm4(x)
        return x


class PredicateModel(nn.Module):
  def __init__(self, latent_dim, model_layer, device, lstm_num_layers=2, object_feat_dim = 788, frame_feat_dim=6240, input_dim=7, lstm_input_size = 7028, hidden_dim=24, scales=[1/16], sampling_ratio=1, pooler_type="ROIAlignV2", dropout_rate=0.2, num_heads=8):
    super().__init__()
    self.device = device

    self.context_feature_extract = ConvInputModel()
    self.pooler = ROIPooler((input_dim, input_dim),
                             scales=scales,
                             sampling_ratio=sampling_ratio,
                             pooler_type=pooler_type)

    self.static_model_feature_extractor = nn.LSTM(lstm_input_size, latent_dim, num_layers=lstm_num_layers, bidirectional=True, batch_first=True)

    self.unary_clf_model = MLPClassifier(input_dim=lstm_input_size,
                                  output_dim=len(unary_preds),
                                  latent_dim=latent_dim,
                                  n_layers=2,
                                  dropout_rate=dropout_rate)

    self.binary_clf_model = MLPClassifier(input_dim= 2 * object_feat_dim + frame_feat_dim,
                                  output_dim=len(binary_preds),
                                  latent_dim=latent_dim,
                                  n_layers=3,
                                  dropout_rate=dropout_rate)

    self.static_clf_model = MLPClassifier(input_dim=latent_dim * 4,
                                  output_dim=len(static_preds),
                                  latent_dim=latent_dim,
                                  n_layers=2,
                                  dropout_rate=dropout_rate)


  def fake_forward(self, batched_videos, batched_bboxes, batched_object_ids, batched_obj_pairs, batched_occured_objs, \
                   batched_video_splits):

    unary_pred_out_dim = len(unary_preds)
    binary_pred_out_dim = len(binary_preds)
    static_pred_out_dim = len(static_preds)

    unary_pred_batch_size = len(batched_bboxes)
    binary_pred_batch_size = len(batched_obj_pairs)
    static_pred_batch_size = sum([len(objs) for objs in batched_occured_objs])

    unary_pred_prob = torch.rand((unary_pred_batch_size, unary_pred_out_dim)).sigmoid()
    binary_pred_prob = torch.rand((binary_pred_batch_size, binary_pred_out_dim)).sigmoid()
    static_pred_prob = torch.rand((static_pred_batch_size, static_pred_out_dim)).sigmoid()

    return unary_pred_prob, binary_pred_prob, static_pred_prob

  def forward(self, batched_videos, batched_bboxes, batched_object_ids, batched_obj_pairs, batched_occured_objs, \
              batched_video_splits):

    # 1. obtain features for videos
    batched_video_splits = [0] + batched_video_splits
    frame_features = self.context_feature_extract(batched_videos)
    current_vid, current_frame_id = -1, -1
    batched_frame_bboxes = []
    current_frame_bboxes = []
    norm_boxes = []
    batched_features = []

    for (video_id, frame_id, obj_id), bbox in zip(batched_object_ids, batched_bboxes):
      overall_frame_id = batched_video_splits[video_id] + frame_id
      norm_boxes.append((bbox['x1']/norm_x, bbox['y1']/norm_y, bbox['x2']/norm_x, bbox['y2']/norm_y))

      if video_id == current_vid and current_frame_id == frame_id:
        current_frame_bboxes.append((bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2']))
      else:
        if not len(current_frame_bboxes) == 0:
          batched_frame_bboxes.append(Boxes(torch.tensor(current_frame_bboxes).to(self.device)))

        batched_features.append(frame_features[overall_frame_id])

        current_frame_id = frame_id
        current_vid = video_id
        current_frame_bboxes = [(bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2'])]

    if not len(current_frame_bboxes) == 0:
       batched_frame_bboxes.append(Boxes(torch.tensor(current_frame_bboxes).to(self.device)))

    # 2. obtain object features from the image
    bb_features = self.pooler([torch.stack(batched_features)], batched_frame_bboxes)
    new_bb_features = bb_features.reshape(bb_features.shape[0], -1)
    object_features = torch.cat((new_bb_features, torch.tensor(norm_boxes).to(self.device)), dim=1)

    # 3. predict the predicates for each objects and object pairs
    current_split = 0
    batched_rela_vecs = []
    batch_rela_split = []

    current_vid, current_frame_id = 0, 0

    batched_object_features = []
    obj_feature_lookup = {}
    frame_feature_lookup = {}
    check_frame_ids = []

    for overall_obj_id, ((video_id, frame_id, obj_id), bbox) in enumerate(zip(batched_object_ids, batched_bboxes)):
      # Static predicates: current frame feature + object feature
      overall_frame_id = batched_video_splits[video_id] + frame_id
      frame_feature = frame_features[overall_frame_id].reshape(-1)
      overall_obj_feature = torch.cat((frame_feature, object_features[overall_obj_id]), dim=0)
      batched_object_features.append(overall_obj_feature)
      obj_feature_lookup[(video_id, frame_id, obj_id)] = object_features[overall_obj_id]
      check_frame_ids.append(overall_frame_id)

    # Static Predicates
    static_obj_embeddings = {}
    for ((video_id, frame_id, obj_id), object_feature) in zip(batched_object_ids, batched_object_features):
      if not video_id in static_obj_embeddings:
        static_obj_embeddings[video_id] = {}
      if not obj_id in static_obj_embeddings[video_id]:
        static_obj_embeddings[video_id][obj_id] = []

      static_obj_embeddings[video_id][obj_id].append(object_feature)

    batched_static_preds = []
    batched_hidden_states = []
    for video_id, video_info in static_obj_embeddings.items():
      for obj_id, obj_features in video_info.items():
        output, (hn, cn) = self.static_model_feature_extractor(torch.stack(obj_features))
        batched_hidden_states.append(hn.reshape(-1))

    batched_static_preds = self.static_clf_model(torch.stack(batched_hidden_states))

    # Unary Predicates
    batched_unary_preds = self.unary_clf_model(torch.stack(batched_object_features))

    # Binary Predicates
    # batched_obj_pairs
    batched_binary_features = []
    for (vid, fid, (from_id, to_id)) in batched_obj_pairs:

      from_feature = obj_feature_lookup[(vid, fid, from_id)]
      to_feature = obj_feature_lookup[(vid, fid, to_id)]
      overall_frame_id = batched_video_splits[vid] + fid
      frame_feature = frame_features[overall_frame_id].reshape(-1)
      binary_feature = torch.cat((frame_feature, from_feature, to_feature), dim=0)
      batched_binary_features.append(binary_feature)

    if len(batched_binary_features) == 0:
      batched_binary_preds = []
    else:
      batched_binary_preds = self.binary_clf_model(torch.stack(batched_binary_features))

    return batched_unary_preds, batched_binary_preds, batched_static_preds
