from glob import glob
import os
import numpy as np
import random
import torch.utils.data as data
import json
import cv2
import video_transforms as vtransforms
import torchvision.transforms as transforms
import torch


def load_balls(root, n_frames, is_train=True):
    transform = transforms.Compose([vtransforms.Scale((64, 64)),
                                    vtransforms.ToTensor()])
    dset = BouncingBalls(root, is_train, n_frames,
                         n_frames, 64, transform)
    return dset

# def make_dataset(root, is_train):
#   if is_train:
#     folder = 'balls_n4_t60_ex50000'
#   else:
#     folder = 'balls_n4_t60_ex2000'

#   dataset = np.load(os.path.join(root, folder, 'dataset_info.npy'))
#   # print("TYPE: ", type(dataset))
#   print("TYPE: ", (dataset.shape))
#   return dataset

def make_dataset(root, is_train):
  instances = 2048 if is_train else 205

  data = []
  for step in range(instances):
    bw_frames = generate_bw_animation_frames(num_frames=20)
    data.append(torch.stack([ToTensor()(frame) for frame in bw_frames], dim=0).squeeze(1))

  dataset = torch.stack(data, dim=0)#.numpy()
  print("TYPE: ", (dataset.shape))
  return dataset

class BouncingBalls(data.Dataset):
  '''
  Bouncing balls dataset.
  '''
  def __init__(self, root, is_train, n_frames_input, n_frames_output, image_size,
               transform=None, return_positions=False):
    super(BouncingBalls, self).__init__()
    self.n_frames = n_frames_input + n_frames_output
    self.dataset = make_dataset(root, is_train)

    self.size = image_size
    self.scale = self.size / 800
    self.radius = int(60 * self.scale)

    self.root = root
    self.is_train = is_train

    self.n_frames_input = n_frames_input*2
    self.n_frames_output = n_frames_output*2
    self.n_frames = n_frames_output*2

    self.transform = transform
    self.return_positions = return_positions
    
    # FIXME
    self.with_target = False
    self.digit_size_ = 28
    self.step_length_ = 0.1
    self.num_digits = 1
    self.image_size_ = image_size
    self.mnist = self.dataset

  def __len__(self):
    return self.dataset.size(0)
  
  def __getitem__(self, idx):

    
    if torch.is_tensor(idx):
            idx = idx.tolist()
    
    traj = self.dataset[idx, :10, :, :]
    label = self.dataset[idx, 10:, :, :]

    
    return traj, label


  # def __getitem__(self, idx):
  #   # traj sizeL (n_frames, n_balls, 4)
  #   traj = self.dataset[idx]
  #   vid_len, n_balls = traj.shape[:2]
  #   if self.is_train:
  #     start = random.randint(0, vid_len - self.n_frames)
  #   else:
  #     start = 0

  #   n_channels = 1
  #   images = np.zeros([self.n_frames, self.size, self.size, n_channels], np.uint8)
  #   positions = []
  #   for fid in range(self.n_frames):
  #     xy = []
  #     for bid in range(n_balls):
  #       # each ball:
  #       ball = traj[start + fid, bid]
  #       x, y = int(round(self.scale * ball[0])), int(round(self.scale * ball[1]))
  #       images[fid] = cv2.circle(images[fid], (x, y), int(self.radius * ball[3]),
  #                                255, -1)
  #       xy.append([x / self.size, y / self.size])
  #     positions.append(xy)

  #   if self.transform is not None:
  #     images = self.transform(images)

  #   input = images[:self.n_frames_input]
  #   if self.n_frames_output > 0:
  #     output = images[self.n_frames_input:]
  #   else:
  #     output = []

  #   if not self.return_positions:
  #     return input, output
  #   else:
  #     positions = np.array(positions)
  #     return input, output, positions

  # def get_random_trajectory(self, seq_length):
  #       ''' Generate a random sequence of a MNIST digit '''
  #       canvas_size = self.image_size_ - self.digit_size_
  #       x = random.random()
  #       y = random.random()
  #       theta = random.random() * 2 * np.pi
  #       v_y = np.sin(theta)
  #       v_x = np.cos(theta)

  #       start_y = np.zeros(seq_length)
  #       start_x = np.zeros(seq_length)
  #       for i in range(seq_length):
  #           # Take a step along velocity.
  #           y += v_y * self.step_length_
  #           x += v_x * self.step_length_

  #           # Bounce off edges.
  #           if x <= 0:
  #               x = 0
  #               v_x = -v_x
  #           if x >= 1.0:
  #               x = 1.0
  #               v_x = -v_x
  #           if y <= 0:
  #               y = 0
  #               v_y = -v_y
  #           if y >= 1.0:
  #               y = 1.0
  #               v_y = -v_y
  #           start_y[i] = y
  #           start_x[i] = x

  #       # Scale to the size of the canvas.
  #       start_y = (canvas_size * start_y).astype(np.int32)
  #       start_x = (canvas_size * start_x).astype(np.int32)
  #       return start_y, start_x

  # def generate_moving_mnist(self, num_digits=2):
  #       '''
  #       Get random trajectories for the digits and generate a video.
  #       '''
  #       data = np.zeros((self.n_frames, self.image_size_, self.image_size_), dtype=np.float32)
  #       for n in range(num_digits):
  #           # Trajectory
  #           start_y, start_x = self.get_random_trajectory(self.n_frames)
  #           ind = random.randint(0, self.mnist.shape[0] - 1)
  #           digit_image = self.mnist[ind]
  #           for i in range(self.n_frames):
  #               top    = start_y[i]
  #               left   = start_x[i]
  #               bottom = top + self.digit_size_
  #               right  = left + self.digit_size_
  #               # Draw digit
  #               data[i, top:bottom, left:right] = np.maximum(data[i, top:bottom, left:right], digit_image)

  #       data = data[..., np.newaxis]

  #       return data

  # def __getitem__(self, idx):
  #       if self.is_train or self.num_digits != 2:
  #           # Generate data on the fly
  #           images = self.generate_moving_mnist(self.num_digits)
  #       else:
  #           images = self.dataset[:, idx, ...]

  #       if self.with_target:
  #           targets = np.array(images > 127, dtype=float) * 255.0

  #       if self.transform is not None:
  #           images = self.transform(images)
  #           if self.with_target:
  #               targets = self.transform(targets)

  #       if self.with_target:
  #           return images, targets
  #       else:
  #           return images





from torchvision.transforms import ToTensor
import random
from PIL import Image, ImageDraw



def create_background_with_gray_lines(frame_size=(64, 64), num_lines=50):
    """Create a background with random gray lines."""
    background = Image.new("L", frame_size, "white")
    draw = ImageDraw.Draw(background)

    for idx in range(int(num_lines/2)):

        const_point = int((frame_size[0] / (num_lines/2)) * idx) +1

        gray_shade = 127
        draw.line([(const_point, 0), (const_point, 63)], fill=gray_shade)
        draw.line([(0, const_point), (63, const_point)], fill=gray_shade)

    return background

def create_bw_frame(background, ball_y, ball_x, ball_radius=5):
    """Create a single black and white frame with the ball at the specified y position."""
    frame = background.copy()  # Copy the background with gray lines
    draw = ImageDraw.Draw(frame)
    x_position = ball_x
    top_left = (x_position - ball_radius, ball_y - ball_radius)
    bottom_right = (x_position + ball_radius, ball_y + ball_radius)
    draw.ellipse([top_left, bottom_right], fill="black")
    return frame

def generate_bw_animation_frames(num_frames=10, frame_size=(64, 64)):
    """Generate a series of black and white frames for the animation."""
    frames = []
    # ball_x = random.randrange(8, 56, 1)
    ball_x = 32
    # max_height = random.randrange(48, 56, 1)
    max_height = 56
    background = create_background_with_gray_lines(frame_size)

    for i in range(num_frames):
        t = i / (num_frames - 1)
        ball_y = (64 - max_height) + int(position_change(9.8, i) / 2)
        frame = create_bw_frame(background, ball_y, ball_x, ball_radius=5)
        frames.append(frame)
    return frames

def position_change(acceleration, time):
        # change_in_position = initial_velocity * time + 0.5 * acceleration * time^2
        change_in_position = 0.5 * acceleration * time ** 2

        return change_in_position

# Generate black and white frames
# bw_frames = generate_bw_animation_frames()

# # Convert frames to PyTorch tensors
# tensor_frames = [ToTensor()(frame) for frame in bw_frames]

# # Example of one frame as a tensor
# tensor_frames[0].shape  # Should be [1, 64, 64] as it's a single channel image

# bw_frames[0].save(
#     './bounce_animation.gif',
#     save_all=True,
#     append_images=bw_frames[1:],
#     duration=100,  # duration for each frame in milliseconds
#     loop=0  # loop forever
# )


