import os 
import sys
import torch
from PIL import Image
import torchvision

util_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../utils"))
assert os.path.exists(util_path)
sys.path.append(util_path)

visbackbone_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../visbackbone"))
assert os.path.exists(visbackbone_path)
sys.path.append(visbackbone_path)

from lib import *
from violet_model import VIOLET_Base
from model_utils import gen_mask, mask_image

def save_tensor_as_img(img, file_name="test0.png"):
    debug_path = ""
    img_path = os.path.join(debug_path, file_name)
    img_pil = torchvision.transforms.functional.to_pil_image(img.to(torch.uint8), mode=None)
    img_pil.save(img_path)
    
class VIOLET_STSG(VIOLET_Base):
    def __init__(self, config_dir, tokzr=None, max_size_frame=8):
        super().__init__(config_dir=config_dir, tokzr=tokzr, max_size_frame=max_size_frame)
                
        self.fc = T.nn.Sequential(*[T.nn.Dropout(0.1), 
                                    T.nn.Linear(self.hidden_size, self.hidden_size*2), 
                                    T.nn.ReLU(inplace=True), 
                                    T.nn.Linear(self.hidden_size*2, 1)])
        self.dummy_param = T.nn.Parameter(torch.empty(0))
        

    def violet_encode(self, video, txt, txt_masks):
        
        _h, _w, _T = None, None, None
        if not len(video) == 0:
            (_B, _T, _, _H, _W) = video.shape
            _h, _w = _H//32, _W//32
            feat_img, mask_img = self.enc_img(video)
        else:
            feat_img, mask_img = [], []
        
        if not len(txt) == 0:
            (_, _X) = txt.shape
            feat_txt = self.enc_txt(txt, mask_txt=txt_masks)
        else: 
            feat_txt = []
        
        # feat_img, mask_img, feat_txt, mask_txt = self.go_feat(video, txt, mask=txt_masks)
        return feat_img, mask_img, feat_txt, txt_masks, _h, _w, _T
    
    def violet_decode(self, feat_img, mask_img, feat_txt, mask_txt, _h, _w, _T):
        
        att, _ = self.go_cross(feat_img, mask_img, feat_txt, mask_txt)
        out = self.fc(att[:, (1+_h*_w)*_T, :])
        return out
    
    def tokenizer_forward(self, raw_txt_ls, size_txt=20,):
        
        text_embeddings = []
        text_masks = []
        device = self.dummy_param.device
        
        for raw_txt in raw_txt_ls:
            
            txt = self.tokzr.encode(raw_txt)
            assert len(txt) < size_txt
            txt = txt[:size_txt-1]
            padding_len = size_txt-len(txt)
            txt = torch.tensor(txt+[self.pad_token_id]*(padding_len)).to(device)
            text_embeddings.append(txt)
            mask = torch.tensor([1 if w != self.pad_token_id else 0 for w in txt]).to(device)
            text_masks.append(mask)
            
        return torch.stack(text_embeddings), torch.stack(text_masks)
        
    def forward(self, 
                batched_videos, 
                batched_bboxes, 
                batched_names,
                batched_object_ids,
                batched_unary_kws,
                batched_binary_kws,
                batched_obj_pairs, 
                batched_video_splits):
        
        batched_obj_name_tokens = []
        batched_unary_tokens = []
        batched_binary_tokens = []
        
        # Seperate videos
        new_batched_videos = []
        current_frame_id = 0
        for split_frame_id in batched_video_splits:
            new_batched_videos.append(batched_videos[current_frame_id: split_frame_id])
            current_frame_id = split_frame_id
       
        # txt = T.LongTensor(txt)
        _, height, width = batched_videos[0].shape
        
        # Step 1: compare the video objects with the nouns in the natural language
        for object_names, unary_kws, binary_kws in \
            zip(batched_names, batched_unary_kws, batched_binary_kws):
            
            if len(object_names) == 0:
                batched_obj_name_tokens.append({})
            else:
                obj_name_tokens, obj_masks = self.tokenizer_forward(object_names)
                batched_obj_name_tokens.append({object_name: (obj_name_token, obj_mask) for object_name, obj_name_token, obj_mask in zip(object_names, obj_name_tokens, obj_masks)}) 

            if len(unary_kws) == 0:
                batched_unary_tokens.append({})
            else:
                unary_tokens, unary_masks = self.tokenizer_forward(unary_kws)
                batched_unary_tokens.append({unary_kw: (unary_token, unary_mask) for unary_kw, unary_token, unary_mask in zip(unary_kws, unary_tokens, unary_masks)})
            
            if len(binary_kws) == 0:
                batched_binary_tokens.append({})
            else:
                binary_tokens, binary_masks = self.tokenizer_forward(binary_kws)
                batched_binary_tokens.append({binary_kw: (binary_token, binary_mask) for binary_kw, binary_token, binary_mask in zip(binary_kws, binary_tokens, binary_masks)})
        
        # Step 2: Obtain bounding box masks

        norm_boxes = []
        batched_frame_bboxes = {}
        current_object_traces = {}
        batched_object_traces = []
        
        current_vid, current_frame_id = -1, -1
        batched_video_splits = [0] + batched_video_splits
        empty_mask = torch.tensor(gen_mask(height, width, [])).to(self.dummy_param.device)
        
        for (video_id, frame_id, obj_id), bbox in zip(batched_object_ids, batched_bboxes):
            
            bx1, by1, bx2, by2 = bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2']
            assert by2 > by1
            assert bx2 > bx1
            
            if not obj_id in current_object_traces:
                current_object_traces[obj_id] = {}
            
            img = new_batched_videos[video_id][frame_id]
            # masked_img = mask_image(img, [(bx1, by1, bx2, by2)])
            masked_img = torch.tensor(gen_mask(height, width, [(bx1, by1, bx2, by2)])).to(img.device) * img
            
            if not video_id == current_vid:
                # Assertion on no video contains 0 object
                if not len(batched_object_traces) == 0:
                    batched_object_traces.append(current_object_traces)
                    current_object_traces = {}
            
            current_object_traces[obj_id][frame_id] = masked_img
              
            current_vid = video_id
            batched_frame_bboxes[video_id, frame_id, obj_id] = (bx1, by1, bx2, by2)

        if not len(current_object_traces) == 0:
            batched_object_traces.append(current_object_traces)
        
        # Forward pass through the Pair Proposal Network 
        selected_pairs = set()
        for vid, gt_obj_pairs in enumerate(batched_obj_pairs):
            for fid, frame_rels in enumerate(gt_obj_pairs):
                for  (from_id, to_id, rel) in frame_rels:
                    selected_pairs.add((vid, fid, (from_id, to_id)))
        selected_pairs = list(selected_pairs)
        
        # Step 3: get the similarity for single objects
        batched_image_unary_probs = []
        batched_obj_features = []
        
        for video, object_traces, unary_tokens, obj_name_tokens, binary_tokens in \
            zip(new_batched_videos, batched_object_traces, batched_unary_tokens, batched_obj_name_tokens, batched_binary_tokens):
            current_obj_features = {}
            
            combined_video_features = []
            combined_video_mask_features = []
            combined_text_features = []
            combined_text_mask_features = []
            combined_labels = []            
            
            # TODO: Add padding for videos when batch size > 1
            video_len = len(video)
            # get object name for each object traces
            # Pad the object that does not appear somewhere
            all_obj_traces = {}
            all_timed_obj_images = []
            all_timed_obj_images_labels = []
            
            for object_id, object_trace in object_traces.items():
                all_obj_traces[object_id] = []
                for frame_id in range(video_len):
                    frame = video[frame_id]
                    if not frame_id in object_trace:
                        all_obj_traces[object_id].append(empty_mask * frame)                        
                    else:
                        all_obj_traces[object_id].append(object_traces[object_id][frame_id])
                        all_timed_obj_images.append(object_traces[object_id][frame_id].unsqueeze(dim=0))
                        all_timed_obj_images_labels.append((object_id, frame_id))
                # all_obj_traces[object_id] = torch.stack(all_obj_traces[object_id])
            
            all_timed_obj_images = torch.stack(all_timed_obj_images)
            input_cate_txts = []
            input_txt_masks = []
            cate_names = []
        
            for cate_name, (cate_txt, cate_mask) in obj_name_tokens.items():
                input_cate_txts.append(cate_txt)
                input_txt_masks.append(cate_mask)
                cate_names.append(cate_name)
            
            # objects = list(all_obj_traces.keys())
            # input_video = torch.stack([all_obj_traces[o] for o in objects])
            input_cate_txts = torch.stack(input_cate_txts)
            input_txt_masks = torch.stack(input_txt_masks)
            
            # Process object names
            cate_feat_img, cate_mask_img, cate_feat_txt, cate_mask_txt, \
                cate_h, cate_w, cate_T = self.violet_encode(video=all_timed_obj_images, 
                                          txt=input_cate_txts, 
                                          txt_masks=input_txt_masks)
             
            for (cate_name, (_, _)), txt_feat, txt_mask in zip(obj_name_tokens.items(), cate_feat_txt, cate_mask_txt):
                for (object_id, frame_id), img_feat, img_mask in zip(all_timed_obj_images_labels, cate_feat_img, cate_mask_img):
                    combined_video_features.append(img_feat)
                    combined_video_mask_features.append(img_mask)
                    combined_text_features.append(txt_feat)
                    combined_text_mask_features.append(txt_mask)
                    combined_labels.append(('cate', cate_name, object_id, frame_id))
                    
            input_unary_txts = []  
            input_unary_masks = []
            unary_names = []
            
            # Process Unary
            for unary_name, (unary_txt, unary_mask) in unary_tokens.items():
                input_unary_txts.append(unary_txt)
                input_unary_masks.append(unary_mask)
                unary_names.append(unary_name)
                # input_unary_labels.append((object_id, frame_id, unary_name))
                        
            if not len(input_unary_txts) == 0:
                input_unary_txts = torch.stack(input_unary_txts)
                input_unary_masks =  torch.stack(input_unary_masks)
                    
            _, _, unary_feat_txt, unary_mask_txt, \
                unary_h, unary_w, unary_T = self.violet_encode(video=[], 
                                            txt=input_unary_txts,
                                            txt_masks=input_unary_masks)
            
            for (unary_name, (unary_txt, unary_mask)), txt_feat, txt_mask in zip(unary_tokens.items(), unary_feat_txt, unary_mask_txt):
                for (object_id, frame_id), img_feat, img_mask in zip(all_timed_obj_images_labels, cate_feat_img, cate_mask_img):
                    combined_video_features.append(img_feat)
                    combined_video_mask_features.append(img_mask)
                    combined_text_features.append(txt_feat)
                    combined_text_mask_features.append(txt_mask)
                    combined_labels.append(('unary', unary_name, object_id, frame_id))
                    
            # Process binary
            batched_image_binary_probs = []
            if len(selected_pairs) == 0:
                batched_image_binary_probs.append([])
                binary_h, binary_w, binary_T = None, None, None
            else:
                input_binary_txts = []
                input_binary_txt_masks = []
                binary_kws = []
                batched_obj_pairs_frames = []

                for vid, fid, (from_id, to_id) in selected_pairs:

                    img = new_batched_videos[video_id][fid]
                    fbx1, fby1, fbx2, fby2 = batched_frame_bboxes[(vid, fid, from_id)]
                    tbx1, tby1, tbx2, tby2 = batched_frame_bboxes[(vid, fid, to_id)]
                    
                    masked_img = torch.tensor(gen_mask(height, width, [(fbx1, fby1, fbx2, fby2), (tbx1, tby1, tbx2, tby2)])).to(img.device) * img
                    batched_obj_pairs_frames.append(masked_img.unsqueeze(dim=0))
                
                for binary_kw, (binary_token, binary_mask) in binary_tokens.items():
                    input_binary_txts.append(binary_token)
                    input_binary_txt_masks.append(binary_mask)
                    binary_kws.append(binary_kw)
                
                if not len(input_binary_txts) == 0:
                    input_binary_txts = torch.stack(input_binary_txts)
                    input_binary_txt_masks = torch.stack(input_binary_txt_masks)
                    
                binary_feat_img, binary_mask_img, binary_feat_txt, binary_mask_txt, \
                    binary_h, binary_w, binary_T = self.violet_encode(video=torch.stack(batched_obj_pairs_frames), 
                                        txt=input_binary_txts, 
                                        txt_masks=input_binary_txt_masks)

                for binary_name, txt_feat, txt_mask in zip(binary_kws, binary_feat_txt, binary_mask_txt):
                    for (vid, fid, (from_id, to_id)), img_feat, img_mask in zip(selected_pairs, binary_feat_img, binary_mask_img):
                        combined_video_features.append(img_feat)
                        combined_video_mask_features.append(img_mask)
                        combined_text_features.append(txt_feat)
                        combined_text_mask_features.append(txt_mask)
                        combined_labels.append(('binary', vid, fid, from_id, to_id, binary_name))
            
            
            assert binary_h is None or binary_h == cate_h
            assert binary_w is None or binary_w == cate_w
            assert binary_T is None or binary_T == cate_T
            
            assert unary_h is None or cate_h == unary_h
            assert unary_w is None or cate_w == unary_w
            assert unary_T is None or cate_T == unary_T
            
            out = self.violet_decode(torch.stack(combined_video_features), 
                                     torch.stack(combined_video_mask_features),
                                     torch.stack(combined_text_features), 
                                     torch.stack(combined_text_mask_features), cate_h, cate_w, cate_T)
            out = out.reshape(-1)
            cate_labels = []
            unary_labels = []
            binary_labels = []
            cate_results = {}
            unary_results = []
            binary_results = []
            
            for label, prob in zip(combined_labels, out):
                if label[0] == 'cate':
                    cate_name, object_id, frame_id = label[1:]
                    if not object_id in cate_results:
                        cate_results[object_id] = {}
                    if not cate_name in cate_results[object_id]:
                        cate_results[object_id][cate_name] = []
                    cate_results[object_id][cate_name].append(prob)
                    # cate_labels.append()
                elif label[0] == 'unary':
                    unary_results.append(prob)
                    unary_labels.append(label[1:])
                elif label[0] == 'binary':
                    binary_results.append(prob)
                    binary_labels.append(label[1:])
            
            new_cate_results = []
            cate_labels = []
            for obj_id, cate_info in cate_results.items():
                current_obj_preds = []
                for cate_name, cat_prob_ls in cate_info.items():
                    current_obj_preds.append(torch.mean(torch.stack(cat_prob_ls)))
                    cate_labels.append((obj_id, cate_name))
                new_cate_results += torch.softmax(torch.stack(current_obj_preds), dim=0)
            cate_results = new_cate_results
            
            if not len(unary_results) == 0:
                unary_results = torch.sigmoid(torch.stack(unary_results))
            if not len(binary_results) == 0:
                binary_results = torch.sigmoid(torch.stack(binary_results))
            
            batched_image_unary_probs.append([(cate_results, cate_labels), (unary_results, unary_labels)])
            batched_image_binary_probs.append((binary_results, binary_labels))
    
        return batched_image_unary_probs, batched_image_binary_probs, selected_pairs

        
    def fake_forward(self, batched_videos, batched_bboxes, batched_object_ids, batched_obj_pairs, batched_occured_objs, \
                    batched_video_splits):

        unary_pred_out_dim = len(unary_preds)
        binary_pred_out_dim = len(binary_preds)
        static_pred_out_dim = len(static_preds)

        unary_pred_batch_size = len(batched_bboxes)
        binary_pred_batch_size = len(batched_obj_pairs)
        static_pred_batch_size = sum([len(objs) for objs in batched_occured_objs])

        unary_pred_prob = torch.rand((unary_pred_batch_size, unary_pred_out_dim)).sigmoid()
        binary_pred_prob = torch.rand((binary_pred_batch_size, binary_pred_out_dim)).sigmoid()
        static_pred_prob = torch.rand((static_pred_batch_size, static_pred_out_dim)).sigmoid()

        return unary_pred_prob, binary_pred_prob, static_pred_prob
