import os
import torch
from decord import VideoReader, cpu
import numpy as np
from PIL import Image
import random
import cv2
from utils.utils import *


def motion_condition_generator_train(video_path, random_mask_ratio):
    
    '''
    Input args:
        video_path:  training_video path
        random_mask_ratio: [r_min, 1.0]
        
    Return:
        video: training_video tensor, shape: [L,3,H,W]
        sparse_traj: region-wise sparse trajectories, shape: [L,2,H,W]
        motion_mask: motion mask, shape: [L,1,H,W]
        motion_bucket_id: average optical flow strength
    '''
    
    video_name = video_path.split('/')[-2]
    
    # The optical flow and visible mask extracted from Dense Optical Tracking (DOT)
    flow_seq_path = os.path.join(os.path.dirname(video_path), 'DOT_optical_flow.npy')
    visible_mask_folder = os.path.join(os.path.dirname(video_path), 'DOT_visible_mask')
    resolution = [320,512]
    
    # Local region size k
    block_size = 8
    
    # 1. Read videos
    video_reader = VideoReader(video_path, ctx=cpu(0))
    frame_indices = list(range(0, len(video_reader)))
    frames = video_reader.get_batch(frame_indices)
    frames = torch.tensor(frames.asnumpy()).permute(0,3,1,2).float() # [t,h,w,c] -> [t,c,h,w]
    t, c, h, w = frames.shape
    
    if resolution is not None:
        assert(frames.shape[2] == resolution[0] and frames.shape[3] == resolution[1]), f'frames={frames.shape}, self.resolution={resolution}'
    video = frames / 255.0 * 2.0 - 1.0
    
    # 2. Read visible mask, optical flow sequence and motion_bucket_id
    mask_file_path_list = [os.path.join(visible_mask_folder, img) for img in sorted(os.listdir(visible_mask_folder))]
    for i in range(t):
        if i==0:
            visible_mask_final = np.ones((h,w), dtype=np.float32)
        else:
            ori_mask_numpy = np.array(Image.open(mask_file_path_list[i-1])).astype(np.float32)
            mask_numpy = np.zeros_like(ori_mask_numpy, dtype=np.float32)
            mask_numpy[ori_mask_numpy>125.0] = 1.0 
            
            visible_mask_final = np.logical_and(visible_mask_final, mask_numpy).astype(np.float32)
    
    
    flow_sq = torch.from_numpy(np.load(flow_seq_path))
    motion_bucket_id = torch.mean(torch.sum(flow_sq**2, dim=-1).sqrt())
    
    # 3. Generate motion mask
    flow_sum_square = torch.zeros(h,w)
    for i in range(t):
        flow_sum_square += torch.sum(flow_sq[i]**2, dim=-1)
        
    motion_mask = flow_sum_square/t > 1.0
    motion_mask = motion_mask[None,:,:,None].repeat(t,1,1,1).to(torch.float32)
    
    # 4. Generate sparse region-wise trajectories
    count_flag = 0
    while True: 
        mask_ratio = random.uniform(min(random_mask_ratio), max(random_mask_ratio))
        
        block_mask = np.random.rand(h//block_size, w//block_size) > mask_ratio
        visible_mask_final_resized = cv2.resize(visible_mask_final, (w//block_size, h//block_size), interpolation=cv2.INTER_NEAREST) 
        
        block_mask_new = np.logical_and(block_mask, visible_mask_final_resized).astype(np.float32)
        full_mask = np.kron(block_mask_new, np.ones((block_size, block_size), dtype=np.uint8))
        full_mask = torch.from_numpy(full_mask).to(torch.float32)[None,:,:,None].repeat(t,1,1,1)
        
        masked_flow_sq = flow_sq * full_mask
        if torch.sum(full_mask)!=0:
            break
        else:
            count_flag = count_flag+1
        if count_flag>3:
            print(f'this video has a bad trajectory..{video_path}')
    
    data = {
        'video': video, 'sparse_traj': masked_flow_sq.permute(0,3,1,2), 'motion_bucket_id': motion_bucket_id,
        "video_name": video_name, "motion_mask": motion_mask.permute(0,3,1,2),
    }
    return data


def motion_condition_generator_eval(reference_image_path, resized_all_points, motion_mask_path, img_ratio, image_label_name):
    
    '''
    Input args:
        reference_image_path:  input first frame path
        motion_mask_path: motion mask path
        resized_all_points: trajectory keypoints list
        image_label_name: case ID name
        
    Return:
        sparse_traj: region-wise sparse trajectories, shape: [L,2,H,W]
        motion_mask:  motion mask, shape: [L,1,H,W]
        motion_bucket_id: specifies the average optical flow strength
    '''
    
    model_length=16
    motion_bucket_id=17
    width, height = img_ratio
    traj_mask = torch.zeros(model_length, height, width, 1)
    
    # 1. Convert input trajectories into region-wise trajectories
    input_drag = torch.zeros(model_length, height, width, 2)
    for splited_track in resized_all_points:
        if len(splited_track) == 1:
            displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
            splited_track = tuple([splited_track[0], displacement_point])
        
        # trajectory interpolation
        splited_track = interpolate_trajectory(splited_track, model_length)
        splited_track = splited_track[:model_length]
        if len(splited_track) < model_length:
            splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
        for i in range(model_length):
            start_point = splited_track[0]
            end_point = splited_track[i]
            start_y, start_x = int(start_point[1]), int(start_point[0])
            input_drag[i,start_y-4:start_y+4, start_x-4:start_x+4, 0] = end_point[0] - start_point[0]
            input_drag[i,start_y-4:start_y+4, start_x-4:start_x+4, 1] = end_point[1] - start_point[1]
            traj_mask[i] = (torch.sqrt(torch.sum(input_drag[i] ** 2, dim=-1)) > 0.0).to(torch.int).unsqueeze(-1)

    # 2. Generate motion mask
    assert motion_mask_path is not None, 'impossible..'
    if motion_mask_path is not None: 
        mask_brush = image2arr(motion_mask_path)
        mask_brush = torch.from_numpy(mask_brush)
        assert torch.equal(mask_brush[:,:,0], mask_brush[:,:,2]), 'Error: this mask should be the same across different channels'
        mask_brush = torch.any(mask_brush != 0, dim=-1).type(torch.uint8) * 255
        mask_brush = (mask_brush[None,:,:,None]/255).repeat(16,1,1,1).to(torch.int)
    
    
    if mask_brush is not None:
        motion_mask = torch.logical_or(traj_mask, mask_brush).to(torch.float32)

    data = {
        'sparse_traj': input_drag.permute(0,3,1,2), 'motion_bucket_id': motion_bucket_id,
        "video_name": image_label_name, "motion_mask": motion_mask.permute(0,3,1,2),
    }
    
    return data


def test_motion_condition_eval():

    # Input condition path: 'assets/input_condition' folder.
    data_path = 'assets/input_condition'
    image_label_name = os.path.basename(data_path)
    reference_image_path = os.path.join(data_path, 'reference_image.png')
    motion_mask_path = os.path.join(data_path, 'motion_mask.png')
    
    # Read trajectories
    resized_all_points = load_json(os.path.join(data_path, 'trajectory.json'), key='traj_key_points')
    
    # Read reference image and motion mask
    img_pil = Image.open(reference_image_path)
    img_ratio = img_pil.size
    data = motion_condition_generator_eval(reference_image_path, resized_all_points, motion_mask_path, img_ratio, image_label_name)

    return data
