from tokenize import Double
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import ipdb



class DoublePendulumTransfer(torch.utils.data.Dataset):
    def __init__(self, path, traj_length, use_img, noise_std, support_size, dataset_size):
        super().__init__()

        self.traj_length = traj_length
        self.use_img = use_img
        self.noise_std = noise_std
        self.support_size = support_size
        #self.support_size = None ## Remove

        with open(path, 'rb') as f:
            self.data = pickle.load(f)[0:dataset_size]


    def __len__(self):
        return len(self.data)

    def create_traj(self, support_x, support_y, traj_length):
        len_seq = len(support_x)

        traj_support_x = []
        traj_support_y = []

        for i in range(0, len_seq - traj_length):
            x_traj = support_x[i : i + traj_length]
            y_state = support_y[i + traj_length - 1]

            traj_support_x.append(x_traj)
            traj_support_y.append(y_state)
            
        traj_support_x = np.array(traj_support_x)
        traj_support_y = np.array(traj_support_y)
        
        return traj_support_x, traj_support_y

    def create_imgs_actions(self, state_img, state_actions):
        img_width = 64
        state_actions = state_actions[:, :, None, None]
        state_actions = state_actions.repeat(img_width, 1).repeat(img_width, 2).repeat(1, 3).transpose(0, 3, 1, 2)
        to_return = np.concatenate((state_img, state_actions), 1)
        return to_return


    def __getitem__(self, idx):
        support_x, support_y, query_x, query_y, task_params = self.data[idx]

        if self.use_img:
            support_x_imgs = np.stack(np.array(support_x, dtype=object)[:,0]).transpose(0, 3, 1, 2) / 255.
            support_x_action = np.stack(np.array(support_x, dtype=object)[:,1])

            support_x_imgs = self.create_imgs_actions(support_x_imgs, support_x_action)
            support_y_imgs = np.stack(np.array(support_y, dtype=object)[:,0]).transpose(0, 3, 1, 2) / 255.

            support_x, support_y = self.create_traj(support_x_imgs, support_y_imgs, self.traj_length)


            query_x_imgs = np.stack(np.array(query_x, dtype=object)[:,0]).transpose(0, 3, 1, 2) / 255.
            query_x_action = np.stack(np.array(query_x, dtype=object)[:,1])

            query_x_imgs = self.create_imgs_actions(query_x_imgs, query_x_action)
            query_y_imgs = np.stack(np.array(query_y, dtype=object)[:,0]).transpose(0, 3, 1, 2) / 255.

            query_x, query_y = self.create_traj(query_x_imgs, query_y_imgs, self.traj_length)
        else:
            support_x = np.stack(np.array(support_x, dtype=object)[:,2]) # x_state_action
            support_y = np.stack(np.array(support_y, dtype=object)[:,1])
            query_x = np.stack(np.array(query_x, dtype=object)[:,2])
            query_y = np.stack(np.array(query_y, dtype=object)[:,1])

            ## Implement support size
            if self.support_size is not None:
                indices = np.arange(len(support_x)) # (MAX_SUPPORT_SIZE,)
                np.random.shuffle(indices)
                indices = indices[:self.support_size]
                support_x = support_x[indices]
                support_y = support_y[indices]

            ## Noise in the input => Noise in the output ## You need a larger support size
            support_x = support_x + np.random.randn(*support_x.shape) * self.noise_std
            support_y = support_y + np.random.randn(*support_y.shape) * self.noise_std

        support_x = torch.from_numpy(support_x).float()
        support_y = torch.from_numpy(support_y).float()
        query_x = torch.from_numpy(query_x).float()
        query_y = torch.from_numpy(query_y).float()

        task_params = torch.from_numpy(np.array(task_params) / [1, 10]).float() # Divide gravity by 10 to get comparable numbers

        return support_x, support_y, query_x, query_y, task_params




if __name__ == '__main__':
    path = 'data/single_double_pendulum.pkl'
    traj_length = 4
    use_img = 1

    dset = DoublePendulumTransfer(path, traj_length, use_img)

    idx = 0
    x_s, y_s, x_q, y_q, params = dset[idx]
