import torchvision.io
import videoseal
import torch
import numpy as np
import os
import sys
import os.path
import re
from numpy import savetxt
from videoseal.evals.metrics import bit_accuracy
import inference_streaming
import detect_streaming

def main():
    args = sys.argv[1:]
    #opts = [homepath, outputwkpath, isHD, streaming, detectonly, videoID]
    opts = ["assets/videos/1.mp4","output/1seal.mp4", 1, 0,0,22,0 ]
    if len(args) > 0:
        for index, arg in enumerate(args):
            opts[index] = arg
    else:
        print("Configuration filled by default, [ input, outdir, dostreaming (0/1), detectonly (0/1),  crf (int), setpowertozero(0/1)]")

    video_path = opts[0]
    video_path_wk= opts[1]
    dostreaming = int(opts[2])
    detectonly = int(opts[3])
    crf = int(opts[4])
    setpowertozero = int(opts[5])
    typewk = "seal"
    filenamein = os.path.splitext(os.path.basename(video_path))[0]
    filenameout = os.path.splitext(os.path.basename(video_path_wk))[0]
    pfafileblindtxt = os.path.dirname(video_path_wk)  + "/blinddetect_" + typewk + "_" + filenameout + ".txt"
    keypath = 'message.csv'

    if not os.path.exists(keypath):
        msgInput = np.random.randint(0,2,96)
        with open(keypath, "w") as f:
            f.write("".join([str(msg) for msg in msgInput]))

    if dostreaming:
        script =  os.getcwd() + "/../videoseal/inference_streaming.py "
        output = ""
        if detectonly:
            script = os.getcwd() + "/../videoseal/detect_streaming.py "
            cmd = "python " + script +" --input " + video_path_wk + " --key  "  + keypath  + output + " 1> "  + pfafileblindtxt
        else:
            output = " --output_dir output/ "
            cmd = "python " + script +" --input " + video_path + " --key  "  + keypath  + " --crf " + str(crf) + " --setpowertozero " + str(setpowertozero) + output + " 1> "  + pfafileblindtxt

        print(cmd)
        os.system(cmd)
        if not detectonly:
            cmd = "mv output/" +filenamein + ".mp4" + " " + video_path_wk
            print(cmd)
            os.system(cmd)

    else:
        # Load video and normalize to [0, 1]
        video = torchvision.io.read_video(video_path, output_format="TCHW") #
        if (len(video[2]) > 0):
            videofps =  float(video[2]['video_fps'])
            audiofps =  float(video[2]['audio_fps'])
        else:
            print("Could not load video")
            exit(1)
        numpyvideo = video[0].permute(0,2,3,1).numpy() #format  THWC, and conversion in array.

        # Load the model
        model = videoseal.load("videoseal")

        # Video Watermarking
        video0 = video[0].float() / 255.0
        msg_Input = model.get_random_msg()  # 1 x k
        outputs = model.embed(video0, msg_Input, is_video=True) # this will embed a random msg
        video_w = outputs["imgs_w"] # the watermarked video
        msgs = outputs["msgs"] # the embedded message

        # Extract the watermark message
        msg_extracted = model.extract_message(video_w, aggregation="avg")

        #write images of wk video
        numpyvideo = video_w.permute(0,2,3,1).numpy() #format  THWC, and conversion in array.
        width = np.size(numpyvideo,1)
        height = np.size(numpyvideo,2)

        #probably due to a bug in torchvision implementation, videofps and audiofps needs to be casted in int in write_video function.
        torchvision.io.write_video(video_path_wk, numpyvideo*255.0, int(round(videofps,0)), "h264", options = {"crf": "1"},audio_array=video[1], audio_fps=int(round(audiofps,0)),audio_codec="aac") #

        #Compare msgs and msg_Input
        ratiodet = torch.sum(torch.eq(msg_Input, msg_extracted)).item()/msg_Input.nelement()
        print(ratiodet)
        with open(pfafileblindtxt, 'w') as f:
           print(ratiodet, file=f)

        # VideoSeal can do Image Watermarking
        #img = video[0:1] # 1 x C x H x W
        #outputs = model.embed(img, is_video=False)
        #img_w = outputs["imgs_w"] # the watermarked image
        #msg_extracted = model.extract_message(imgs_w, aggregation="avg", is_video=False)

if __name__ == "__main__":
    main()