import av
import numpy as np


class VideoReaderAV:
    def __init__(self, video_path):
        self.video_path = video_path
        self.container = av.open(video_path)
        self.container.streams.video[0].thread_type = "FRAME"
        self.container.streams.video[0].thread_count = 8

        self.video_fps = self.container.streams.video[0].average_rate
        self.time_base = self.container.streams.video[0].time_base
        self.width = self.container.streams.video[0].width
        self.height = self.container.streams.video[0].height
        self.duration_in_seconds = self.container.streams.video[0].duration * self.time_base
        self.video_num_frames = int(self.duration_in_seconds * self.video_fps)
    
    def get_avg_fps(self):
        return self.video_fps

    def __len__(self):
        return self.video_num_frames
    
    def __del__(self):
        try: self.container.streams.video[0].close()
        except: pass
        try: self.container.close()
        except: pass
        # import gc; gc.collect()
    
    def close(self):
        try: self.container.streams.video[0].close()
        except: pass
        try: self.container.close()
        except: pass
    
    def get_batch(self, sampled_indices):
        sampled_pts = [int(indice / self.video_fps / self.time_base) for indice in sampled_indices]
        frames = []
        for pts in sampled_pts:
            self.container.seek(pts, any_frame=False, backward=True, stream=self.container.streams.video[0])
            frame_found = False
            for i, frame in enumerate(self.container.decode(video=0)):
                if frame.pts >= pts:
                    frames.append(frame.to_rgb().to_ndarray())
                    frame_found = True
                    break

                # if i > 300:
                #     break
                # if i > 0 and frame.key_frame:
                #     break
            
            if not frame_found:
                break
            
        if len(frames) <= 0:
            raise ValueError("No frames found")
            
        if len(frames) < len(sampled_indices):
            # This is consistent with decord
            frames = frames + [frames[-1]] * (len(sampled_indices) - len(frames))

        return np.array(frames)

