from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import os
import cv2
import matplotlib
import numpy as np
import torch
import torch.nn.functional as F

from torch import Tensor
from torch.utils.data import DataLoader, Dataset

def get_continual_tetrominoes(
        name: str,
        dataset_f,
        width: int,
        height: int,
        max_num_objects: int,
        num_background_objects: int,
        input_channels: int,
        train_dataset_size: int,
        test_dataset_size: int,
        dataset_root: str,
        dataset_name: str,  
        shape_list=None,
        color_list=None,
        datas=None,
        positions=None,
        generate_dataset=True,
        ):
    
    if generate_dataset:
        total_dataset_size = train_dataset_size + test_dataset_size
        comb_array = np.array(np.meshgrid(shape_list, color_list)).T.reshape(-1, 2)
        repeats = (total_dataset_size) * max_num_objects // comb_array.shape[0] + 1
        sample_pool = comb_array.repeat(repeats=repeats, axis=0)
        sample_pool = np.random.permutation(sample_pool)

        datas = np.split(sample_pool, indices_or_sections=np.cumsum([max_num_objects]*(total_dataset_size+1)), axis=0)[:total_dataset_size]
        positions = np.random.choice(np.linspace(-2,2,5), (total_dataset_size, 2))
        templates = np.random.choice(4, total_dataset_size)

        train_datas, test_datas = datas[:train_dataset_size], datas[train_dataset_size:]
        train_positions, test_positions = positions[:train_dataset_size], positions[train_dataset_size:]
        train_templates, test_templates = templates[:train_dataset_size], templates[train_dataset_size:]
        
    else:
        train_path = os.path.join(dataset_root, f'{dataset_name}_train')
        test_path = os.path.join(dataset_root, f'{dataset_name}_test')
        assert os.path.exists(train_path)
        assert os.path.exists(test_path)
        train_datas, test_datas = None, None
        train_positions, test_positions = None, None
        train_templates, test_templates = None, None
        shape_list=None
        color_list=None
        raise ValueError("TBU")

    train_dataset = dataset_f(
        name=name,
        width=width,
        height=height,
        max_num_objects=max_num_objects,
        num_background_objects=0,
        input_channels=input_channels,
        dataset_root=dataset_root,
        dataset_name=dataset_name,
        datas=train_datas,
        positions=train_positions,
        templates=train_templates,
        shape_list=shape_list,
        color_list=color_list,
    )
    test_dataset = dataset_f(
        name=name,
        width=width,
        height=height,
        max_num_objects=max_num_objects,
        num_background_objects=0,
        input_channels=input_channels,
        dataset_root=dataset_root,
        dataset_name=dataset_name,
        datas=test_datas,
        positions=test_positions,
        templates=test_templates,
        shape_list=shape_list,
        color_list=color_list,
    )
    return train_dataset, test_dataset



class ContinualTetrominoes(Dataset):
    def __init__(
        self,
        name: str,
        width: int,
        height: int,
        max_num_objects: int,
        num_background_objects: int,
        input_channels: int,
        dataset_root: str,
        dataset_name: str,  #
        datas: List,
        positions: List,
        templates: List,
        shape_list: List,
        color_list: List,

        ):
        super().__init__()

        self.name = name
        self.width = width
        self.height = height
        self.max_num_objects = max_num_objects
        self.num_background_objects = num_background_objects
        self.input_channels = input_channels
        self.dataset_root = dataset_root
        self.dataset_name = dataset_name
        self.shape_list = shape_list
        self.color_list = color_list

        if datas is not None and positions is not None:
            assert len(datas) == len(positions)
            self.datas = datas
            self.positions = positions
            self.templates = templates
        else:
            raise ValueError('T.B.U')
        

        self.logits = None
        self.shape2id = {x+1: idx+1 for idx, x in enumerate(self.shape_list)}
        self.shape2id[0] = 0


        self.len_shape  = self.shape_list.shape[0]
        self.len_color = self.color_list.shape[0]
        self.shape2id_ = {x: idx for idx, x in enumerate(self.shape_list)}
        self.len_labels = 2 + self.len_shape + self.len_color + 1
        

    def __len__(self,):

        return len(self.datas)

    def __getitem__(self, index):
        data = self.datas[index]
        position = self.positions[index]
        template_ = self.templates[index]
        template = self.get_position_template(template_)
            
        out = {}
        images = []
        masks = []
        colors = [0]
        shapes = [0]
        shape_ids = [0]
        xs = [0]
        ys = [0]
        is_foreground = [0]
        image = torch.zeros((self.width, self.height, self.input_channels))
        x_ = [-1]
        y_ = [-1]

        set_label = []
        p = [0] * 2
        c = [0] * self.len_color
        s = [0] * self.len_shape 
        label = p + c + s + [0]
        set_label.append(label)


        for idx, (shape, color) in enumerate(data):
            x = (self.width//2)//2-1 + template[idx][0] + position[0]
            y = (self.height//2)//2-1 + template[idx][1] + position[1]
            obj_img = self.get_data_by_label(
                angle=0.0,          # angle
                color=color,        # color
                scale=3.5,          # scale
                x=x,                # x
                y=y,                # y
                shape=shape,        # shape
                height=self.height,
                width=self.width,
            )
            images.append(obj_img)
            image += obj_img
            mask = obj_img != 0
            mask = mask.to(torch.float32).max(dim=-1).values
            masks.append(mask.unsqueeze(0))
            colors.append(color+1)
            shapes.append(shape+1)
            xs.append(x)
            ys.append(y)
            x_.append(template[idx][0])
            y_.append(template[idx][1])
            is_foreground.append(1)
            shape_ids.append(self.shape2id[shape+1])

            grid = self.width // 2
            p = [int(int(template[idx][0] // grid) * 6 - 3)] + [int(int(template[idx][1] // grid) * 6 - 3)]
            c = [0] * self.len_color
            c[int(color*10)] = 1
            s = [0] * self.len_shape 
            s[self.shape2id_[int(shape)]] = 1
            label = p + c + s + [1]
            set_label.append(label)


        out['image'] = image.permute(2,0,1)
        masks = torch.stack(masks)
        bg = (masks.sum(dim=0) == 0).to(torch.float32)
        out['mask'] = torch.cat([bg.unsqueeze(0), masks])
        out['color'] = colors
        out['shape'] = shapes
        out['x'] = torch.tensor(xs, dtype=torch.float32).view(-1, 1)
        out['y'] = torch.tensor(ys, dtype=torch.float32).view(-1, 1)
        out['is_foreground'] = torch.tensor(is_foreground, dtype=torch.float32).view(-1, 1)
        out['is_modified'] = torch.zeros_like(out['is_foreground'])
        out['index'] = index
        out['shape_ids'] = shape_ids

        return out

    def get_set_prediction_lable(self, ):
        pass
    
    
    def get_data_by_label(self, angle=0, color=0, scale=1, x=16, y=16, shape=0, height=32, width=32, value=1.0,
                        flag_affine=cv2.INTER_AREA, flag_resize=cv2.INTER_AREA):
        int_final_ratio = 1
        final_shape = (height, width)
        intermediate_shape = (height * int_final_ratio, width * int_final_ratio)
        
        # object_shape, shape, height, width, angle = get_tetrominos_shape(shape, height, width, angle)
        object_shape, shape, height, width, angle = self.get_object_shape(shape, height, width, angle)

        scale_ = scale / height * int_final_ratio
        t1 = np.eye(3)  # First translation moves center of shape to origin
        t1[0, 2] = -object_shape.shape[1] / 2
        t1[1, 2] = -object_shape.shape[0] / 2
        r = np.eye(3)  # Rotation
        r[0, 0] = scale_ * np.cos(angle * np.pi / 180)
        r[0, 1] = scale_ * np.sin(angle * np.pi / 180)
        r[1, 0] = -r[0, 1]
        r[1, 1] = r[0, 0]
        t2 = np.eye(3)  # Second translation moves rotated shape to x, y
        t2[0, 2] = int_final_ratio * (x + 0.5)
        t2[1, 2] = int_final_ratio * (y + 0.5)
        affine_mat = (t2 @ r @ t1)[:-1]

        dst = cv2.warpAffine(object_shape, affine_mat, intermediate_shape, flags=flag_affine)
        dst = cv2.resize(dst, final_shape, interpolation=flag_resize)
        dst = value * np.repeat(dst[..., np.newaxis], 3, axis=2)
        dst[..., 1] = 1
        dst[..., 0] = color * 360
        dst = cv2.cvtColor(dst, cv2.COLOR_HSV2RGB)
        dst[dst > 1] = 1
        dst[dst < 0] = 0

        dst = torch.Tensor(dst)
        # img = img.reshape(-1, 3, 32, 32).permute(0, 2, 3, 1).numpy()[0]

        max_values = dst.view(-1,3).max(dim=0, keepdim=True).values.unsqueeze(0)
        mask = dst != 0
        dst[mask] = max_values.expand_as(dst)[mask]
        return dst
    
    def get_object_shape(self, shape, height, width, angle):
        return get_tetrominos_shape(shape, height, width, angle)
    
    def get_position_template(self, index):
        grid = self.width // 2
        assert grid == self.height // 2
        self.position_template = [
            [[0, 0], [0, grid], [grid, grid//2]],
            [[0, 0], [grid, 0], [grid//2, grid]],
            [[0, grid//2], [grid, 0], [grid, grid]],
            [[grid//2, 0], [0, grid], [grid, grid]],
        ]
        return self.position_template[index]
    

class ContinualPentominos(ContinualTetrominoes):

    def get_object_shape(self, shape, height, width, angle):
        return get_pentominos_shape(shape, height, width, angle)
    

def get_tetrominos_shape(shape, height, width, angle):
    ones = np.ones((height, width), dtype=np.float32)
    zeros = np.zeros((height, height), dtype=np.float32)
    tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    if shape in [0, 1]: # -
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([ones, ones, ones, ones, ones, ones, ones, ones],), 
            np.hstack([ones, ones, ones, ones, ones, ones, ones, ones],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        if shape == 1:
            tetromino = np.rot90(tetromino, 1)
    elif shape in [2, 3, 4, 5]: #
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        if shape == 3:
            tetromino = np.flip(tetromino, 0)
        if shape == 4:
            tetromino = np.rot90(tetromino, 2)
        if shape == 4:
            tetromino = np.flip(tetromino, 1)

    elif shape in [6, 7, 8, 9]: #
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        if shape == 7:
            tetromino = np.flip(tetromino, 0)
        if shape == 8:
            tetromino = np.flip(tetromino, 1)
        if shape == 9:
            tetromino = np.flip(tetromino, 1)
            tetromino = np.flip(tetromino, 0)

    elif shape in [10, 11, 12, 13]: #
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        if shape == 11:
            tetromino = np.flip(tetromino, 0)
        if shape == 12:
            tetromino = np.rot90(tetromino, 1)
        if shape == 13:
            tetromino = np.rot90(tetromino, 1)
            tetromino = np.flip(tetromino, 1)

    elif shape in [14, 15, 16, 17]: #
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        if shape == 15:
            tetromino = np.flip(tetromino, 0)
        if shape == 16:
            tetromino = np.rot90(tetromino, 1)
        if shape == 17:
            tetromino = np.rot90(tetromino, 1)
            tetromino = np.flip(tetromino, 1)
    elif shape == 18: #
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    elif shape == 19: #F
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    # if shape == 1: remove I
    elif shape == 20: #L
        tetromino = np.vstack([
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            ])
    elif shape == 21: #N
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            ])
        angle = -45.0 
    elif shape == 22: #P
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    elif shape == 23: #T
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    elif shape == 24: #U
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    elif shape == 25: #V
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        angle = 45.0 
    elif shape == 26: #W
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        angle = 45.0 
    elif shape == 27: #X
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        angle = -45.0 
    elif shape == 28: #X
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        angle = 45.0 
    elif shape == 29: #Y
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            ])
        angle = -45.0 
    elif shape == 30: #Z
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    else:
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        raise ValueError("invalid shape: {}".format(shape))
    
    return tetromino, shape, height, width, angle

def get_pentominos_shape(shape, height, width, angle):
    ones = np.ones((height, width), dtype=np.float32)
    zeros = np.zeros((height, height), dtype=np.float32)

    if shape == 0: #F
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    # if shape == 1: remove I
    elif shape == 1: #L
        tetromino = np.vstack([
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            ])
    elif shape == 2: #N
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            ])
        angle = -45.0 
    elif shape == 3: #P
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    elif shape == 4: #T
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    elif shape == 5: #U
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    elif shape == 6: #V
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        angle = 45.0 
    elif shape == 7: #W
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        angle = 45.0 
    elif shape == 8: #X
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        angle = -45.0 
    elif shape == 8: #X
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        angle = 45.0 
    elif shape == 9: #Y
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, ones, ones, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, ones, ones, zeros, zeros],), 
            ])
        angle = -45.0 
    elif shape == 10: #Z
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, ones, ones, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, ones, ones, ones, ones, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
    else:
        tetromino = np.vstack([
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            np.hstack([zeros, zeros, zeros, zeros, zeros, zeros, zeros, zeros],), 
            ])
        raise ValueError("invalid shape: {}".format(shape))
    
    return tetromino, shape, height, width, angle