import torch
import torch.nn as nn

from configs.paths import project_path
from modules.object_tracking import FeatureExtractionResNet101
from modules.sthoip_transformer.sttran_gaze import STTranGaze
from modules.sthoip_transformer.sttran_gaze import STTranGazeCrossAttention
from modules.gaze_following import GazeFollowing
from common.dataset.image_processing import union_roi,mask_union_roi
import numpy as np

class STTRAN_Wrapper(nn.Module):
    def __init__(self, cfg, modelname, gaze_usage):
        super().__init__()
        self.cfg =cfg
        self.device = cfg["MODEL"]["device"]
        self.modelname = modelname
        self.gaze_usage = gaze_usage
        self.feature_extractor = self.load_backbone(cfg)
        self.model = self.load_sttran(cfg, gaze_usage)



    def extract_features(self, batch):
        """
        Uses image feature extractor to extract the
        union_feats
        features
        spatial_masks
        obj_heatmaps
        :return:
        """
        if "features" not in batch:
            frames = batch["frames"]
            bboxes = batch["bboxes"]
            pair_idxes = batch["pair_idxes"]
            im_idxes = batch["im_idxes"]

            base_feature_maps = self.feature_extractor.backbone_base(frames)
            bboxes_features, _ = self.feature_extractor(frames, bboxes.to(self.device), base_feature_maps)
            union_bboxes = union_roi(bboxes, pair_idxes, im_idxes)
            union_features, _ = self.feature_extractor(frames, union_bboxes.to(self.device), base_feature_maps)
            # masked union bboxes
            pair_rois = torch.cat((bboxes[pair_idxes[:, 0], 1:], bboxes[pair_idxes[:, 1], 1:]), 1).data.cpu().numpy()
            masked_bboxes = torch.Tensor(mask_union_roi(pair_rois, 27) - 0.5).to(self.device)

            batch.update({"union_feats": union_features, "features":bboxes_features, "spatial_masks": masked_bboxes})
        return batch

    def forward(self, batch):
        batch = self.extract_features(batch)
        # return batch
        return self.model(batch)


    def load_sttran(self, cfg, gaze_usage):
        num_interaction_classes_loss = cfg["TRAINER"]["DATASET"]["num_interaction_classes"]
        num_spatial_classes =  cfg["TRAINER"]["DATASET"]["num_spatial_classes"]
        # transformer model
        if cfg["TRAINER"]["ENGINE"]["separate_head"]:
            separate_head_num = [num_spatial_classes, -1]
            separate_head_name = ["spatial_head", "action_head"]
        else:
            separate_head_name = ["interaction_head"]
            separate_head_num = [-1]

        if gaze_usage == "cross":
            sttran_gaze_model = STTranGazeCrossAttention(
                num_interaction_classes=cfg["TRAINER"]["DATASET"]["num_interaction_classes"],
                obj_class_names=cfg["TRAINER"]["DATASET"]["object_classes"],
                spatial_layer_num=cfg["MODEL"]["sttran_enc_layer_num"],
                cross_layer_num=1,
                temporal_layer_num=cfg["MODEL"]["sttran_dec_layer_num"] - 1,
                dim_transformer_ffn=cfg["MODEL"]["dim_transformer_ffn"],
                d_gaze=512,
                cross_sa=True,
                cross_ffn=False,
                global_token=cfg["MODEL"]["global_token"],
                mlp_projection=cfg["MODEL"]["mlp_projection"],
                sinusoidal_encoding=cfg["MODEL"]["sinusoidal_encoding"],
                dropout=cfg["MODEL"]["dropout"],
                word_vector_dir=cfg["MODEL"]["sttran_word_vector_dir"],
                sliding_window=cfg["MODEL"]["sttran_sliding_window"],
                separate_head=separate_head_num,
                separate_head_name=separate_head_name,
            )
            print(
                f"Spatial-temporal Transformer loaded. d_model={sttran_gaze_model.d_model}, "
                f"gaze cross first layer={sttran_gaze_model.d_gaze}, separate_head={sttran_gaze_model.separate_head}"
            )
        elif gaze_usage == "cross_all":
            sttran_gaze_model = STTranGazeCrossAttention(
                num_interaction_classes=cfg["TRAINER"]["DATASET"]["num_interaction_classes"],
                obj_class_names=cfg["TRAINER"]["DATASET"]["object_classes"],
                spatial_layer_num=cfg["MODEL"]["sttran_enc_layer_num"],
                cross_layer_num=cfg["MODEL"]["sttran_dec_layer_num"],
                temporal_layer_num=0,
                dim_transformer_ffn=cfg["MODEL"]["dim_transformer_ffn"],
                d_gaze=512,
                cross_sa=True,
                cross_ffn=False,
                global_token=cfg["MODEL"]["global_token"],
                mlp_projection=cfg["MODEL"]["mlp_projection"],
                sinusoidal_encoding=cfg["MODEL"]["sinusoidal_encoding"],
                dropout=cfg["MODEL"]["dropout"],
                word_vector_dir=cfg["MODEL"]["sttran_word_vector_dir"],
                sliding_window=cfg["MODEL"]["sttran_sliding_window"],
                separate_head=separate_head_num,
                separate_head_name=separate_head_name,
            )
            print(
                f"Spatial-temporal Transformer loaded. d_model={sttran_gaze_model.d_model}, "
                f"gaze cross all layers={sttran_gaze_model.d_gaze}, separate_head={sttran_gaze_model.separate_head}"
            )
        else:
            sttran_gaze_model = STTranGaze(
                num_interaction_classes=cfg["TRAINER"]["DATASET"]["num_interaction_classes"],
                obj_class_names=cfg["TRAINER"]["DATASET"]["object_classes"],
                enc_layer_num=cfg["MODEL"]["sttran_enc_layer_num"],
                dec_layer_num=cfg["MODEL"]["sttran_dec_layer_num"],
                dim_transformer_ffn=cfg["MODEL"]["dim_transformer_ffn"],
                no_gaze=gaze_usage == "no",
                sinusoidal_encoding=cfg["MODEL"]["sinusoidal_encoding"],
                word_vector_dir=cfg["PATHS"]["sttran_word_vector_dir"],
                sliding_window=cfg["MODEL"]["sttran_sliding_window"],
                separate_head=separate_head_num,
                separate_head_name=separate_head_name,
            )
            print(
                f"Spatial-temporal Transformer loaded. d_model={sttran_gaze_model.d_model}, "
                f"gaze concat={sttran_gaze_model.no_gaze}, separate_head={sttran_gaze_model.separate_head}"
            )
        sttran_gaze_model = sttran_gaze_model.to(self.device)
        return sttran_gaze_model



    def load_backbone(self, cfg):
        backbone_model_path = cfg["PATHS"]["backbone_model_path"]
        ## Models
        feature_backbone = FeatureExtractionResNet101(
            backbone_model_path, download=True, finetune=False, finetune_layers=[]
        ).to(self.device)
        trainable_backbone_names = []
        trainable_backbone_params = []
        for name, param in feature_backbone.named_parameters():
            if param.requires_grad:
                trainable_backbone_names.append(name)
                trainable_backbone_params.append(param)
        print(
            f"ResNet101 feature backbone loaded from {backbone_model_path}. Finetuning weights: {trainable_backbone_names}"
        )
        return feature_backbone

    def load_gaze(self, cfg):
        gaze_following_module = GazeFollowing(
            weight_path=f"{project_path}/weights/detecting_attended/model_videoatttarget.pt",
            config_path=f"{project_path}/configs/gaze_following.yaml",
            device=self.device,
        )
        return gaze_following_module

    def info_model(self):
        print(f"MODEL {self.modelname}")
        total_params = sum(p.numel() for p in self.feature_extractor.parameters())
        print("[{}] - {:.2f}M".format("FEATURE BACKBONE", total_params / 10 ** 6))
        total_params = sum(p.numel() for p in self.model.parameters())
        print("[{}] - {:.2f}M".format("STTRAN", total_params / 10 ** 6))
