import numpy as np
import torch
from torch.utils.data import Dataset


class DatasetSim2D(Dataset):

    def __init__(self, frq=0.5, N=100, n_obj=1, obst=1, radius_dyn=20, move_bg=0, segment=0, batch_size=32):

        self.frq = frq
        self.N = N
        self.n_obj = n_obj
        self.obst = obst == 1
        self.radius_dyn = radius_dyn if radius_dyn != 0 else int(N / 2)
        self.batch_size = batch_size
        self.segment = segment

        self.move_bg = move_bg
        if self.move_bg == 1:
            r_bg = 20
            self.bg = np.zeros((3 * N, 3 * N))
            for i in range(3 * N):
                for j in range(3 * N):
                    if int(i / r_bg) % 2 == 0 and int(j / r_bg) % 2 == 0:
                        self.bg[i, j] = 1
        else:
            self.bg = np.zeros((3 * N, 3 * N))

        self.r_g = 4
        self.r1 = 8  # 2
        self.r2 = 4

        self.img_grip = np.zeros((9, 9))
        self.img_grip[0, 3:6] = 1
        self.img_grip[1, 3:6] = 1
        self.img_grip[2:4, :] = 1
        self.img_grip[4, 1:8] = 1
        self.img_grip[5, 2:7] = 1
        self.img_grip[6, 1:8] = 1
        self.img_grip[7, :] = 1
        self.img_grip[7, 4] = 0
        self.img_grip[8, :3] = 0
        self.img_grip[8, 6:] = 0

        self.img_obj1 = np.zeros((self.r1 * 2 + 1, self.r1 * 2 + 1))
        for i in range(-self.r1, self.r1 + 1):
            for j in range(-self.r1, self.r1 + 1):
                if np.abs(i) + np.abs(j) <= self.r1:
                    self.img_obj1[i + self.r1, j + self.r1] = 1

        self.img_obj2 = np.ones((self.r2 * 2 + 1, self.r2 * 2 + 1))

    def __len__(self):
        return self.batch_size * 10

    def __getitem__(self, idx):

        state_t, state_t1, a = self.gen_interaction()

        img = torch.from_numpy(self.get_img(state_t[0], state_t[1], state_t[2])).float()
        next_img = torch.from_numpy(self.get_img(state_t1[0], state_t1[1], state_t1[2])).float()
        a = torch.from_numpy(a).float() / self.N
        real_pos = torch.from_numpy(state_t).float() / self.N
        next_real_pos = torch.from_numpy(state_t1).float() / self.N

        return img, next_img, a, real_pos, next_real_pos

    def get_ladder(self):

        imgs = np.zeros((self.N - 2 * self.r_g, 3, self.N, self.N))
        for i in range(self.r_g, self.N - self.r_g):
            pos_grip = np.ones(2) * i
            pos_obj1 = np.random.randint(self.r1, high=self.N - self.r1 - 1, size=2)
            pos_obj2 = np.random.randint(self.r2, high=self.N - self.r2 - 1, size=2)
            imgs[i - self.r_g] = self.get_img(pos_grip, pos_obj1, pos_obj2)
        return torch.from_numpy(imgs).float()

    def get_grid(self, type):

        if type == 0:
            pos_list = range(self.r_g, self.N - self.r_g, 5)
            imgs = np.zeros((len(pos_list) ** 2, 3, self.N, self.N))
            for i, x in enumerate(pos_list):
                for j, y in enumerate(pos_list):
                    pos_grip = np.array([float(x), float(y)])
                    pos_obj1 = np.array([10., 10.])
                    pos_obj2 = np.array([75, 75])
                    imgs[i * len(pos_list) + j] = self.get_img(pos_grip, pos_obj1, pos_obj2)
            return torch.from_numpy(imgs).float()
        else:
            pos_list = range(self.r1, self.N - self.r1, 5)
            imgs = np.zeros((len(pos_list) ** 2, 3, self.N, self.N))
            for i, x in enumerate(pos_list):
                for j, y in enumerate(pos_list):
                    pos_grip = np.array([10., 10.])
                    pos_obj1 = np.array([float(x), float(y)])
                    pos_obj2 = np.array([75, 75])
                    imgs[i * len(pos_list) + j] = self.get_img(pos_grip, pos_obj1, pos_obj2)
            return torch.from_numpy(imgs).float()

    def gen_interaction(self):

        pos_grip = np.random.randint(self.r_g, high=self.N - self.r_g - 1, size=2)
        pos_grip_1 = np.random.randint(self.r_g, high=self.N - self.r_g - 1, size=2)

        if np.random.random_sample() < self.frq:
            if self.n_obj == 1:
                if self.segment == 0:
                    pos_obj1 = np.clip(pos_grip + np.random.randint(-self.r1, high=self.r1 + 1, size=2), a_min=self.r1, a_max=self.N - self.r1 - 1)
                else:
                    segment_point = (np.random.rand() * (pos_grip_1 - pos_grip) + pos_grip).astype(int)
                    alpha = np.arccos((pos_grip_1 - pos_grip)[0] / (np.linalg.norm(pos_grip_1 - pos_grip)))
                    v1 = segment_point + np.array([np.sin(alpha), np.cos(alpha)]) * (self.r_g + self.r1)
                    v2 = segment_point - np.array([np.sin(alpha), np.cos(alpha)]) * (self.r_g + self.r1)
                    pos_obj1 = np.clip((np.random.rand() * (v1 - v2) + v2).astype(int), a_min=self.r1, a_max=self.N - self.r1 - 1)
                pos_obj1_1 = np.clip(pos_obj1 + np.random.randint(-self.radius_dyn, high=self.radius_dyn + 1, size=2), a_min=self.r1, a_max=self.N - self.r1 - 1)
                while self.check_touch(pos_grip, pos_obj1_1, self.r1):
                    pos_obj1_1 = np.clip(pos_obj1 + np.random.randint(-self.radius_dyn, high=self.radius_dyn + 1, size=2), a_min=self.r1, a_max=self.N - self.r1 - 1)
                pos_obj2 = np.random.randint(self.r2, high=self.N - self.r2 - 1, size=2)
                pos_obj2_1 = pos_obj2.copy()
            else:
                if np.random.random_sample() < 0.5:
                    if self.segment == 0:
                        pos_obj1 = np.clip(pos_grip + np.random.randint(-self.r1, high=self.r1 + 1, size=2), a_min=self.r1, a_max=self.N - self.r1 - 1)
                    else:
                        segment_point = (np.random.rand() * (pos_grip_1 - pos_grip) + pos_grip).astype(int)
                        alpha = np.arccos((pos_grip_1 - pos_grip)[0] / (np.linalg.norm(pos_grip_1 - pos_grip)))
                        v1 = segment_point + np.array([np.sin(alpha), np.cos(alpha)]) * (self.r_g + self.r1)
                        v2 = segment_point - np.array([np.sin(alpha), np.cos(alpha)]) * (self.r_g + self.r1)
                        pos_obj1 = np.clip((np.random.rand() * (v1 - v2) + v2).astype(int), a_min=self.r1, a_max=self.N - self.r1 - 1)
                    pos_obj2 = np.random.randint(self.r2, high=self.N - self.r2 - 1, size=2)
                    pos_obj1_1 = np.clip(pos_obj1 + np.random.randint(-self.radius_dyn, high=self.radius_dyn + 1, size=2), a_min=self.r1, a_max=self.N - self.r1 - 1)
                    while self.check_touch(pos_grip, pos_obj1_1, self.r1):
                        pos_obj1_1 = np.clip(pos_obj1 + np.random.randint(-self.radius_dyn, high=self.radius_dyn + 1, size=2), a_min=self.r1, a_max=self.N - self.r1 - 1)
                    pos_obj2_1 = pos_obj2.copy()
                    while self.check_touch(pos_grip, pos_obj2_1, self.r2):
                        pos_obj2_1 = np.clip(pos_obj2 + np.random.randint(-self.radius_dyn, high=self.radius_dyn + 1, size=2), a_min=self.r2, a_max=self.N - self.r2 - 1)
                else:
                    pos_obj1 = np.random.randint(self.r1, high=self.N - self.r1 - 1, size=2)
                    if self.segment == 0:
                        pos_obj2 = np.clip(pos_grip + np.random.randint(-self.r2, high=self.r2 + 1, size=2), a_min=self.r2, a_max=self.N - self.r2 - 1)
                    else:
                        segment_point = (np.random.rand() * (pos_grip_1 - pos_grip) + pos_grip).astype(int)
                        alpha = np.arccos((pos_grip_1 - pos_grip)[0] / (np.linalg.norm(pos_grip_1 - pos_grip)) + 1e-5)
                        v1 = segment_point + np.array([np.sin(alpha), np.cos(alpha)]) * (self.r_g + self.r1)
                        v2 = segment_point - np.array([np.sin(alpha), np.cos(alpha)]) * (self.r_g + self.r1)
                        pos_obj2 = np.clip((np.random.rand() * (v1 - v2) + v2).astype(int), a_min=self.r1, a_max=self.N - self.r1 - 1)
                    pos_obj1_1 = pos_obj1.copy()
                    while self.check_touch(pos_grip, pos_obj1_1, self.r1):
                        pos_obj1_1 = np.clip(pos_obj1 + np.random.randint(-self.radius_dyn, high=self.radius_dyn + 1, size=2), a_min=self.r1, a_max=self.N - self.r1 - 1)
                    pos_obj2_1 = np.clip(pos_obj2 + np.random.randint(-self.radius_dyn, high=self.radius_dyn + 1, size=2), a_min=self.r2, a_max=self.N - self.r2 - 1)
                    while self.check_touch(pos_grip, pos_obj2_1, self.r2):
                        pos_obj2_1 = np.clip(pos_obj2 + np.random.randint(-self.radius_dyn, high=self.radius_dyn + 1, size=2), a_min=self.r2, a_max=self.N - self.r2 - 1)
        else:
            pos_obj1 = np.random.randint(self.r1, high=self.N - self.r1 - 1, size=2)
            pos_obj2 = np.random.randint(self.r2, high=self.N - self.r2 - 1, size=2)
            # TODO: uncomment below if touching changes position of the object strictly
            # while self.check_touch(pos_grip, pos_obj):
            #     pos_obj = np.random.randint(self.radius - 1, high=self.N - self.radius + 1, size=2)
            pos_obj1_1 = pos_obj1.copy()
            pos_obj2_1 = pos_obj2.copy()

        a = pos_grip_1 - pos_grip

        state_t = np.concatenate((np.expand_dims(pos_grip, 0),
                                  np.expand_dims(pos_obj1, 0),
                                  np.expand_dims(pos_obj2, 0)), 0)
        state_t1 = np.concatenate((np.expand_dims(pos_grip_1, 0),
                                   np.expand_dims(pos_obj1_1, 0),
                                   np.expand_dims(pos_obj2_1, 0)), 0)

        return state_t, state_t1, a

    def check_touch(self, pos_grip, pos_obj, r):

        delta_x = np.abs(pos_obj[0] - pos_grip[0])
        delta_y = np.abs(pos_obj[1] - pos_grip[1])

        if delta_x + delta_y <= (self.r_g + r):
            return True

    def get_img(self, grip_pos, obj1_pos, obj2_pos):

        img = np.zeros((3, self.N, self.N))

        img = img + self.bg[grip_pos[0] * 2:grip_pos[0] * 2 + self.N, grip_pos[1] * 2:grip_pos[1] * 2 + self.N]

        img = self.plot_element(img, grip_pos, self.r_g, self.img_grip, 0)
        img = self.plot_element(img, obj1_pos, self.r1, self.img_obj1, 1)
        if self.n_obj > 1:
            img = self.plot_element(img, obj2_pos, self.r2, self.img_obj2, 2)
        if self.obst:
            n_bands = int(self.N / 10.)
            for i in range(n_bands):
                img[:, :, (i * 10):(i * 10 + 5)] = 0.7
        return img.copy()

    def plot_element(self, img, pos, r, template, channel):
        for i in range(-r, r + 1):
            for j in range(-r, r + 1):
                img[:, int(pos[0]) + i, int(pos[1]) + j] *= (1 - template[i + r, j + r])
                img[channel, int(pos[0]) + i, int(pos[1]) + j] += template[i + r, j + r]
        return img


if __name__ == '__main__':
    dataloader = DatasetSim2D(frq=0.5, N=64, n_obj=2, obst=1, radius_dyn=0, batch_size=32)
    batch = dataloader.__getitem__(0)
