import os
import cv2
import numpy as np
from tqdm import tqdm

def video_frames(clip_path, frame_len, img_size):
    video = cv2.VideoCapture(clip_path)
    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

    for count in range(frame_len):
        video.set(cv2.CAP_PROP_POS_FRAMES, count * max(int(frame_count / frame_len), 1))
        flag, frame = video.read()
        if not flag:
            break
        frame = cv2.resize(frame, (img_size, img_size))
        frame = frame / 255.0  # Normalize frame to [0, 1]
        yield frame
    video.release()

def create_video_dataset(video_base_directory, img_size=128, frame_len=16):
    labels = [
        "Basketball", "BasketballDunk", "Biking", "CliffDiving", "CricketBowling",
        "Diving", "Fencing", "FloorGymnastics", "GolfSwing", "HorseRiding",
        "IceDancing", "LongJump", "PoleVault", "RopeClimbing", "SalsaSpin",
        "SkateBoarding", "Skiing", "Skijet", "SoccerJuggling", "Surfing",
        "TennisSwing", "TrampolineJumping", "VolleyballSpiking", "WalkingWithDog",
        "SoccerPenalty", "FieldHockeyPenalty", "FrontCrawl", "BreastStroke"
    ]

    imgs = []
    labells = []
    idx = 0

    for label in tqdm(labels):
        video_directory = os.path.join(video_base_directory, label)

        video_files = [f for f in os.listdir(video_directory) if f.endswith('.avi')]
        for video_file in video_files:
            video_path = os.path.join(video_directory, video_file)
            frames_generator = video_frames(video_path, frame_len, img_size)
            frames_array = np.array(list(frames_generator), dtype='float16')
            imgs.append(frames_array)
            labells.append(idx)

        idx += 1

    dataset = np.array(imgs, dtype='float16')
    dataset_label = np.array(labells, dtype='int8')
    np.save("dataset/dataset.npy",dataset)
    np.save("dataset/dataset_label.npy", dataset_label)

    return dataset, dataset_label