import torch
import math
from torchvision.transforms import functional as F
from torchvision.transforms import ToPILImage
import os
import cv2
from PIL import Image, ImageDraw
import torch.multiprocessing as mp
import matplotlib.pyplot as plt
import subprocess
import warnings

def video_zeros(vid):
    size = list(vid.size())
    size[0] = 1
    return torch.zeros(*size)

def annotate_video(vid, text, position=(10,10), color=(255, 255, 0)):
    seq = []
    for img in torch.split(vid, 1, dim=1):
        img = F.to_pil_image(img.squeeze(1).cpu())
        d = ImageDraw.Draw(img)
        d.text(position, text, fill=color)
        seq.append(F.to_tensor(img))
    seq = torch.stack(seq, dim=1)
    return seq

def collage(batch, stub, preds=None):
    # assume batch first (B, C, T, W, H)
    assert len(batch) == len(preds)
    short_length = 0
    bs = batch.size(0)
    length = math.ceil(math.sqrt(bs)) 
    rows = []
    for row, mb in enumerate(torch.split(batch, length, dim=0)):
        if mb.size(0) < length:
            mb = torch.cat([mb] + [video_zeros(mb)] * (length - mb.size(0)), dim=0) 
            assert mb.size(0) == length, "Got {}".format(mb.size())
        if preds is not None:
            row_list = []
            for col, vid in enumerate(torch.split(mb.squeeze(0), 1, dim=0)):
                if len(preds) <= row * length + col:
                    row_list.append(vid.squeeze(0)) 
                else:
                    seq = annotate_video(vid.squeeze(0), preds[row * length + col])
                    row_list.append(seq)
        else:
            row_list = torch.split(mb, 1, dim=0) 
        row_list = torch.cat(row_list, dim=-1).squeeze(0)
        rows.append(row_list)
    collage = torch.cat(rows, dim=2)
    save_tensor_to_video(collage, stub)
    return collage

def save_tensor_to_video(t, fname_stub, fmt="{:02d}", img_ext=".png", vid_ext=".mp4", fps=10):
    for i, frame in enumerate(torch.split(t, 1, dim=1)):
        frame = F.to_pil_image(frame.squeeze(1).cpu())
        frame.save("_".join([fname_stub, fmt.format(i)]) + img_ext)
    os.system("ffmpeg -y -loglevel quiet -framerate {} -i {}_%2d{} {}{}".format(fps, fname_stub, img_ext, fname_stub, vid_ext))
    os.system("rm -rf {}*{}".format(fname_stub, img_ext))

def display_video(vid, msg="", cmap='viridis'):
    fig = plt.figure(figsize=(18, 10), linewidth=4, edgecolor="#000000")
    for i, img in enumerate(torch.split(vid.cpu(), 1, dim=1)):
        ax = fig.add_subplot(4, vid.size(1) // 4, i+1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("Frame {}".format(i))
        img = ToPILImage()(img.squeeze(1))
        img = img.resize((vid.size(2)*3, vid.size(3)*3))
        plt.imshow(img, cmap=cmap)
    fig.text(0.5, 1, msg, fontsize='xx-large', ha='center')

def batch_transform(X, transform, meta, workers=64, batch_dim=0):
    seq = []

    for i, vid in enumerate(torch.split(X, 1, dim=batch_dim)):
        res = transform(vid.squeeze(batch_dim), meta=meta[i])
        if type(res) == torch.Tensor: # for file corruptions, some transformations will yield NoneType for unrecoverabe cases
            seq.append(res)   
    
    if len(seq):
        seq = torch.stack(seq, 0)
        return seq
    else:
        return torch.Tensor(0) # empty Tensor for methods that still expect a Tensor output

def img_to_vid(args, dirname, tempname):
    temp_vidname = os.path.join('network_corruption_temp', tempname)
    img_vid_cmd = ['ffmpeg', '-y', '-loglevel', args.ffmpeg_log_level, '-start_number', '1', '-i', 
    os.path.join(dirname, '%06d.jpg'), '-c:v', 'libx264', '-vf', 'crop=trunc(iw/2)*2:trunc(ih/2)*2,fps=25', '-pix_fmt', 'yuv420p', temp_vidname]
    if args.limit is not None: print("Img-to-vid command:", img_vid_cmd)
    to_video = subprocess.Popen(img_vid_cmd)
    to_video.wait()
    return temp_vidname

def vid_to_img(args, fname, dirname):
    #corruption_name = form_corruption_name(args)
    #dirtokens = dirname.split('/')
    #corrupted_path = os.path.join(args.base_path, 'network_corruptions', corruption_name, *dirtokens[-5:])
    vid_img_cmd = ['ffmpeg', '-y', '-loglevel', args.ffmpeg_log_level, '-i', fname, '-start_number', '1', os.path.join(dirname, '%06d.jpg')]
    if args.limit is not None:
        print("Vid-to-img command:", vid_img_cmd)
    to_img = subprocess.Popen(vid_img_cmd)
    to_img.wait()
    return dirname

def gen_temp(old_path, tmp):
    if os.path.isfile(tmp): warnings.warn("File {} already exists".format(tmp))
    cmd = ['ffmpeg', '-y', '-loglevel', 'quiet', '-i', old_path, '-vcodec', 'libx264', '-strict', '-2', tmp]
    p = subprocess.Popen(cmd)
    p.communicate()
    if p.returncode != 0: warnings.warn(' '.join(cmd) + " failed with exit code " + str(p.returncode))
    tries = 1
    while not os.path.isfile(tmp) and tries < 10:
        p = subprocess.Popen(cmd)
        p.communicate()
        if p.returncode != 0: warnings.warn(' '.join(cmd) + " failed with exit code " + str(p.returncode))
        tries += 1
    assert os.path.isfile(tmp), "Failed after 10 tries on command " + ' '.join(cmd)

def scrape_dirs(args):
    """
        if you copied the MOT20 dataset to the directory at args.base_path, then args.img_folder 
        should be MOT20/images/train by default
    """
    seq_folders = os.listdir(os.path.join(args.base_path, args.img_folder))
    final_folders = [os.path.join(args.base_path, args.img_folder, f, 'img1') for f in seq_folders]
    return final_folders

def safe_capture(file, max_attempts=5):
    cap = cv2.VideoCapture(file)
    attempts = 1
    while not cap.isOpened() and attempts < max_attempts: 
        cap = cv2.VideoCapture(file)
        attempts += 1
    if attempts > max_attempts: 
        raise RuntimeError("Failed to open {} with cv2.VideoCapture after {} tries".format(file, attempts))
    return cap

class stderr_suppress(object):
    '''
    A context manager for doing a "deep suppression" of stdout and stderr in 
    Python, i.e. will suppress all print, even if the print originates in a 
    compiled C/Fortran sub-function.
       This will not suppress raised exceptions, since exceptions are printed
    to stderr just before a script exits, and after the context manager has
    exited (at least, I think that is why it lets exceptions through).      

    '''
    def __init__(self):
        # Open a pair of null files
        self.null_fd = os.open(os.devnull,os.O_RDWR) 
        # Save stderr (2) file descriptor.
        self.save_fd = os.dup(2)

    def __enter__(self):
        # Assign the null pointers to stdout and stderr.
        os.dup2(self.null_fd, 2)

    def __exit__(self, *_):
        # Re-assign the real stdout/stderr back to (1) and (2)
        os.dup2(self.save_fd, 2)
        # Close all file descriptors
        os.close(self.null_fd)
        os.close(self.save_fd)
