import os 
import sys
import torch

util_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../utils"))
assert os.path.exists(util_path)
sys.path.append(util_path)

visbackbone_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../visbackbone"))
assert os.path.exists(visbackbone_path)
sys.path.append(visbackbone_path)

from lib import *
from violet_model import VIOLET_Base
from model_utils import gen_mask

class VIOLET_STSG(VIOLET_Base):
    def __init__(self, config_dir, tokzr=None, max_size_frame=8):
        super().__init__(config_dir=config_dir, tokzr=tokzr, max_size_frame=max_size_frame)
                
        self.fc = T.nn.Sequential(*[T.nn.Dropout(0.1), 
                                    T.nn.Linear(self.hidden_size, self.hidden_size*2), 
                                    T.nn.ReLU(inplace=True), 
                                    T.nn.Linear(self.hidden_size*2, 1)])
        self.dummy_param = T.nn.Parameter(torch.empty(0))
        

    def violet_forward(self, video, txt, txt_masks):
        
        (_B, _T, _, _H, _W), (_, _X) = video.shape, txt.shape
        _h, _w = _H//32, _W//32
        
        feat_img, mask_img, feat_txt, mask_txt = self.go_feat(video, txt, mask=txt_masks)

        att, _ = self.go_cross(feat_img, mask_img, feat_txt, mask_txt)

        out = self.fc(att[:, (1+_h*_w)*_T, :])
        
        return out, feat_img, mask_img

    def tokenizer_forward(self, raw_txt_ls, size_txt=8,):
        
        text_embeddings = []
        text_masks = []
        device = self.dummy_param.device
        
        for raw_txt in raw_txt_ls:
            
            txt = self.tokzr.encode(raw_txt)
            assert len(txt) < size_txt
            txt = txt[:size_txt-1]
            padding_len = size_txt-len(txt)
            txt = torch.tensor(txt+[self.pad_token_id]*(padding_len)).to(device)
            text_embeddings.append(txt)
            mask = torch.tensor([1 if w != self.pad_token_id else 0 for w in txt]).to(device)
            text_masks.append(mask)
            
        return torch.stack(text_embeddings), torch.stack(text_masks)
        
    def forward(self, 
                batched_videos, 
                batched_bboxes, 
                batched_names,
                batched_object_ids,
                batched_unary_kws,
                batched_binary_kws,
                batched_obj_pairs, 
                batched_video_splits):
        
        batched_obj_name_tokens = []
        batched_unary_tokens = []
        batched_binary_tokens = []
        
        # Seperate videos
        new_batched_videos = []
        current_frame_id = 0
        for split_frame_id in batched_video_splits:
            new_batched_videos.append(batched_videos[current_frame_id: split_frame_id])
            current_frame_id = split_frame_id
       
        # txt = T.LongTensor(txt)
        height, width, _ = batched_videos[0].shape
        
        # Step 1: compare the video objects with the nouns in the natural language
        for object_names, unary_kws, binary_kws in \
            zip(batched_names, batched_unary_kws, batched_binary_kws):
            
            if len(object_names) == 0:
                batched_obj_name_tokens.append({})
            else:
                obj_name_tokens, obj_masks = self.tokenizer_forward(object_names)
                batched_obj_name_tokens.append({object_name: (obj_name_token, obj_mask) for object_name, obj_name_token, obj_mask in zip(object_names, obj_name_tokens, obj_masks)}) 

            if len(unary_kws) == 0:
                batched_unary_tokens.append({})
            else:
                unary_tokens, unary_masks = self.tokenizer_forward(unary_kws)
                batched_unary_tokens.append({unary_kw: (unary_token, unary_mask) for unary_kw, unary_token, unary_mask in zip(unary_kws, unary_tokens, unary_masks)})
            
            if len(binary_kws) == 0:
                batched_binary_tokens.append({})
            else:
                binary_tokens, binary_masks = self.tokenizer_forward(binary_kws)
                batched_binary_tokens.append({binary_kw: (binary_token, binary_mask) for binary_kw, binary_token, binary_mask in zip(binary_kws, binary_tokens, binary_masks)})
        
        # Step 2: Obtain bounding box masks

        norm_boxes = []
        batched_frame_bboxes = {}
        current_object_traces = {}
        batched_object_traces = []
        
        current_vid, current_frame_id = -1, -1
        batched_video_splits = [0] + batched_video_splits
        empty_mask = torch.tensor(gen_mask(height, width, []))
        
        for (video_id, frame_id, obj_id), bbox in zip(batched_object_ids, batched_bboxes):
            
            bx1, by1, bx2, by2 = bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2']
            assert by2 > by1
            assert bx2 > bx1
            
            if not obj_id in current_object_traces:
                current_object_traces[obj_id] = {}
            
            img = new_batched_videos[video_id][frame_id]
            masked_img = torch.tensor(gen_mask(height, width, [(bx1, by1, bx2, by2)])).to(img.device) * img
            
            if not video_id == current_vid:
                # Assertion on no video contains 0 object
                if not len(batched_object_traces) == 0:
                    batched_object_traces.append(current_object_traces)
                    current_object_traces = {}
            
            current_object_traces[obj_id][frame_id] = masked_img
              
            current_vid = video_id
            batched_frame_bboxes[video_id, frame_id, obj_id] = (bx1, by1, bx2, by2)

        if not len(current_object_traces) == 0:
            batched_object_traces.append(current_object_traces)
        
       
        # Step 3: get the similarity for single objects
        batched_image_unary_probs = []
        batched_obj_features = []
        
        for video, object_traces, unary_tokens, obj_name_tokens in \
            zip(new_batched_videos, batched_object_traces, batched_unary_tokens, batched_obj_name_tokens):
            current_obj_features = {}
            
            # TODO: Add padding for videos when batch size > 1
            video_len = len(video)
            # get object name for each object traces
            # Pad the object that does not appear somewhere
            all_obj_masks = {}
            all_timed_obj_images = []
            all_timed_obj_images_labels = []
            
            for object_id, object_trace in object_traces.items():
                all_obj_masks[object_id] = []
                for frame_id in range(video_len):
                    frame = video[frame_id]
                    if not frame_id in object_trace:
                        all_obj_masks[object_id].append(empty_mask * frame)                        
                    else:
                        all_obj_masks[object_id].append(object_traces[object_id][frame_id])
                        all_timed_obj_images.append(object_traces[object_id][frame_id].unsqueeze(dim=0))
                        all_timed_obj_images_labels.append((object_id, frame_id))
                all_obj_masks[object_id] = torch.stack(all_obj_masks[object_id])
            
            all_timed_obj_images = torch.stack(all_timed_obj_images)
            input_obj_trace_masks = []
            input_cate_txts = []
            input_txt_masks = []
            cate_labels = []
            
            for object_id, obj_trace_masks in  all_obj_masks.items():
                for cate_name, (cate_txt, cate_mask) in obj_name_tokens.items():
                    input_obj_trace_masks.append(obj_trace_masks)
                    input_cate_txts.append(cate_txt)
                    input_txt_masks.append(cate_mask)
                    cate_labels.append((object_id, cate_name))
            
            input_video = torch.stack(input_obj_trace_masks).to(torch.float16)
            input_cate_txts = torch.stack(input_cate_txts)
            input_txt_masks = torch.stack(input_txt_masks)
            
            # For object names
            cate_results, _, _ = self.violet_forward(video=input_video, 
                                          txt=input_cate_txts, 
                                          txt_masks=input_txt_masks)
            
            input_unary_txts = []
            input_unary_masks = []
            input_unary_labels = []
            for object_id, frame_id in all_timed_obj_images_labels:
                for (unary_txt, unary_mask)  in unary_tokens:
                    input_unary_txts.append(unary_txt)
                    input_unary_masks.append(unary_mask)
                    input_unary_labels.append((object_id, frame_id, unary_txt))
                    
            # For unary predicates
            if len(unary_tokens) == 0:
                unary_results = torch.tensor([])
            else:
                unary_results, feat_img, mask_img = self.violet_forward(video=all_timed_obj_images, 
                                              txt=input_unary_txts,
                                              txt_masks=input_unary_masks)
                for object_id, frame_id, _ in input_unary_labels:
                    if not object_id in current_obj_features:
                        current_obj_features[object_id] = {}
                    if not frame_id in current_obj_features[object_id]:
                        current_obj_features[object_id][frame_id] = (feat_img, mask_img)

            batched_obj_features.append(current_obj_features)
            batched_image_unary_probs.append([(cate_results, cate_labels), (unary_results, input_unary_labels)])

        # Step 4: get the similarity for object pairs 
        batched_obj_pairs_frames = {}
        batched_obj_pairs_labels = []
        sub_feats = []
        obj_feats = []
        candidate_obj_pairs = []
        for (vid, fid, (from_id, to_id)) in batched_obj_pairs:

            sub_feat = torch.stack([batched_obj_features[vid][from_id][fid]])
            obj_feat = torch.stack([batched_obj_features[vid][to_id][fid]])
            
            sub_feats.append(sub_feat)
            obj_feats.append(obj_feat)
            candidate_obj_pairs.append((from_id, to_id))
        
        # Forward pass through the Pair Proposal Network 
        if len(batched_obj_pairs) == 0:
            selected_pairs = []
        else:
            pred_outputs = self.pair_proposal_model(torch.stack(sub_feats), torch.stack(obj_feats))

            # get top pairs 
            # TODO: add this into the loss pipeline
            selected_pairs = pick_top_pairs(pred_outputs, self.num_top_pairs)
        
        new_selected_pairs = []    
        for pair_id in selected_pairs:
            gl_from_id, gl_to_id = candidate_obj_pairs[pair_id]
            vid, fid, from_id = batched_object_ids[gl_from_id]
            tvid, tfid, to_id = batched_object_ids[gl_to_id]
            
            assert vid == tvid and fid == tfid
            new_selected_pairs.append((vid, fid, (from_id, to_id))) 
            
            overall_frame_id = batched_video_splits[vid] + fid
            fbx1, fby1, fbx2, fby2 = batched_frame_bboxes[(vid, fid, from_id)]
            tbx1, tby1, tbx2, tby2 = batched_frame_bboxes[(vid, fid, to_id)]
            
            masked_img = torch.tensor(gen_mask(height, width, [(fbx1, fby1, fbx2, fby2), (tbx1, tby1, tbx2, tby2)])).to(img.device) * img

            batched_obj_pairs_frames.append(masked_img * video)
            batched_obj_pairs_labels.append((vid, fid, from_id, to_id))
            
        batched_image_binary_probs = []
        if len(batched_cropped_obj_pairs) == 0:
            batched_image_binary_probs.append([])
        else:
            for ct, binary_nl_features in enumerate(batched_binary_nl_features):
                
                if len(binary_nl_features) == 0:
                    batched_image_binary_probs.append([])
                    continue
                
                binary_kws = batched_binary_kws[ct]
                
                cropped_obj_pairs = batched_cropped_obj_pairs[ct]
                inputs = self.siglip_processor(images=cropped_obj_pairs, return_tensors="pt")
                inputs = inputs.to(self.device)
                
                obj_features = self.siglip_model.get_image_features(**inputs)
                obj_siglip_features = obj_features / obj_features.norm(p=2, dim=-1, keepdim=True)
                binary_nl_features = binary_nl_features / binary_nl_features.norm(p=2, dim=-1, keepdim=True)

                logit_scale = self.siglip_model.logit_scale
                logit_bias = self.siglip_model.logit_bias
                binary_logits_per_text = torch.matmul(binary_nl_features, obj_siglip_features.t()) * logit_scale.exp() + logit_bias
                binary_logits_per_image = binary_logits_per_text.t()
                
                batched_image_binary_probs.append((binary_logits_per_text.sigmoid(), binary_kws))
       
        return batched_image_unary_probs, batched_image_binary_probs, new_selected_pairs
        


        
    def fake_forward(self, batched_videos, batched_bboxes, batched_object_ids, batched_obj_pairs, batched_occured_objs, \
                    batched_video_splits):

        unary_pred_out_dim = len(unary_preds)
        binary_pred_out_dim = len(binary_preds)
        static_pred_out_dim = len(static_preds)

        unary_pred_batch_size = len(batched_bboxes)
        binary_pred_batch_size = len(batched_obj_pairs)
        static_pred_batch_size = sum([len(objs) for objs in batched_occured_objs])

        unary_pred_prob = torch.rand((unary_pred_batch_size, unary_pred_out_dim)).sigmoid()
        binary_pred_prob = torch.rand((binary_pred_batch_size, binary_pred_out_dim)).sigmoid()
        static_pred_prob = torch.rand((static_pred_batch_size, static_pred_out_dim)).sigmoid()

        return unary_pred_prob, binary_pred_prob, static_pred_prob
