from video_chatgpt.video_conversation import conv_templates, SeparatorStyle
from video_chatgpt.model.utils import KeywordsStoppingCriteria
import torch
from einops import rearrange
import numpy as np
import torch.nn.functional as F

import os               
import matplotlib.pyplot as plt
from video_chatgpt.utils import save_as_gif
from PIL import Image
import torchvision.transforms as transforms
to_pil = transforms.ToPILImage()
from scipy.ndimage import binary_erosion

# Define constants
DEFAULT_VIDEO_TOKEN = "<video>"
DEFAULT_VIDEO_PATCH_TOKEN = "<vid_patch>"
DEFAULT_VID_START_TOKEN = "<vid_start>"
DEFAULT_VID_END_TOKEN = "<vid_end>"
DEFAULT_TRANSCRIPT_START = "The noisy audio transcript of this video is:"

downsize_to_288 = transforms.Resize(288)

def get_spatio_temporal_features_torch_batch(features, 
                                            temp_pool_type=None, 
                                            semantic_tokens=None, 
                                            frame_length=100, 
                                            semantic_replace=True):
    """
    Computes spatio-temporal features from given features.

    Parameters:
    features (torch.Tensor): Input features to process.

    Returns:
    torch.Tensor: Spatio-temporal features.
    """

    # Extract the dimensions of the features
    b, t, s, c = features.shape

    # Compute temporal tokens as the mean along the time axis
    temporal_tokens = torch.mean(features, dim=2)

    # Padding size calculation
    padding_size = frame_length - t

    # Pad temporal tokens if necessary
    if padding_size > 0:
        padding = torch.zeros(b, padding_size, c, device=features.device)
        temporal_tokens = torch.cat((temporal_tokens, padding), dim=1)
    
    # Compute spatial tokens as the mean along the spatial axis
    spatial_tokens = torch.mean(features, dim=1)

    if semantic_tokens is not None:
        temp_length = int(frame_length-semantic_tokens.shape[1])
        # Concatenate temporal and spatial tokens and cast to half precision
        if semantic_replace:
            if temp_pool_type == 'random':
                perm = torch.randperm(frame_length)
                # Select the first frame_length-num_semantic values
                sorted_selected = torch.sort(perm[:temp_length]).values
            # uniform
            else:
                sorted_selected = torch.linspace(0, frame_length-1, steps=temp_length, dtype=torch.long)            
            
            concat_tokens = torch.cat([temporal_tokens[:,sorted_selected], semantic_tokens, spatial_tokens], dim=1).half()
        else:
            concat_tokens = torch.cat([temporal_tokens, semantic_tokens, spatial_tokens], dim=1).half()        
    else:
        # Concatenate temporal and spatial tokens and cast to half precision
        concat_tokens = torch.cat([temporal_tokens, spatial_tokens], dim=1).half()

    return concat_tokens



def video_chatgpt_infer_batch(video_frames, question, conv_mode, model, vision_tower, tokenizer, 
                                image_processor, video_token_len, transcript=None, 
                                spixel_encoder=None, 
                                save_spixel_visualization=False,
                                temp_pool_type=None,
                                output_name='tav',
                                debug=False,
                                rewrite=False,
                                end_marks=False,
                                clip_text_package=None,
                                rewrite_per_sentence=False,
                                sam_package={},
                                batch_video_name='',
                                save_masks_and_features=False
                                ):
    """
    Run inference using the Video-ChatGPT model.

    Parameters:
    sample : Initial sample
    video_frames (torch.Tensor): Video frames to process.
    question (str): The question string.
    conv_mode: Conversation mode.
    model: The pretrained Video-ChatGPT model.
    vision_tower: Vision model to extract video features.
    tokenizer: Tokenizer for the model.
    image_processor: Image processor to preprocess video frames.
    video_token_len (int): The length of video tokens.

    Returns:
    dict: Dictionary containing the model's output.
    """
    save_dict = []
    # Prepare question string for the model
    bs = len(question)
    if model.get_model().vision_config.use_vid_start_end:
        qs = question[0] + '\n' + DEFAULT_VID_START_TOKEN + DEFAULT_VIDEO_PATCH_TOKEN * video_token_len + DEFAULT_VID_END_TOKEN
    else:
        qs = question[0] + '\n' + DEFAULT_VIDEO_PATCH_TOKEN * video_token_len
    
    # Append transcript text to the question
    if transcript:
        qs = f'{qs}\n{DEFAULT_TRANSCRIPT_START}\n\"{transcript}\"'

    # Prepare conversation prompt
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()


    # Tokenize the prompt
    inputs = tokenizer([prompt])

    # Preprocess video frames and get image tensor
    batch_image_tensor = []
    for vid in video_frames:
        image_tensor = image_processor.preprocess(vid, return_tensors='pt')['pixel_values']
        batch_image_tensor.append(image_tensor.unsqueeze(0))

    # Move image tensor to GPU and reduce precision to half
    video_tensor = torch.cat(batch_image_tensor, dim=0)
    video_tensor = video_tensor.half().cuda()
    
    bs, t, _, _, _ = video_tensor.shape    
    image_tensor = rearrange(video_tensor, 'b t c h w -> (b t) c h w')

    # Generate video spatio-temporal features
    with torch.no_grad():
        image_forward_outs = vision_tower(image_tensor, output_hidden_states=True)
        frame_features = image_forward_outs.hidden_states[-2][:, 1:]        
        token_h = token_w = int(np.sqrt(frame_features.shape[1]))
        
        if spixel_encoder != None:
            frame_features_ = rearrange(frame_features, '(b t) (h w) d -> b t h w d', b=bs, h=token_h)
            spixel_encoder.half()
            spixel_encoder.to(frame_features.device)
            
            video_bt = rearrange(video_tensor, 'b t c h w -> (b t) c h w').to(frame_features.device)
            
            sp_indices, new_n_keys = spixel_encoder.get_k_means_feature(video_bt, len(video_frames), sam_package=sam_package)
            
            intp_scale_h = sp_indices.shape[3] // token_h
            intp_scale_w = sp_indices.shape[4] // token_w
                
            spixel_features_list = []
            overlapped_sp_indices_list = []
            for bid, batch_sp_indices in enumerate(sp_indices):
                n_key = new_n_keys[bid]
                # n_key should be larger than 1                
                one_hot_indices = torch.arange(n_key, device=frame_features.device).unsqueeze(1)
                one_hot_indices = one_hot_indices.expand(n_key, spixel_encoder.max_spread_scale)[:, :, None, None, None]
                
                # overlapping indices
                overlapped_one_hot_indices = torch.where((one_hot_indices == batch_sp_indices).int().sum(dim=1) > 0, 1, 0)
                overlapped_sp_indices_list.append(overlapped_one_hot_indices)
                pooled_prob = F.avg_pool2d(overlapped_one_hot_indices.float(), kernel_size=(intp_scale_h, intp_scale_w), stride=(intp_scale_h, intp_scale_w))
                
                # [1, t, h, w, d] * [k, t, h, w, 1] --> [k, t, h, w, d]                                
                masked_encoded_tokens = frame_features_[bid].unsqueeze(0) * pooled_prob.unsqueeze(-1)
                
                spixel_features = torch.sum(masked_encoded_tokens, dim=[1,2,3])
                spixel_scale = torch.sum(pooled_prob, dim=[1,2,3])
                spixel_features_list.append(spixel_features / spixel_scale.unsqueeze(1).clamp(min=1))

            spixel_features_tensor = torch.stack(spixel_features_list, dim=0)
            overlapped_sp_indices = torch.stack(overlapped_sp_indices_list, dim=0)
            
            sv_masks = F.avg_pool3d(overlapped_sp_indices.float(), kernel_size=(1,2,2), stride=(1,2,2))                        
            sv_masks = sv_masks.repeat_interleave(7, dim=3).repeat_interleave(7, dim=4)
            sv_feats = spixel_features_tensor
            
            if save_masks_and_features:
                for sv_i in range(bs):
                    save_dict.append(
                            {
                            'video_name': batch_video_name[sv_i],
                            'masks': sv_masks[sv_i].detach().cpu(),
                            'spixel_features': sv_feats[sv_i].detach().cpu()
                            }
                        )
            if save_spixel_visualization:
                # batch_video_name
                save_superpixel_videos(frames=video_tensor, 
                                        spixel_masks=sp_indices, 
                                        n_keys=new_n_keys, 
                                        ovlp_spixel_masks=overlapped_sp_indices, 
                                        output_name=output_name,
                                        sam_info=sam_package['sam_type'])            
    
    frame_features = rearrange(frame_features, '(b t) hw d -> b t hw d', b=bs)
    if spixel_encoder is not None:        
        video_spatio_temporal_features = get_spatio_temporal_features_torch_batch(frame_features, temp_pool_type, spixel_features_tensor)
    else:
        video_spatio_temporal_features = get_spatio_temporal_features_torch_batch(frame_features)
    
    # Move inputs to GPU
    input_ids = torch.as_tensor(inputs.input_ids).cuda()
    input_ids = torch.repeat_interleave(input_ids, bs, dim=0)
    
    # Define stopping criteria for generation
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
    
    # Run model inference
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            video_spatio_temporal_features=video_spatio_temporal_features,
            do_sample=True,
            temperature=0.2,
            max_new_tokens=1024,
            stopping_criteria=[stopping_criteria])

    # Check if output is the same as input
    n_diff_input_output = (input_ids != output_ids[:, :input_ids.shape[1]]).sum().item()
    if n_diff_input_output > 0:
        print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')

    # Decode output tokens
    outputs = tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)
    outputs_ = [] 

    # Clean output string
    for o in outputs:
        outputs_.append(o.strip().rstrip(stop_str).strip())

    return outputs_, save_dict
    

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)
    
        

def save_superpixel_videos(frames, spixel_masks, n_keys, video_names=None, ovlp_spixel_masks=None, output_name=None, sam_info=None):
    B, T, _, _, _= frames.shape
    
    frames = rearrange(frames, 'b t c h w -> (b t) c h w')
    frames = downsize_to_288(frames)
    
    spixel_masks_top = spixel_masks[:, 0, ...] 
    intp_spixel_masks = spixel_masks_top.repeat_interleave(3, dim=2).repeat_interleave(3, dim=3)

    frames = (frames + 0.5).clamp(0, 1)
    frames = rearrange(frames, '(b t) c h w -> b t h w c', b=B).cpu().numpy() # BCTHW -> BTHWC            
    frames = (frames * 255).astype('uint8')
    intp_spixel_masks = intp_spixel_masks.cpu().numpy().astype('uint8')
    
    Sky = [128,128,128]
    Building = [128,0,0]
    Pole = [192,192,128]
    Road = [128,64,128]
    Pavement = [60,40,222]
    Tree = [128,128,0]
    SignSymbol = [192,128,128]
    Fence = [64,64,128]
    Car = [64,0,128]
    Pedestrian = [64,64,0]
    Bicyclist = [0,128,192]
    Unlabelled = [0,0,0]
    color_dict = np.array([Unlabelled, Sky, Building, Pole, Road, Pavement,
                        Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Sky, Building, Pole, Road, Pavement,
                        Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Sky, Building, Pole, Road, Pavement,
                        Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Sky, Building, Pole, Road, Pavement,
                        Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist])
    color_dict_rgba = np.c_[color_dict, np.full(len(color_dict), 128, dtype=np.uint8)]
    
    video_duration = 4000
    video_length = frames.shape[1]
    
    for bid, (frame, spixel_mask, n_key) in enumerate(zip(frames, intp_spixel_masks, n_keys)):        
        blended_frames = []
        base_path = os.path.join('save_gifs', sam_info, 'davis_conceptfig')
        os.makedirs(base_path, exist_ok=True)    
        for fid, (f_, sm_) in enumerate(zip(frame, spixel_mask)):
            
            blended_image_array = labelVisualize(n_key, color_dict_rgba, f_, sm_)
            blended_frames.append(blended_image_array)
            
        batch_frames = frames[bid]
        batch_frames = torch.Tensor(batch_frames)[None,:,:,:,:]
        blended_frames = torch.Tensor(blended_frames)[None,:,:,:,:]            
        org_pil = to_pil_images(batch_frames/255, output_type="pil")
        seg_pil = to_pil_images(blended_frames/255, output_type="pil")
        
        video_name = video_names[bid] if video_names is not None else f'{output_name}_video_idx{bid}'
        output_dir = f"{video_name}_seg_dur{str(video_duration).replace('000', 'k')}_fid{fid}_nkey{n_key}.gif"
        
        output_dir = os.path.join(base_path, output_dir)
        save_as_gif(seg_pil, output_dir, duration=video_duration // video_length)            
        save_as_gif(org_pil, output_dir.replace('seg', 'org'), duration=video_duration // video_length)
    

        if ovlp_spixel_masks is not None:
            video_name = video_names[bid] if video_names is not None else f'video_idx{bid}'
            intp_batch_ovlp_spixel_masks = ovlp_spixel_masks[bid].repeat_interleave(3, dim=2).repeat_interleave(3, dim=3)
            
            for key_id, ov_sm_ in enumerate(intp_batch_ovlp_spixel_masks):
                indv_blended_image_array = IndvlabelVisualize(key_id, color_dict_rgba[1:], frames[bid], ov_sm_)
            
                indv_blended_frames = torch.Tensor(indv_blended_image_array)
                seg_pil = to_pil_images(indv_blended_frames.unsqueeze(0)/255, output_type="pil")
            
                output_dir = f"{video_name}_seg_dur{str(video_duration).replace('000', 'k')}_fid{fid}_nkey{n_key}_sem{key_id}.gif"
                output_dir = os.path.join(base_path, output_dir)
                save_as_gif(seg_pil, output_dir, duration=video_duration // video_length)            
    
    
def labelVisualize(num_classes, color_dict, img, label_matrix):        
    # Step 4: Apply the color map to the label matrix
    label_matrix += 1
    adjusted_labels = np.where(label_matrix > 0, label_matrix, 0)
    overlay_image = color_dict[adjusted_labels]
                
    edges = np.zeros_like(label_matrix)
    for j in range(num_classes+1):
        one_hot_lab = np.where(label_matrix == j, 1, 0)
        edges += np.logical_xor(one_hot_lab, binary_erosion(one_hot_lab))
        
    edges = np.where(edges > 0, 1, 0)
    
    edge_image = np.zeros_like(img)
    edge_color = [0, 0, 0]  # White edges
    edge_image[edges] = edge_color

    # Blend edges with the original image to highlight boundaries
    original_with_edges = np.where(edges[..., None], edge_image, img)

    # Convert the original image to RGBA for blending
    original_image_rgba = np.concatenate([original_with_edges, np.full((288, 288, 1), 255, dtype=np.uint8)], axis=-1)
    
    # Blend the original image and the overlay
    original_pil = Image.fromarray(original_image_rgba.astype(np.uint8))
    overlay_pil = Image.fromarray(overlay_image.astype(np.uint8))
    blended_image = Image.alpha_composite(original_pil, overlay_pil)

    # Convert back to array to display using matplotlib
    blended_image_array = np.array(blended_image)

    return blended_image_array


def IndvlabelVisualize(key_id, color_dict, frames, label_matrix):        
    # Step 4: Apply the color map to the label matrix    
    label_matrix = label_matrix.cpu()
    
    blended_image_array_list = []
    for frame, local_label_matrix in zip(frames, label_matrix):
        # use one color
        overlay_image = color_dict[local_label_matrix]
    
        edges_ = np.logical_xor(local_label_matrix, binary_erosion(local_label_matrix))
    
        edge_image = np.zeros_like(frame)
        edge_color = [0, 0, 0]  # black edges
        edge_image[edges_] = edge_color

        # Blend edges with the original image to highlight boundaries
        original_with_edges = np.where(edges_[..., None], edge_image, frame)

        # Convert the original image to RGBA for blending
        original_image_rgba = np.concatenate([original_with_edges, np.full((288, 288, 1), 255, dtype=np.uint8)], axis=-1)
        
        # Blend the original image and the overlay
        original_pil = Image.fromarray(original_image_rgba.astype(np.uint8))
        overlay_pil = Image.fromarray(overlay_image.astype(np.uint8))
        blended_image = Image.alpha_composite(original_pil, overlay_pil)

        # Convert back to array to display using matplotlib
        blended_image_array = np.array(blended_image)
        blended_image_array_list.append(blended_image_array)

    return blended_image_array_list

def to_pil_images(video_frames: torch.Tensor, output_type='pil'):
    video_frames = rearrange(video_frames, "b f w h c -> b f c w h")
    bsz = video_frames.shape[0]
    images = []
    for i in range(bsz):
        video = video_frames[i]
        for j in range(video.shape[0]):
            if output_type == "pil":
                images.append(to_pil(video[j]))
            else:
                images.append(video[j])        
    return images

