import random
from torch.utils.data import Dataset
from video_utils import shuffle_except_first, get_video_frames
from rl_utils import get_query_content, get_gvl_content
import numpy as np


class OneVideoDataset(Dataset):
    def __init__(self, video_path, num_frames_query=5, reward_type='VOC', rollout_size=32, offset=None) -> None:
        super().__init__()
        self.frames = get_video_frames(video_path)
        self.num_frames_query = num_frames_query
        self.max_steps = rollout_size
        self.reward_type = reward_type

        if offset is not None:
            self.frames = self.frames[offset:]
        
        self.percentages = np.linspace(0, 100, len(self.frames))
        self.pairs = list(zip(self.frames, self.percentages))

    def __len__(self):
        return self.max_steps

    def get_uniform_frames(self, num_frames):
        # Generate uniform indices across the video length
        indices = np.linspace(0, len(self.frames) - 1, num_frames, dtype=int)
        
        # Get the corresponding frames
        selected_frames = [self.frames[i] for i in indices]
        
        # Create new uniform percentages from 0 to 100
        selected_percentages = np.linspace(0, 100, num_frames)
        
        return selected_frames, selected_percentages

    def __getitem__(self, idx):
        selected = random.sample(self.pairs, self.num_frames_query)
        # Sort by percentage (second element of tuple)
        selected.sort(key=lambda x: x[1])
        # Choose a random index to remove (not the first)
        
        if self.reward_type == 'Single':
            # initial style query 
            idx_to_remove = random.choice(range(1, len(selected)))
            frame_to_predict = selected[idx_to_remove]
            updated_selected = selected[:idx_to_remove] + selected[idx_to_remove+1:]
            shuffled_selected, shuffled_indices = shuffle_except_first(updated_selected)
            msg, images = get_query_content(frame_to_predict, shuffled_selected)
            return {
                "ground_truth": frame_to_predict[1], #only for initial reward
                "prompt": msg,
                "image": images,
            }

        elif self.reward_type == 'VOC':
            # VOC style query
            shuffled_selected, shuffled_indices = shuffle_except_first(selected)
            shuffled_selected = [pair[0] for pair in shuffled_selected]
            msg, images = get_gvl_content(shuffled_selected, [])
            return {
                "prompt": msg,
                "image": images,
                "shuffled_indices": shuffled_indices
            }
        else:
            raise NotImplementedError


class MixedVideoDataset(Dataset):
    def __init__(self, video_paths, num_frames_query=5, reward_type='VOC', rollout_size=32, offsets=None) -> None:
        super().__init__()
        self.datasets = []
        self.num_videos = len(video_paths)
        self.rollout_size = rollout_size

        for video_path, offset in zip(video_paths, offsets):
            temp_dataset = OneVideoDataset(video_path, num_frames_query=num_frames_query, reward_type=reward_type,
                                            rollout_size=rollout_size, offset=offset)
            self.datasets.append(temp_dataset)

    def __len__(self):
        return self.rollout_size

    def get_uniform_frames(self, num_frames):
        dataset_index = random.choice(range(self.num_videos))
        return self.datasets[dataset_index].get_uniform_frames(num_frames)

    def __getitem__(self, idx):
        dataset_index = random.choice(range(self.num_videos))
        return self.datasets[dataset_index][idx]
