import os
import torch
import random

import torchvision.transforms.v2.functional as TF

def default(a, b):
    return a if a is not None else b

def shift_img(img, shift_x, shift_y):
    '''
    Shift image along x and y axis
    Args:
        img (Tensor): image tensor with shape (..., C, H, W)
        shift_x (int): number of pixels to shift vertically
        shift_y (int): number of pixels to shift horizontally
    Returns:
        shifted image tensor
    '''
    if shift_x == 0 and shift_y == 0:
        return img
    return img.roll(shifts=(shift_x, shift_y), dims=(-2, -1))


class CommonDataset(torch.utils.data.Dataset):
    shift_max = 5000
    def __init__(self, split, seq_len, root, variant=None, transform=None, start_frame=None, random=False,
                 aug_flip=False, aug_rotate90=False, aug_shift=False, shift_size=1, circular_shift=False):
        # self.dataset = torch.load(self.path_to_data(root, variant, split))
        self.dataset = self.load_data(root, variant, split)
        self.seq_len = self.max_len if seq_len == -1 else seq_len
        assert self.seq_len <= self.max_len, f"seq_len ({self.seq_len}) must be less than max_len ({self.max_len})"
        
        self.transform = transform
        self.random = random
        self.start_frame = start_frame
        self.aug_flip = aug_flip
        self.aug_rotate90 = aug_rotate90
        self.aug_shift = aug_shift
        self.shift_size = shift_size
        self.circular_shift = circular_shift
        

    @property
    def max_len(self):
        return self.dataset.shape[1]

    def load_data(self, root, variant, split) -> torch.Tensor:
        pass

    def __getitem__(self, index):
        flip = random.random() > 0.5 if self.aug_flip else False
        # randomly sample rotatation angle from 4 choices: 0, 90, 180, 270. if self.aug_rotate is False, then angle = 0
        angle = random.choice([0, 90, 180, 270]) if self.aug_rotate90 else 0
        shift_x = random.randrange(self.shift_max) if self.aug_shift else 0
        shift_y = random.randrange(self.shift_max) if self.aug_shift else 0
        max_len = len(self.dataset[index])

        if self.random:
            view_ids = random.sample(range(self.max_len), self.seq_len)
            seq = self.dataset[index][view_ids]
        else:
            if self.circular_shift:
                start = default(self.start_frame, random.randrange(max_len))
            else:
                start = default(self.start_frame, random.randrange(max_len - self.seq_len + 1))
            seq = self.dataset[index].roll(-start, dims=0)[:self.seq_len]
        seq = seq.float()

        # apply augmentation
        if flip:
            seq = TF.horizontal_flip(seq)
        if angle:
            seq = TF.rotate(seq, angle)
        if shift_x or shift_y:
            seq = shift_img(seq, shift_x * self.shift_size, shift_y * self.shift_size)

        # apply user defined transform (e.g. normalization)
        if self.transform:
            seq = self.transform(seq)
        
        dummy_actions = torch.zeros(len(seq)).long()
        
        return seq, dummy_actions, index

    def __len__(self):
        '''
        Number of videos (not frames)
        '''
        return len(self.dataset)