from collections import defaultdict
from pathlib import Path

from lutils import openf, writef
from PIL import Image


def concat_h_imgs(im_list, resample=Image.Resampling.BICUBIC):
    min_height = min(im.height for im in im_list)
    im_list_resize = [
        im.resize(
            (int(im.width * min_height / im.height), min_height), resample=resample
        )
        for im in im_list
    ]
    total_width = sum(im.width for im in im_list_resize)
    dst = Image.new("RGB", (total_width, min_height))
    pos_x = 0
    for im in im_list_resize:
        dst.paste(im, (pos_x, 0))
        pos_x += im.width
    return dst


class VisualizeVideo:
    def __init__(self, frames_dir):
        if isinstance(frames_dir, str):
            frames_dir = Path(frames_dir)
        self.videoid2pth = self.get_videoid2pth(frames_dir=frames_dir)

    def __getitem__(self, video_id):
        frames_pths = self.videoid2pth[video_id]
        frames = [openf(pth) for pth in frames_pths]
        return concat_h_imgs(frames)

    @staticmethod
    def get_videoid2pth(frames_dir):
        # Create a dict mapping video_id to a list of frame paths
        videoid2pth_pth = frames_dir / "videoid2pth.json"
        if videoid2pth_pth.exists():
            return openf(videoid2pth_pth)

        videoid2pth = defaultdict(list)

        frames_pths = [frames_pth for frames_pth in frames_dir.glob(f"*/*.png")]

        for frame_pth in frames_pths:
            videoid = frame_pth.parent.name
            videoid2pth[videoid].append(str(frame_pth))

        for frame_pth in frames_pths:
            frame_id = frame_pth.stem
            assert frame_id[-4] == "_"
            video_id = frame_id[:-4]

            videoid2pth[video_id].append(str(frame_pth))

        videoid2pth = {k: sorted(v) for k, v in videoid2pth.items()}
        writef(videoid2pth_pth, videoid2pth)


def sample_frames(frames_videos, vlen):
    import numpy as np

    acc_samples = min(frames_videos, vlen)
    intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
    ranges = []
    for idx, interv in enumerate(intervals[:-1]):
        ranges.append((interv, intervals[idx + 1] - 1))

    frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]

    return frame_idxs


def extract_frames(url, n_frames=10):
    import urllib.request

    import cv2
    import numpy as np
    from PIL import Image

    # Download the video from the URL
    resp = urllib.request.urlopen(url)
    video = resp.read()

    # Load the video using OpenCV
    video = cv2.VideoCapture(url)

    # Get the total number of frames in the video
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

    # Sample `n_frames` frames from the video
    frames_idxs = sample_frames(n_frames, total_frames)

    frames = []
    for frame_number in frames_idxs:
        video.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        ret, frame = video.read()
        if ret:
            # Convert the color channels from BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # Convert the NumPy array to a PIL Image object
            pil_image = Image.fromarray(np.uint8(frame))
            frames.append(pil_image)
    return frames


def visualize_url_video(url, n_frames=10):
    frames = extract_frames(url, n_frames=n_frames)
    if n_frames == 1:
        return frames[0]
    return concat_h_imgs(frames)


def visualize_path_video(path, n_frames=10):
    import cv2
    import numpy as np
    from PIL import Image

    video = cv2.VideoCapture(path)
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    frames_idxs = sample_frames(n_frames, total_frames)
    frames = []
    for frames_idx in frames_idxs:
        video.set(cv2.CAP_PROP_POS_FRAMES, frames_idx)
        ret, frame = video.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(np.uint8(frame))
            frames.append(pil_image)

    if len(frames) == 0:
        return None

    if n_frames == 1:
        return frames[0]

    return concat_h_imgs(frames)
