import random
import numpy as np
import os
import torch
from glob import glob
import pfio
import logging
import torch.nn as nn
import torchvision.transforms as T
from torchvision.transforms import functional as F


class MyRandomRotation:
    def __init__(self, angles: list[int] = [0, 90, 180, 270]):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return F.rotate(x, angle)


class RandomShift:
    shift_max = 5000
    '''
    Shift image along x and y axis
    Args:
        img (Tensor): image tensor with shape (..., C, H, W)
    Returns:
        shifted image tensor
    '''

    def __call__(self, img):
        shift_x = random.randrange(self.shift_max)
        shift_y = random.randrange(self.shift_max)
        return img.roll(shifts=(shift_x, shift_y), dims=(-2, -1))


def fill_boundary(video, fill):
    video[..., 0, :] = fill
    video[..., -1, :] = fill
    video[..., :, 0] = fill
    video[..., :, -1] = fill
    return video


class Phyre128(torch.utils.data.Dataset):
    name = 'phyre128'
    max_ch_val = 255

    def __init__(self, split, seq_len, root, aug=False, boundary=False,
                 transform=None, scs_url=None, fps=1, fix_start=False):
        self.seq_len = seq_len

        # First, we need to get the list of train and test files.
        # Then we randomly split the all files into train and test.

        root = os.path.join(root, f'{self.name}_fps{fps}')
        self.fs = pfio.v2.from_url(root, http_cache=scs_url)

        self.all_files = glob(os.path.join(root, split, '*.npz'))
        # self.all_files = glob(os.path.join(root, '*/*.npz'))
        # len_test = len(glob(os.path.join(root, 'test/*.npz')))
        # random.seed(0)
        # random.shuffle(self.all_files)
        # self.all_files = self.all_files[:-len_test] if split == 'train' else self.all_files[-len_test:]

        self.aug = T.Compose([
            T.RandomHorizontalFlip(),
            MyRandomRotation(),
            RandomShift()
        ]) if aug else nn.Identity()
        
        self.fix_start = fix_start
        self.boundary = boundary
        self.transform = transform

    def __len__(self):
        return len(self.all_files)

    def get_video(self, index):
        # with self.fs.open(self.all_files[index]) as fp:
        #     arrays = np.load(fp)
        try:
            with self.fs.open(self.all_files[index]) as fp:
                arrays = np.load(fp)
            # arrays = np.load(self.all_files[index])
            video = arrays['img']
        except EOFError:
            logging.error(f"EOFError: No data left in file {self.all_files[index]}")
            video = self.handle_corrupted_file(self.all_files[index])
        except Exception as e:
            logging.error(f"Error loading file {self.all_files[index]}: {e}")
            video = self.handle_corrupted_file(self.all_files[index])
        
        return video
    
    def __getitem__(self, index):
        
        video = self.get_video(index)
        video = torch.from_numpy(video).float()

        max_seq_len = video.shape[0]
        start = 0 if self.fix_start else random.randrange(max_seq_len - self.seq_len + 1)
        end = start + self.seq_len
        video = video[start:end]

        if self.boundary:
            video = fill_boundary(video, fill=self.max_ch_val)

        video = self.aug(video)

        # apply user defined transform (e.g. normalization)
        if self.transform:
            video = self.transform(video)

        dummy_actions = torch.zeros(len(video)).long()

        return video, dummy_actions, index
    
    def handle_corrupted_file(self, file_path):
        # Implement your logic to handle corrupted files
        # For example, return a default value or reload the file
        logging.warning(f"Handling corrupted file: {file_path}")
        index = random.randint(0, len(self.all_files) - 1)
        video = self.get_video(index)
        return video