import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(current_dir)
sys.path.append(parent_dir)

from utils_hyfluid_psnr import *
from tqdm import tqdm, trange
import cv2
import numpy as np
from run_nerf_helpers import NeRFSmall, to8b, img2mse, mse2psnr, get_rays_np, get_rays, get_rays_np_continuous, sample_bilinear
import lpips
import torch
from pdb import set_trace as bp
from einops import rearrange
from utils.YParams import YParams
from skimage.metrics import structural_similarity

from PIL import Image
from swin_transformer.swin_transformer import build_vmae  # models.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_frames(args, start_num=0, end_num=9, gt=False):
    imgs = []
    if gt:
        for idx in range(start_num, end_num):
            image = cv2.imread(os.path.join(args.gt_dir, f"gt_{str(idx).zfill(3)}.png"), cv2.COLOR_BGR2RGB) # gray, no need for COLOR_BGR2RGB
            H, W, _ = image.shape
            if H == 1920:
                C_H, C_W = H - 120 , int(W/2)
            elif H == 960:
                C_H, C_W = H - 60 , int(W/2)
            img_test = image.sum(axis=-1) # num_f, h, w      1920*1080   
            w_test = img_test.sum(axis=0) # add h, len = w       num_f*1080  
            h_test = img_test.sum(axis=1) # add w, len = h       num_f*1920
            w_test_mask = np.where(w_test>40, 1, 0)  # 40/255=0.15686
            h_test_mask = np.where(h_test>40, 1, 0)  # 40/255=0.15686
            y1 = C_H - h_test_mask[:C_H].sum()
            y2 = C_H + h_test_mask[C_H:].sum()
            x1 = C_W - w_test_mask[:C_W].sum()
            x2 = C_W + w_test_mask[C_W:].sum()
            image[:y1] = 0.
            image[y2:] = 0.
            image[:, :x1] = 0.
            image[:, x2:] = 0.
            imgs.append(image/255.) # image [h, w, 3]
    else:
        for idx in range(start_num, end_num):
            image = cv2.imread(os.path.join(args.hyfluid_dir, f"rgb_{str(idx).zfill(3)}.png"), cv2.COLOR_BGR2RGB) # gray, no need for COLOR_BGR2RGB
            imgs.append(image/255.) # image [h, w, 3]
    images = np.float32(imgs) # imgs [t, h, w, 3] numpy [0,1]

    images = torch.from_numpy(images)
    return images  # [t, h, w, 3] tensor [0,1]




def load_frames_from_videos(basedir, frame_num_cutoff, half_res, split= 'all'):
    all_imgs = []
    f_names= []
    with open(os.path.join(basedir, 'info.json'), 'r') as fp:
        # read render settings
        meta = json.load(fp)
        if split == 'all':
            video_list = []
            video_list.extend(meta['train_videos'])
            video_list.extend(meta['test_videos'])
        else:
            video_list = meta[split + '_videos'] if (split + '_videos') in meta else meta['train_videos'][0:1]
        for video_id, train_video in enumerate(video_list):
            imgs = []
            max_frame_num = 0
            f_name = os.path.join(basedir, train_video['file_name'])
            # bp()
            if frame_num_cutoff <= 0:
                frame_num = train_video['frame_num']
            else:
                frame_num = frame_num_cutoff
            max_frame_num = max(max_frame_num, frame_num)
            reader = imageio.get_reader(f_name, "ffmpeg")
            for frame_i in range(frame_num):
                reader.set_image_index(frame_i-1)
                frame = reader.get_next_data()
                H, W = frame.shape[:2]
                C_H, C_W = (H//2 - 60 , int(W//2/2)) if half_res else  (H - 120 , int(W/2))

                if half_res:
                    frame = cv2.resize(frame, (W//2, H//2), interpolation=cv2.INTER_AREA)
                
                # imgs.append(img)
                frame = np.where(frame<20, 0, frame)
                img_test = frame.sum(axis=-1) # num_f, h, w 1920*1080   # .sum(axis=0)
                w_test = img_test.sum(axis=0) # add h, len = w  num_f*1080  
                h_test = img_test.sum(axis=1) # add w, len = h  num_f*1920
                w_test_mask = np.where(w_test>40, 1, 0)
                h_test_mask = np.where(h_test>40, 1, 0)
                y1 = C_H - h_test_mask[:C_H].sum()
                y2 = C_H + h_test_mask[C_H:].sum()
                x1 = C_W - w_test_mask[:C_W].sum()
                x2 = C_W + w_test_mask[C_W:].sum()
                frame[:y1] = 0
                frame[y2:] = 0
                frame[:, :x1] = 0
                frame[:, x2:] = 0
                imgs.append(frame)
           
            reader.close()
            imgs = (np.float32(imgs) / 255.) 
            all_imgs.append(imgs)
            f_names.append(f_name)
        all_imgs = np.stack(all_imgs, 0)    # [V, T, H, W, 3]
        all_imgs = torch.from_numpy(all_imgs)
    return all_imgs, f_names


# calculate all frames
def calculate_crop_box(images):
    V, T, H, W, _ = images.shape   # [v, t, h, w, 3] tensor [0,1]
    if H == 1920:
        C_H, C_W = H - 120 , int(W/2)
    elif H == 960:
        C_H, C_W = H - 60 , int(W/2)
    else:
        print('Unrecognized size of images!')
        bp()
    f_h, f_w = H, W
    flow_h_top, flow_h_bottom, flow_w_left, flow_w_right = C_H, C_H, C_W, C_W   
 
    out_puts = torch.where(images<0.0784, 0, images).detach()  # 20/255=0.0784
    for output in out_puts:
        if H == 1920:
            output[:, -100:, :,:]=0
        elif H == 960:
            output[:, -50:, :,:]=0
        
        img_test = output.sum(axis=-1) # num_f, h, w 1920*1080   
        w_test = img_test.sum(axis=1) # add h, len = w  num_f*1080  
        h_test = img_test.sum(axis=2) # add w, len = h  num_f*1920
        w_test_mask = torch.where(w_test>0.1569, 1, 0)
        h_test_mask = torch.where(h_test>0.1569, 1, 0)
        y1 = C_H - h_test_mask[:, :C_H].sum(axis=1).max()
        y2 = C_H + h_test_mask[:, C_H:].sum(axis=1).max()
        x1 = C_W - w_test_mask[:, :C_W].sum(axis=1).max()
        x2 = C_W + w_test_mask[:, C_W:].sum(axis=1).max()
        if y1 < flow_h_top: flow_h_top = y1.item()
        if y2 > flow_h_bottom: flow_h_bottom = y2.item()
        if x1 < flow_w_left: flow_w_left = x1.item()
        if x2 > flow_w_right: flow_w_right = x2.item()
    print('nf:', T,  'bbox:', flow_h_top, flow_h_bottom, flow_w_left, flow_w_right)
    bbox = flow_h_top, flow_h_bottom, flow_w_left, flow_w_right
    size = (f_h-flow_h_top-(f_h-flow_h_bottom), f_w-flow_w_left-(f_w-flow_w_right)), (flow_h_top, f_h-flow_h_bottom, flow_w_left, f_w-flow_w_right)
    center = C_H, C_W
    return (bbox, size, center)


# resize_and_clip if before_inf=True else restore
def resize_and_clip(images, size_info, target_shape=224, before_inf=True, use_fea=False):  
    # bp()
    N,C, H, W = images.shape
    bbox, size_new, center = size_info

    if before_inf:
        new_images = images[:, bbox[0]:bbox[1], bbox[2]:bbox[3]].permute(0,3,1,2) # !!!!!!!
        new_images = F.interpolate(new_images, size=(target_shape, target_shape), mode='bilinear', align_corners=True)
        return new_images
    else:
        (flow_h, flow_w), (flow_pad_h_top, flow_pad_h_bottom, flow_pad_w_left, flow_pad_w_right) = size_new
             
        # if use_fea: bp()
        new_images = F.interpolate(images, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
        # if use_fea: bp()    
        if use_fea:
            new_images = F.pad(new_images, (flow_pad_w_left, flow_pad_w_right, flow_pad_h_top, flow_pad_h_bottom), mode='replicate')    
            new_images = F.interpolate(new_images, size=(H, W), mode='bilinear', align_corners=True)
        else:
            new_images = F.pad(new_images, (flow_pad_w_left, flow_pad_w_right, flow_pad_h_top, flow_pad_h_bottom))    
        return new_images


# inference use fm
def inference_fm(model, all_images, size_info, params, num_inf=15):
    
    all_features = []
    all_outputs = []
    V, T, H, W, C = all_images.shape
    for vi_images in all_images:
        # inf
        # clip and resize
        images = resize_and_clip(vi_images, size_info, target_shape=params.input_size, before_inf=True).unsqueeze(1)   # [t, 3, target_shape, target_shape] tensor [0,1] 

        # images = vi_images.permute(0,3,1,2)
        # images = F.interpolate(images, size=(params.input_size, params.input_size), mode='bilinear', align_corners=True).unsqueeze(1) 
        
        # [t, 1, 3, target_shape, target_shape] tensor [0,1] 
        T, B, C, H1, W1 = images.shape
        padding = torch.zeros((T, B, params.in_chans - 1, H1, W1)).to(images.device)
        images = torch.cat((images[:, :, :1],padding), dim=2)
        # adjust steps
        features = []
        with torch.no_grad():
            for i in range(0, len(images)):  # 1~20
                if i >= params.n_steps:
                    xx = images[i-params.n_steps:i] # if inference 11st frame, pick frame_id in [0,10) frames 
                    output, feature = model(xx, return_feature=True)
                elif i == 0 or i == 1:
                    # xx = images[:i+1].expand(params.n_steps, *images.shape[1:])
                    feature = torch.zeros([1, 96, 56, 56]).to(images.device)
                else:
                    xx = images[:i]
                    xx = rearrange(xx, 't b c h w -> b c t h w') # if inference 10th frame, pick [0,9) frames
                    xx = F.interpolate(xx, size=(params.n_steps, *xx.shape[3:]), mode='trilinear', align_corners=True) # 11 frames
                    xx = rearrange(xx, 'b c t h w -> t b c h w')[:params.n_steps]
                    output, feature = model(xx, return_feature=True)
                features.append(feature.detach()) # (1, 96, 56, 56)
            # continue predicting future frames
            if num_inf>0:
                outputs = []
                for i in range(num_inf): 
                    output, feature = model(xx, return_feature=True) # T_in, B, C_in, H, W = x.shape
                    xx = torch.cat((xx[1:], output.unsqueeze(0)), dim=0)
                    outputs.append(output.detach())
                    features.append(feature.detach())
                
                features = torch.cat(features, dim=0)
                outputs = torch.clip(torch.cat(outputs, dim=0), 0, 1)  # [num_inf, 1, 3, target_shape, target_shape] tensor [0,1] 
                # restore
                # outputs = F.interpolate(outputs, size=(H, W), mode='bilinear', align_corners=True)
                
                # features = F.interpolate(features, size=(fea_shape[-2], fea_shape[-1]), mode='bilinear', align_corners=True)

                outputs = resize_and_clip(outputs, size_info, target_shape=params.input_size, before_inf=False) # restore to [num_inf, h, w, 3] tensor [0,1]
                # bp()
                features = resize_and_clip(features, size_info, target_shape=params.input_size, before_inf=False, use_fea=True)

                num_inf, _, h, w =  outputs.shape
                outputs = outputs[:, :1].expand(num_inf, 3, h, w).permute(0,2,3,1).detach()  # [num_inf, h, w, 3]
        all_features.append(features.unsqueeze(0))
        all_outputs.append(outputs.unsqueeze(0))
    all_features = torch.cat(all_features, dim=0)
    all_outputs = torch.cat(all_outputs, dim=0)
    return all_outputs, all_features


def config_parser():
    import configargparse
    parser = configargparse.ArgumentParser()
    parser.add_argument("--gt_dir", type=str, help='where to rgb images of ground truth.')
    parser.add_argument("--fm_yaml", type=str, help='config of fm')
    parser.add_argument("--fm_load", type=str, default=None, help='config of fm') 
    parser.add_argument("--num_initial_frames", type=int, default=20, help='how many hy frames used for the first time')
    parser.add_argument("--num_tested_frames", type=int, default=15, help='total num of frames (including initial and predicted frames)')
    parser.add_argument("--out_dir", type=str, default='logs', help='dir to save frames')
    return parser


def main():
    split = 'all'
    half_res = True
    parser = config_parser()
    args = parser.parse_args()
    num_tested_frames = args.num_tested_frames
    num_initial_frames = args.num_initial_frames
    # load hy frames and gt frames
    params = YParams(os.path.abspath(args.fm_yaml), config_name='basic_config')
    if args.fm_load:
        params['vmae_pretrained'] = args.fm_load
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # load model
    model = build_vmae(params).to(device).eval()
    # gt_frames = load_frames(args, 0, num_initial_frames, gt=True).to(device)   # [20, 960, 540, 3]
    gt_frames, f_names = load_frames_from_videos(args.gt_dir, num_initial_frames + num_tested_frames, half_res, split= 'all')
    gt_frames = gt_frames.to(device)
    size_info = calculate_crop_box(gt_frames)
    gt_frames = gt_frames[:, :num_initial_frames].to(device)
    outputs, features = inference_fm(model, gt_frames, size_info, params, num_inf=num_tested_frames)
    outputs = outputs.cpu().numpy()
    features = features.cpu().numpy()
    # bp()  # todo: several videos
    # save images
    if args.out_dir:
        run_name= args.fm_load.split('/')[-3]
        out_dir = os.path.join(args.out_dir, run_name)
        os.makedirs(out_dir, exist_ok=True)
        for f_i, f_name in enumerate(f_names):
            f_name = f_name.split('/')[-1].split('.')[0]
            frame_dir = os.path.join(out_dir, 'frames',f_name)
            feature_dir = os.path.join(out_dir, 'features',f_name)
            os.makedirs(frame_dir, exist_ok=True)
            os.makedirs(feature_dir, exist_ok=True)
            for feature_id, feature in enumerate(features[f_i]):
                np.save(os.path.join(feature_dir, f'{f_name}_{feature_id}.npy'),feature)
            for i, frame in enumerate(outputs[f_i]):
                frame_id = i+num_initial_frames
                cv2.imwrite(os.path.join(frame_dir, f'{f_name}_{frame_id}.png'), (frame*255).astype('uint8'))
                # im = Image.fromarray(np.repeat(frame[f_i, 0, :, :, np.newaxis], 3, axis=2).astype(np.uint8))
        np.save(os.path.join(out_dir, 'all_features.npy'), features)
    # bp()
    return outputs, features

      

import ipdb
if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        print(e)
        ipdb.post_mortem()
   