import os
import torch
import cv2
import pdb
import time
import sys

import torchvision
from PIL import Image
import numpy as np
import tqdm
from torchvision.io import read_video , write_video
from dataset.transforms import ToTensorVideo , Resize

import torch.nn.functional as F


import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model',  type=str,  required=True ,choices=["SCubA","SGuTA"], default="SCubA")
parser.add_argument("--input_video" , type=str , required=True , help="Path/WebURL to input video")
parser.add_argument("--youtube-dl" , type=str , help="Path to youtube_dl" , default=".local/bin/youtube-dl")
parser.add_argument("--factor" , type=int , required=True , choices=[2,4,8] , help="How much interpolation needed. 2x/4x/8x.")
parser.add_argument("--codec" , type=str , help="video codec" , default="mpeg4")
parser.add_argument("--load_model" , required=True , type=str , help="path for stored model")
parser.add_argument("--up_mode" , type=str , help="Upsample Mode" , default="transpose")
parser.add_argument("--output_ext" , type=str , help="Output video format" , default=".avi")
parser.add_argument("--input_ext" , type=str, help="Input video format", default=".mp4")
parser.add_argument("--downscale" , type=float , help="Downscale input res. for memory" , default=1)
parser.add_argument("--output_fps" , type=int , help="Target FPS" , default=30)
parser.add_argument("--nbr_frame" , type=int  , default=6)
parser.add_argument("--is_folder" , action="store_true" )
parser.add_argument("--device" , type=str ,default="cuda:0")

args = parser.parse_args()

input_video = args.input_video
input_ext = args.input_ext

from os import path

if not args.is_folder and not path.exists(input_video):
    print("Invalid input file path!")
    exit()
    
if args.is_folder and not path.exists(input_video):
    print("Invalid input directory path!")
    exit()

if args.output_ext != ".avi":
    print("Currently supporting only writing to avi. Try using ffmpeg for conversion to mp4 etc.")

if input_video.endswith("/"):
    video_name = input_video.split("/")[-2].split(input_ext)[0]
else:
    video_name = input_video.split("/")[-1].split(input_ext)[0]

output_video = os.path.join(video_name + f"_{args.factor}x" + str(args.output_ext))

n_outputs = args.factor - 1

if input_video.startswith("http"):
    assert args.youtube_dl is not None
    youtube_dl_path = args.youtube_dl
    cmd = f"{youtube_dl_path} -i -o video.mp4 {input_video}"
    os.system(cmd)
    input_video = "video.mp4"
    output_video = "video" + str(args.output_ext) 

def loadModel(model, checkpoint):
    
    saved_state_dict = torch.load(checkpoint)['state_dict']
    saved_state_dict = {k.partition("module.")[-1]:v for k,v in saved_state_dict.items()}
    model.load_state_dict(saved_state_dict)

if args.model == 'SCubA':
    from model.SCubA import SCubA
    model = SCubA(in_channels=3, out_channels=3, n_feat=64, patch_size=(1,4,4), cube_size=(2,4,4), stage=3).to(args.device)
if args.model == 'SGuTA':
    from model.SGuTA import SGuTA
    model = SGuTA(in_channels=3, out_channels=3, n_feat=64, patch_size=(1,4,4), stage=1, num_frm=8).to(args.device)
loadModel(model , args.load_model)
model = model.cuda()

def write_video_cv2(frames , video_name , fps , sizes):

    out = cv2.VideoWriter(video_name,cv2.CAP_OPENCV_MJPEG,cv2.VideoWriter_fourcc('M','J','P','G'), fps, sizes)

    for frame in frames:
        out.write(frame)


def make_image(img):
    q_im = img.data.mul(255.).clamp(0,255).round()
    im = q_im.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
    return im

def files_to_videoTensor(path , downscale=1.):
    from PIL import Image
    files = sorted(os.listdir(path))
    print(len(files))
    images = [torch.Tensor(np.asarray(Image.open(os.path.join(input_video , f)))).type(torch.uint8) for f in files]
    print(images[0].shape)
    videoTensor = torch.stack(images)
    return videoTensor

def video_to_tensor(video):

    videoTensor , _ , md = read_video(video)
    fps = md["video_fps"]
    print(fps)
    return videoTensor

def video_transform(videoTensor , downscale=1):
    
    T , H , W = videoTensor.size(0), videoTensor.size(1) , videoTensor.size(2)
    downscale = int(downscale * 8)
    resizes = 8*(H//downscale) , 8*(W//downscale)
    transforms = torchvision.transforms.Compose([ToTensorVideo() , Resize(resizes)])
    videoTensor = transforms(videoTensor)
    
    # resizes = 720,1280
    print("Resizing to %dx%d"%(resizes[0] , resizes[1]) )
    return videoTensor , resizes

if args.is_folder:
    videoTensor = files_to_videoTensor(input_video , args.downscale)
else:
    videoTensor = video_to_tensor(input_video)

idxs = torch.Tensor(range(len(videoTensor))).type(torch.long).view(1,-1).unfold(1,size=args.nbr_frame,step=1).squeeze(0)
videoTensor , resizes = video_transform(videoTensor , args.downscale)
print("Video tensor shape is , " , videoTensor.shape)

frames = torch.unbind(videoTensor , 1)
n_inputs = len(frames)
width = n_outputs + 1

outputs = [] ## store the input and interpolated frames
outputs.append(frames[idxs[0][0]])
outputs.append(frames[idxs[0][1]])
outputs.append(frames[idxs[0][2]])
model = model.eval()

for i in tqdm.tqdm(range(len(idxs))):
    idxSet = idxs[i]
    inputs = [frames[idx_].cuda().unsqueeze(0) for idx_ in idxSet]
    with torch.no_grad():
        outputFrame = model(inputs)   
    outputFrame = [of.squeeze(0).cpu().data for of in outputFrame]
    outputs.extend(outputFrame)
    outputs.append(inputs[2].squeeze(0).cpu().data)
outputs.append(frames[idxs[-1][-3]])
outputs.append(frames[idxs[-2][-2]])
outputs.append(frames[idxs[-2][-1]])
new_video = [make_image(im_) for im_ in outputs]

write_video_cv2(new_video , output_video , args.output_fps , (resizes[1] , resizes[0]))

print("Writing to " , output_video.split(".")[0] + ".mp4")
os.system('ffmpeg -hide_banner -loglevel warning -i %s %s'%(output_video , output_video.split(".")[0] + ".mp4"))
os.remove(output_video)
