import os
import sys
import torch
import datetime
import torch.utils.data
import numpy as np
import time
import matplotlib.pyplot as plt
from torch.utils.data import Sampler
import random
from tqdm import tqdm
import torch.distributed as dist

class UniformDownsampleDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, downsample_factor):
        self.original_dataset = original_dataset
        self.downsample_factor = downsample_factor
        self.original_length = len(original_dataset)
        self.epoch_counter = 0
        self.resample()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.original_dataset[self.indices[idx]]

    def resample(self):
        # For each epoch, select a different uniform downsample by changing the start index
        start_idx = self.epoch_counter % self.downsample_factor
        self.indices = list(range(start_idx, self.original_length, self.downsample_factor))
        # Increment the epoch counter
        self.epoch_counter += 1


class RandomDownsampleDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, downsample_factor):
        self.original_dataset = original_dataset
        self.downsample_factor = downsample_factor
        self.original_length = len(original_dataset)
        self.resample()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.original_dataset[self.indices[idx]]

    def resample(self):
        # Randomly select 1/n of the original dataset
        downsample_size = self.original_length // self.downsample_factor
        self.indices = random.sample(range(self.original_length), downsample_size)


class ClipDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, clip_length=12):
        self.original_dataset = original_dataset
        self.clip_length = clip_length
        self.clips = self._create_clips()
    
    def _create_clips(self):
        clips = []
        start_idx = 0  # Accumulated global index
        for seq_idx, sequence_length in enumerate(self.original_dataset.data_len_sequence):
            if seq_idx >= len(self.original_dataset.data_list):
                break
            seq_start = start_idx  # Starting global index for the current sequence

            seq_end = seq_start + sequence_length  # Ending global index for the current sequence

            # Create clips within the current sequence range
            while seq_start + self.clip_length <= seq_end:
                clips.append((seq_idx, seq_start, seq_start + self.clip_length))
                seq_start += self.clip_length

            # Handle the case where the last clip may partially overlap
            if seq_start < seq_end:
                clips.append((seq_idx, seq_end - self.clip_length, seq_end))
            
            # Update the global start index for the next sequence
            start_idx = seq_end

        return clips

    def __len__(self):
        return len(self.clips) * self.clip_length

    def __getitem__(self, idx):
        clip_idx = idx // self.clip_length
        frame_idx = idx % self.clip_length

        seq_idx, start, _ = self.clips[clip_idx]
        actual_index = start + frame_idx
        return self.original_dataset[actual_index]


class SequentialClipSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_clips = len(dataset.clips)
        self.clip_length = dataset.clip_length

    def __iter__(self):
        indices = []
        clip_indices = list(range(self.num_clips))

        # Randomly shuffle the clip indices
        random.shuffle(clip_indices)
        if self.num_clips % self.batch_size != 0:
            # Pad to make the total a multiple of batch_size
            clip_indices += clip_indices[:(self.batch_size - self.num_clips % self.batch_size)]
        for clip_start in range(0, self.num_clips, self.batch_size):
            for frame_idx in range(self.clip_length):
                # During each frame iteration, generate a batch of indices
                batch_indices = [clip_idx * self.clip_length + frame_idx for clip_idx in clip_indices[clip_start:clip_start+self.batch_size]]
                indices.extend(batch_indices)

        return iter(indices)

    def __len__(self):
        # Each clip has clip_length frames, so the total length is num_clips * clip_length
        return self.num_clips * self.clip_length
