import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import ipdb


class Symbols():
    def __init__(self, img_size):

        self.img_size = img_size
        self.symbols_list = []
        self.symbols_list.append(np.array([[0, 0, 0, 0, 0, 0, 0, 0],  # 0
                                           [0, 0, 0, 0, 1, 0, 0, 0],  # 1
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 2
                                           [0, 1, 1, 1, 1, 1, 0, 0],  # 3
                                           [0, 0, 1, 1, 1, 1, 1, 0],  # 4
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 5
                                           [0, 0, 0, 1, 0, 0, 0, 0],  # 6
                                           [0, 0, 0, 0, 0, 0, 0, 0]])) # 7

        self.symbols_list.append(np.array([[0, 0, 0, 1, 1, 0, 0, 0],  # 0
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 1
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 2
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 3
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 4
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 5
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 6
                                           [0, 0, 0, 1, 1, 0, 0, 0]])) # 7

        self.symbols_list.append(np.array([[0, 0, 0, 0, 0, 0, 0, 0],  # 0
                                           [0, 0, 0, 0, 0, 0, 0, 0],  # 1
                                           [0, 0, 0, 0, 0, 0, 0, 0],  # 2
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 3
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 4
                                           [0, 0, 0, 0, 0, 0, 0, 0],  # 5
                                           [0, 0, 0, 0, 0, 0, 0, 0],  # 6
                                           [0, 0, 0, 0, 0, 0, 0, 0]])) # 7

        self.symbols_list.append(np.array([[0, 0, 0, 1, 1, 0, 0, 0],  # 0
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 1
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 2
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 3
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 4
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 5
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 6
                                           [0, 0, 0, 1, 1, 0, 0, 0]])) # 7

        self.symbols_list.append(np.array([[0, 0, 0, 1, 1, 0, 0, 0],  # 0
                                           [0, 0, 1, 0, 0, 1, 0, 0],  # 1
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 2
                                           [1, 0, 0, 0, 0, 0, 0, 1],  # 3
                                           [1, 0, 0, 0, 0, 0, 0, 1],  # 4
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 5
                                           [0, 0, 1, 0, 0, 1, 0, 0],  # 6
                                           [0, 0, 0, 1, 1, 0, 0, 0]])) # 7

        self.symbols_list.append(np.array([[0, 0, 0, 1, 1, 0, 0, 0],  # 0
                                           [0, 0, 0, 1, 1, 0, 0, 0],  # 1
                                           [0, 0, 1, 0, 0, 1, 0, 0],  # 2
                                           [0, 0, 1, 0, 0, 1, 0, 0],  # 3
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 4
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 5
                                           [1, 0, 0, 0, 0, 0, 0, 1],  # 6
                                           [1, 1, 1, 1, 1, 1, 1, 1]])) # 7

        self.symbols_list.append(np.array([[1, 1, 0, 0, 0, 0, 1, 1],  # 0
                                           [1, 1, 0, 0, 0, 0, 1, 1],  # 1
                                           [1, 1, 0, 0, 0, 0, 1, 1],  # 2
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 3
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 4
                                           [1, 1, 0, 0, 0, 0, 1, 1],  # 5
                                           [1, 1, 0, 0, 0, 0, 1, 1],  # 6
                                           [1, 1, 0, 0, 0, 0, 1, 1]])) # 7

        self.symbols_list.append(np.array([[0, 0, 1, 0, 0, 1, 0, 0],  # 0
                                           [0, 0, 1, 0, 0, 1, 0, 0],  # 1
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 2
                                           [0, 0, 1, 0, 0, 1, 0, 0],  # 3
                                           [0, 0, 1, 0, 0, 1, 0, 0],  # 4
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 5
                                           [0, 0, 1, 0, 0, 1, 0, 0],  # 6
                                           [0, 0, 1, 0, 0, 1, 0, 0]])) # 7

        self.symbols_list.append(np.array([[0, 0, 0, 0, 0, 0, 0, 0],  # 0
                                           [0, 1, 1, 1, 1, 1, 1, 0],  # 1
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 2
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 3
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 4
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 5
                                           [0, 1, 1, 1, 1, 1, 1, 0],  # 6
                                           [0, 0, 0, 0, 0, 0, 0, 0]])) # 7

        self.symbols_list.append(np.array([[0, 0, 0, 0, 0, 0, 0, 0],  # 0
                                           [0, 1, 1, 1, 1, 1, 1, 0],  # 1
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 2
                                           [0, 1, 1, 1, 1, 1, 1, 0],  # 3
                                           [0, 1, 1, 1, 1, 1, 1, 0],  # 4
                                           [0, 1, 0, 0, 0, 0, 1, 0],  # 5
                                           [0, 1, 1, 1, 1, 1, 1, 0],  # 6
                                           [0, 0, 0, 0, 0, 0, 0, 0]])) # 7

        self.symbols_list.append(np.array([[0, 0, 0, 0, 0, 0, 0, 0],  # 0
                                           [0, 1, 1, 1, 1, 1, 1, 0],  # 1
                                           [0, 1, 0, 1, 1, 0, 1, 0],  # 2
                                           [0, 1, 0, 1, 1, 0, 1, 0],  # 3
                                           [0, 1, 0, 1, 1, 0, 1, 0],  # 4
                                           [0, 1, 0, 1, 1, 0, 1, 0],  # 5
                                           [0, 1, 1, 1, 1, 1, 1, 0],  # 6
                                           [0, 0, 0, 0, 0, 0, 0, 0]])) # 7

        self.symbols_list.append(np.array([[1, 1, 1, 1, 1, 1, 1, 1],  # 0
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 1
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 2
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 3
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 4
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 5
                                           [1, 1, 1, 1, 1, 1, 1, 1],  # 6
                                           [1, 1, 1, 1, 1, 1, 1, 1]])) # 7

    def get_number_symbols(self):
        #return len(self.symbols_list)
        if self.img_size == 6:
            return len(self.symbols_list)
        elif self.img_size < 6:
            return self.img_size**2 // 3


    def get_symbol(self, idx):
        return self.symbols_list[idx]


class Symb_Image_Dataset(Dataset):

    def __init__(self, n_support, n_query, n_tasks, img_size=6, device=None):

        #img_size = 4
        self.symbols_generator = Symbols(img_size)

        self.ns = n_support
        self.nq = n_query
        self.n_tasks = n_tasks

        self.img_size = img_size

        # self.imgs = np.zeros((n_tasks, self.ns+self.nq, 1, 100, 100))
        #self.x = np.ones((n_tasks, self.ns+self.nq, self.img_size * self.img_size)) * -1
        #self.y = np.zeros((n_tasks, self.ns+self.nq))



        # for i in tqdm(range(n_tasks)):

        #     task = np.random.randint(0, self.symbols_generator.get_number_symbols())

        #     total_sym_id = np.array([x for x in np.arange(len(self.symbols_generator.symbols_list)) if x!= task])

        #     for j in range(self.ns+self.nq):
        #         #ipdb.set_trace()
                
        #         symbols_sampled = np.random.randint(0, 4, self.symbols_generator.get_number_symbols())
        #         #symbols_sampled = np.random.randint(0, 4, self.symbols_generator.get_number_symbols())
        #         #symbols = np.concatenate([np.ones(n_symb) * symbol_id for symbol_id, n_symb in enumerate(symbols_sampled)])

        #         #ipdb.set_trace()
        #         other_symbols = np.random.choice(total_sym_id, self.symbols_generator.get_number_symbols()-1, replace=False)
        #         sym_id = np.append(other_symbols, task)

        #         symbols = np.concatenate([np.ones(n_symb) * symbol_id for symbol_id, n_symb in zip(sym_id, symbols_sampled)])

        #         #symbols = 


        #         #print("Symbols sampled", symbols_sampled.shape)
        #         #print("Symbols", symbols)

        #         cell_index = np.arange(self.img_size * self.img_size)
        #         np.random.shuffle(cell_index)
        #         #ipdb.set_trace()

        #         #ipdb.set_trace()
                
        #         try:
        #             self.x[i, j, cell_index[:int(symbols.shape[0])]] = symbols.copy()
        #         except Exception as e:
        #             print(e)
        #             ipdb.set_trace()

        #         # for symb_n, symbol in enumerate(symbols):
        #         #     cell_x = int(cell_index[symb_n] / 10)
        #         #     cell_y = int(cell_index[symb_n] - cell_x * 10)
        #         #     self.imgs[i, j, 0, cell_x*10+1:cell_x*10+9, cell_y*10+1:cell_y*10+9] = self.symbols_generator.symbols_list[int(symbol)].copy()

        #         #self.y[i, j] = symbols_sampled[task]
        #         num_ = sum(symbols == task)
        #         self.y[i, j] = num_
        #         #self.y[i, j] = 

        self.device = device
        #ipdb.set_trace()

    def make_img(self, symbols):

        #img = np.zeros((1, 1, 60, 60))
        img = np.zeros((1, 1, self.img_size * 10, self.img_size * 10))

        #ipdb.set_trace()
        #ipdb.set_trace()
        for cell_n, symbol in enumerate(symbols):
            if symbol != -1:
                cell_x = int(cell_n / self.img_size)
                cell_y = int(cell_n - cell_x * self.img_size)
                img[0, 0, cell_x * 10 + 1:cell_x * 10 + 9, cell_y * 10 + 1:cell_y * 10 + 9] = self.symbols_generator.symbols_list[int(symbol)].copy()
        return img

    
    def generate_task(self):

        #ipdb.set_trace()
        xs = np.ones((self.ns + self.nq, self.img_size * self.img_size))*-1
        ys = np.zeros((self.ns + self.nq))

        task = np.random.randint(0, self.symbols_generator.get_number_symbols())

        total_sym_id = np.array([x for x in np.arange(len(self.symbols_generator.symbols_list)) if x!= task])
        # [0,1,2,3,4, 5, 7, 8, 9]

        for j in range(self.ns+self.nq):
            num_task = np.random.randint(1, 4)

            other_symbols_num = np.random.randint(0, 4, self.symbols_generator.get_number_symbols()-2)
            #num_task = 4
            # [0, 1, 4, 2, 3, 1] + [4] + [4]
            other_symbols_num = np.append(other_symbols_num, num_task)
            np.random.shuffle(other_symbols_num)

            symbols_sampled = np.append(other_symbols_num, num_task)

            #symbols_sampled = np.random.randint(0, 4, self.symbols_generator.get_number_symbols())

            other_symbols = np.random.choice(total_sym_id, self.symbols_generator.get_number_symbols()-1, replace=False)

            sym_id = np.append(other_symbols, task)

            symbols = np.concatenate([np.ones(n_symb) * symbol_id for symbol_id, n_symb in zip(sym_id, symbols_sampled)])

            cell_index = np.arange(self.img_size * self.img_size)
            np.random.shuffle(cell_index)
            
            try:
                xs[j, cell_index[:int(symbols.shape[0])]] = symbols.copy()
            except Exception as e:
                print(e)
                ipdb.set_trace()
            num_ = sum(symbols == task)
            ys[j] = num_

        return xs, ys

    def __len__(self):
        return self.n_tasks

    def __getitem__(self, idx):

        #x, y = 

        x_task, y_task = self.generate_task()
        
        #task = idx

        indices = np.arange(self.ns+self.nq)
        np.random.shuffle(indices)

        imgs = np.concatenate([self.make_img(x_task[i]) for i in indices], 0)

        xs = torch.from_numpy(imgs[:self.ns]).float().to(self.device)
        ys = torch.from_numpy(y_task[indices[:self.ns]]).long().to(self.device)
        xq = torch.from_numpy(imgs[self.ns:]).float().to(self.device)
        yq = torch.from_numpy(y_task[indices[self.ns:]]).long().to(self.device)

        return xs, ys, xq, yq


if __name__ == '__main__':

    dataset = Symb_Image_Dataset(5, 50, 100, "cuda:0")
    dataset.__getitem__(0)