from re import X
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import ruamel.yaml as yaml
import numpy as np
from tqdm import tqdm 
from einops import rearrange

import os, shutil
import argparse

from utils.misc import load_model_state_dict
from run_utils import build_model

import warnings
warnings.filterwarnings('ignore')

from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HubDataset
from torchvision import transforms
from data.dataset import center_crop_tensor
import decord
import cv2

class VideoDataset(Dataset):
    def __init__(self, csv_file, data_column, image_size, fps=24):
        self.data_frame = HubDataset.from_csv(csv_file, cache_dir='/group/cache/datasets')
        self.data_column = data_column

        self.image_size = image_size
        self.fps = fps
        decord.bridge.set_bridge("torch")

        transform = [
            transforms.Lambda(lambda x: rearrange(x, 'f h w c -> f c h w')),
            transforms.Lambda(lambda x: x.float() / 255.),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
            transforms.Lambda(lambda x: center_crop_tensor(x, 256)),
            transforms.Lambda(lambda x: rearrange(x, 'f c h w -> c f h w'))
        ]

        self.transform = transforms.Compose(transform)

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        row = self.data_frame[idx]
        video_path = row[self.data_column]
        h, w = self.get_read_size(row['height'], row['width'], self.image_size)
        frames = self.decord_read_video(video_path, h, w)
        frames = self.transform(frames)
        return frames
        
    def get_read_size(self, height, width, image_size):
        if height > width:
            height = int(height / width * image_size)
            width = image_size
        else:
            width = int(width / height * image_size)
            height = image_size
        return height, width

    def decord_read_video(self, video_path, height, width):
        ctx = decord.cpu(0)
        vr = decord.VideoReader(video_path, ctx=ctx, height=height, width=width)
        fps = vr.get_avg_fps()
        num_samples = int(len(vr) / fps * self.fps)
        indices = np.linspace(
            0, len(vr) - 1, num_samples
        ).astype(int)
        frames = vr.get_batch(indices)
        return frames

def get_num_frames(period, fps, t_patch_size):
    return int(round(period * fps / t_patch_size) * t_patch_size)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--ckpt", type=str, required=True)
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--save_path", type=str, required=True)
    parser.add_argument("--enc_fps", type=float, default=24.0)
    parser.add_argument("--dec_fps", type=float, default=24.0)
    
    #fFirst parse of command-line args to check for config file
    args = parser.parse_args()
    
    # If a config file is specified, load it and set defaults
    if args.config is not None:
        with open(args.config, 'r', encoding='utf-8') as f:
            file_yaml = yaml.YAML()
            config_args = file_yaml.load(f)
            parser.set_defaults(**config_args)
    
    # re-parse command-line args to overwrite with any command-line inputs
    args = parser.parse_args()
    
    model = build_model(args)
    checkpoint = torch.load(args.ckpt, map_location="cpu")
    model.load_state_dict(load_model_state_dict(checkpoint, model), strict=False)
    model = model.eval().cuda()

    dataset = VideoDataset(args.data_path, 'video_path', args.image_size, fps=args.enc_fps)
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        drop_last=False
    )
    print(f"Dataset contains {len(dataset):,} records ({args.data_path})")

    if os.path.exists(args.save_path):
        shutil.rmtree(args.save_path)
    os.makedirs(args.save_path)

    n_enc = get_num_frames(args.period, args.enc_fps, args.t_patch_size)
    n_dec = get_num_frames(args.period, args.dec_fps, args.t_patch_size)

    for i, x in enumerate(tqdm(loader)):
        b, c, f, h, w = x.shape
        assert b == 1
        x = x[:, :, :n_enc]
        x = x.cuda()

        with torch.no_grad():
            z, _, _ = model.encode(x, num_frames=n_enc, fps=args.enc_fps)
            y = model.decode(z, x=None, h=h, w=w, num_frames=n_dec, fps=args.dec_fps)

        frames = (torch.clamp(y[0], min=-1, max=1) + 1) / 2
        frames = frames.cpu().float().numpy()
        frames = rearrange(frames, 'c f h w -> f h w c')
        frames = (frames * 255).astype(np.uint8)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(os.path.join(args.save_path, f'{i}.mp4'), fourcc, args.dec_fps, (w, h))
        for frame in frames:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            out.write(frame)