import numpy as np
import torch
from datasets.maze_gen_new_simple import Maze_Gen
from torch.utils.data import Dataset
from tqdm import tqdm
import ipdb



class DatasetGridSimple(Dataset):
 
    def __init__(self, T, support_size, seed, grid_size):
 
        self.generator = Maze_Gen(seed, grid_size)
        #support_size = (grid_size + 2)**2
        support_size = 2 * grid_size
 
        self.T = T

        rng = np.random.default_rng(seed)
 
        self.support_states = np.zeros((T, support_size, 2))
        self.support_labels = np.zeros((T, support_size, 1))
 
        query_size = 5
        self.query_states = np.zeros((T, query_size, 4))
        self.query_actions = np.zeros((T, query_size, 1))
 
        for t in tqdm(range(T)):
            start, goal, grid, optimal_states, optimal_policy = self.generator.get_maze()

            idx_x, idx_y = np.where(grid == 0)
            pos_idx = np.concatenate((np.expand_dims(idx_x, -1), np.expand_dims(idx_y, -1)), -1)

            idx_x, idx_y = np.where(grid == 1)
            neg_idx = np.concatenate((np.expand_dims(idx_x, -1), np.expand_dims(idx_y, -1)), -1)

            y_neg = np.zeros((len(neg_idx), 1))
            y_pos = np.ones((len(pos_idx), 1))
            support_x = np.vstack((pos_idx, neg_idx))
            support_y = np.vstack((y_pos, y_neg))
            shuffled_idx = np.arange(len(support_x))
            rng.shuffle(shuffled_idx)
            shuffled_idx = shuffled_idx[0:support_size]
            support_x = support_x[shuffled_idx]
            support_y = support_y[shuffled_idx]
 
            query_idx = rng.choice(range(len(optimal_policy[0])), query_size)
            query_states = np.array([optimal_states[i] for i in query_idx])
            query_actions = np.array([optimal_policy[0][i] for i in query_idx])
 
            self.support_states[t] = support_x
            self.support_labels[t] = support_y

            self.query_states[t] = np.concatenate((query_states, np.ones((query_size, 2))*goal), -1)
            self.query_actions[t,:,0] = query_actions
 
    def __len__(self):
        return self.T
 
    def __getitem__(self, idx):

        support_x = torch.from_numpy(self.support_states[idx]).float()
        support_y = torch.from_numpy(self.support_labels[idx]).float().squeeze(-1)

        #support_x = torch.cat((support_states, support_actions), -1).float()
        #support_y = torch.from_numpy(self.support_success[idx]).float().squeeze(-1)

        query_x = torch.from_numpy(self.query_states[idx]).float()
        query_y = torch.from_numpy(self.query_actions[idx]).float().squeeze(-1)

        return support_x, support_y, query_x, query_y, torch.Tensor([0]).float()
 
 
# class DatasetGrid(Dataset):
 
#     def __init__(self, T, support_size, seed):
 
#         self.generator = Maze_Gen(seed)
 
#         self.T = T

#         rng = np.random.default_rng(seed)
 
#         self.support_states = np.zeros((T, support_size, 4))
#         self.support_actions = np.zeros((T, support_size, 1))
#         self.support_success = np.zeros((T, support_size, 1))
 
#         query_size = 5
#         self.query_states = np.zeros((T, query_size, 4))
#         self.query_actions = np.zeros((T, query_size, 1))
 
#         for t in tqdm(range(T)):
#             start, goal, grid, optimal_states, optimal_policy = self.generator.get_maze()
#             idx_x, idx_y = np.where(grid == 0)
#             idx = np.concatenate((np.expand_dims(idx_x, -1), np.expand_dims(idx_y, -1)), -1)
#             support_idx = rng.choice(range(idx.shape[0]), support_size)
#             support_states = np.concatenate((idx[support_idx], np.ones((support_size, 2))*goal), -1)
 
#             support_actions = rng.choice(range(4), support_size)
 
#             action_movement = np.array([[1, 0], [-1, 0], [0, 1], [0, -1]])
#             support_success = (grid[[x[0]+action_movement[a, 0] for x, a in zip(idx[support_idx], support_actions)],
#                                    [x[1]+action_movement[a, 0] for x, a in zip(idx[support_idx], support_actions)]])*1
 
#             query_idx = rng.choice(range(len(optimal_policy[0])), query_size)
#             query_states = np.array([optimal_states[i] for i in query_idx])
#             query_actions = np.array([optimal_policy[0][i] for i in query_idx])
 
#             self.support_states[t] = support_states
#             self.support_actions[t,:,0] = support_actions
#             self.support_success[t,:,0] = support_success
#             self.query_states[t] = np.concatenate((query_states, np.ones((query_size, 2))*goal), -1)
#             self.query_actions[t,:,0] = query_actions
 
#     def __len__(self):
#         return self.T
 
#     def __getitem__(self, idx):

#         support_states = torch.from_numpy(self.support_states[idx])
#         support_actions = torch.from_numpy(self.support_actions[idx])
#         support_x = torch.cat((support_states, support_actions), -1).float()
#         support_y = torch.from_numpy(self.support_success[idx]).float().squeeze(-1)

#         query_x = torch.from_numpy(self.query_states[idx]).float()
#         query_y = torch.from_numpy(self.query_actions[idx]).float().squeeze(-1)

#         return support_x, support_y, query_x, query_y, torch.Tensor([0]).float()
 
 

 
 
 
if __name__ == '__main__':
 
    dataset = DatasetGrid(10000, 100, 42)
    dataset.__getitem__(10)
    print()
 
    # sim = Simulator()
    # sim.reset()
    # a = 0
    # sim.step(a)
    # print()