import torch
import numpy as np
import torch.utils.data as Data

def index_assignment(index, row, col, pad_length):
    '''
    Return:
        new_assign: a dict,
                key: the id in the index
                value: [i, j], x,y index of the patch center point in the padded image
    '''
    new_assign = {}
    for counter, value in enumerate(index):
        assign_0 = value // col + pad_length
        assign_1 = value % col + pad_length
        new_assign[counter] = [assign_0, assign_1]
    return new_assign

def select_patch(matrix, pos_row, pos_col, ex_len):
    selected_rows = matrix[range(pos_row-ex_len, pos_row+ex_len+1)]
    selected_patch = selected_rows[:, range(pos_col-ex_len, pos_col+ex_len+1)]
    return selected_patch


def select_small_cubic(data_size, data_indices, whole_data, patch_length, padded_data, dimension):
    '''
    Give the patched data from image, given the pixel index
    Args:
        data_size: the len of data, N
        data_indices: the list of data index, in 1D 
        whole_data: original image, shape (H, W, C)
        patch_length: the patch radius
        padded_data: padded image, shape (H + 2P, W + 2P, C)
        dimension: C, channel dimention
    Return:
        small_cubic_data: patch data, shape (N, 2P+1, 2P+1, C)
    '''
    small_cubic_data = np.zeros((data_size, 2 * patch_length + 1, 2 * patch_length + 1, dimension))
    data_assign = index_assignment(data_indices, whole_data.shape[0], whole_data.shape[1], patch_length)
    for i in range(len(data_assign)):
        small_cubic_data[i] = select_patch(padded_data, data_assign[i][0], data_assign[i][1], patch_length)
    return small_cubic_data


def select_small_cubic_new(data_indices, whole_data, patch_length):
    '''
    Give the patched data from image, given the pixel index
    Args:
        data_indices: the list of data index, in 1D 
        whole_data: original image, shape (H, W, C)
        patch_length: the patch radius
    Return:
        small_cubic_data: patch data, shape (N, 2P+1, 2P+1, C)
    '''
    H, W, C = whole_data.shape
    N = len(data_indices)
    small_cubic_data = np.zeros((N, 2 * patch_length + 1, 2 * patch_length + 1, C))
    data_assign = index_assignment(data_indices, H, W, patch_length)
    # do zero padding
    # padded_data: shape (H + 2p, W + 2p, C)
    padded_data = np.lib.pad(
                    whole_data, ((patch_length, patch_length), (patch_length, patch_length),
                     (0, 0)),
                    'constant',
                    constant_values=0)

    for i in range(len(data_assign)):
        x, y = data_assign[i]
        small_cubic_data[i] = select_patch(padded_data, x, y, patch_length)
    return small_cubic_data

def get_dataset_loader(data_indices, whole_data, patch_length, labels, batch_size = 32, do_shuffle = False):
    y = labels[data_indices]

    # x_patch: shape (N, 2P+1, 2P+1, C)
    x_patch = select_small_cubic_new(data_indices, whole_data, patch_length)

    # x_patch_tensor: shape (N, 1, 2P+1, 2P+1, C) 
    x_patch_tensor = torch.from_numpy(x_patch).type(torch.FloatTensor).unsqueeze(1)
    # y_tensor: shape (N)
    y_tensor = torch.from_numpy(y).type(torch.FloatTensor)

    dataset = Data.TensorDataset(x_patch_tensor, y_tensor)

    loader = Data.DataLoader(
        dataset=dataset,  # torch TensorDataset format
        batch_size=batch_size,  # mini batch size
        shuffle=do_shuffle,  
        num_workers=0, 
    )
    return dataset, loader

def generate_iter(TRAIN_SIZE, train_indices, TEST_SIZE, test_indices, TOTAL_SIZE, total_indices, VAL_SIZE,
                  whole_data, PATCH_LENGTH, padded_data, INPUT_DIMENSION, batch_size, gt):
    gt_all = gt[total_indices] - 1
    y_train = gt[train_indices] - 1
    y_test = gt[test_indices] - 1

    all_data =  select_small_cubic(TOTAL_SIZE, total_indices, whole_data,
                                                      PATCH_LENGTH, padded_data, INPUT_DIMENSION)
    # train_data: shape (num_train, 2P+1, 2P+1, C)
    train_data = select_small_cubic(TRAIN_SIZE, train_indices, whole_data,
                                                        PATCH_LENGTH, padded_data, INPUT_DIMENSION)
    print(train_data.shape)
    # test_data: shape (num_test_all, 2P+1, 2P+1, C)
    test_data =  select_small_cubic(TEST_SIZE, test_indices, whole_data,
                                                       PATCH_LENGTH, padded_data, INPUT_DIMENSION)

    # x_train: shape (num_train, 2P+1, 2P+1, C)
    x_train = train_data.reshape(train_data.shape[0], train_data.shape[1], train_data.shape[2], INPUT_DIMENSION)
    # x_test_all: shape (num_test_all, 2P+1, 2P+1, C)
    x_test_all = test_data.reshape(test_data.shape[0], test_data.shape[1], test_data.shape[2], INPUT_DIMENSION)

    # x_val: shape (num_val, 2P+1, 2P+1, C)
    x_val = x_test_all[-VAL_SIZE:]
    y_val = y_test[-VAL_SIZE:]

    # x_test: shape (num_test, 2P+1, 2P+1, C)
    x_test = x_test_all[:-VAL_SIZE]
    y_test = y_test[:-VAL_SIZE]
    
    # x1_tensor_train: shape (num_train, 1, 2P+1, 2P+1, C) 
    x1_tensor_train = torch.from_numpy(x_train).type(torch.FloatTensor).unsqueeze(1)
    # y1_tensor_train: shape (num_train)
    y1_tensor_train = torch.from_numpy(y_train).type(torch.FloatTensor)
    torch_dataset_train = Data.TensorDataset(x1_tensor_train, y1_tensor_train)

    # x1_tensor_valida: shape (num_val, 1, 2P+1, 2P+1, C)
    x1_tensor_valida = torch.from_numpy(x_val).type(torch.FloatTensor).unsqueeze(1)
    # y1_tensor_valida: shape (num_val)
    y1_tensor_valida = torch.from_numpy(y_val).type(torch.FloatTensor)
    torch_dataset_valida = Data.TensorDataset(x1_tensor_valida, y1_tensor_valida)

    # x1_tensor_test: shape (num_test, 1, 2P+1, 2P+1, C)
    x1_tensor_test = torch.from_numpy(x_test).type(torch.FloatTensor).unsqueeze(1)
    # y1_tensor_test: shape (num_test)
    y1_tensor_test = torch.from_numpy(y_test).type(torch.FloatTensor)
    torch_dataset_test = Data.TensorDataset(x1_tensor_test,y1_tensor_test)

    all_data.reshape(all_data.shape[0], all_data.shape[1], all_data.shape[2], INPUT_DIMENSION)
    all_tensor_data = torch.from_numpy(all_data).type(torch.FloatTensor).unsqueeze(1)
    all_tensor_data_label = torch.from_numpy(gt_all).type(torch.FloatTensor)
    torch_dataset_all = Data.TensorDataset(all_tensor_data, all_tensor_data_label)


    train_iter = Data.DataLoader(
        dataset=torch_dataset_train,  # torch TensorDataset format
        batch_size=batch_size,  # mini batch size
        shuffle=True,  
        num_workers=0, 
    )
    valiada_iter = Data.DataLoader(
        dataset=torch_dataset_valida,  # torch TensorDataset format
        batch_size=batch_size,  # mini batch size
        shuffle=True,  
        num_workers=0, 
    )
    test_iter = Data.DataLoader(
        dataset=torch_dataset_test,  # torch TensorDataset format
        batch_size=batch_size,  # mini batch size
        shuffle=False, 
        num_workers=0, 
    )
    all_iter = Data.DataLoader(
        dataset=torch_dataset_all,  # torch TensorDataset format
        batch_size=batch_size,  # mini batch size
        shuffle=False, 
        num_workers=0, 
    )
    return train_iter, valiada_iter, test_iter, all_iter #, y_test


def generate_iter_new(TRAIN_SIZE, train_indices,
                  whole_data, PATCH_LENGTH, padded_data, INPUT_DIMENSION, batch_size, gt):
    # gt_all = gt[total_indices] - 1
    y_train = gt[train_indices] 
    # y_test = gt[test_indices] - 1

    # all_data =  select_small_cubic(TOTAL_SIZE, total_indices, whole_data,
    #                                                   PATCH_LENGTH, padded_data, INPUT_DIMENSION)
    # train_data: shape (num_train, 2P+1, 2P+1, C)
    train_data = select_small_cubic(TRAIN_SIZE, train_indices, whole_data,
                                                        PATCH_LENGTH, padded_data, INPUT_DIMENSION)
    # print(train_data.shape)
    # # test_data: shape (num_test_all, 2P+1, 2P+1, C)
    # test_data =  select_small_cubic(TEST_SIZE, test_indices, whole_data,
    #                                                    PATCH_LENGTH, padded_data, INPUT_DIMENSION)

    # x_train: shape (num_train, 2P+1, 2P+1, C)
    x_train = train_data.reshape(train_data.shape[0], train_data.shape[1], train_data.shape[2], INPUT_DIMENSION)
    # x_test_all: shape (num_test_all, 2P+1, 2P+1, C)
    # x_test_all = test_data.reshape(test_data.shape[0], test_data.shape[1], test_data.shape[2], INPUT_DIMENSION)

    # # x_val: shape (num_val, 2P+1, 2P+1, C)
    # x_val = x_test_all[-VAL_SIZE:]
    # y_val = y_test[-VAL_SIZE:]

    # # x_test: shape (num_test, 2P+1, 2P+1, C)
    # x_test = x_test_all[:-VAL_SIZE]
    # y_test = y_test[:-VAL_SIZE]
    
    # x1_tensor_train: shape (num_train, 1, 2P+1, 2P+1, C) 
    x1_tensor_train = torch.from_numpy(x_train).type(torch.FloatTensor).unsqueeze(1)
    # y1_tensor_train: shape (num_train)
    y1_tensor_train = torch.from_numpy(y_train).type(torch.FloatTensor)
    torch_dataset_train = Data.TensorDataset(x1_tensor_train, y1_tensor_train)

    # # x1_tensor_valida: shape (num_val, 1, 2P+1, 2P+1, C)
    # x1_tensor_valida = torch.from_numpy(x_val).type(torch.FloatTensor).unsqueeze(1)
    # # y1_tensor_valida: shape (num_val)
    # y1_tensor_valida = torch.from_numpy(y_val).type(torch.FloatTensor)
    # torch_dataset_valida = Data.TensorDataset(x1_tensor_valida, y1_tensor_valida)

    # # x1_tensor_test: shape (num_test, 1, 2P+1, 2P+1, C)
    # x1_tensor_test = torch.from_numpy(x_test).type(torch.FloatTensor).unsqueeze(1)
    # # y1_tensor_test: shape (num_test)
    # y1_tensor_test = torch.from_numpy(y_test).type(torch.FloatTensor)
    # torch_dataset_test = Data.TensorDataset(x1_tensor_test,y1_tensor_test)

    # all_data.reshape(all_data.shape[0], all_data.shape[1], all_data.shape[2], INPUT_DIMENSION)
    # all_tensor_data = torch.from_numpy(all_data).type(torch.FloatTensor).unsqueeze(1)
    # all_tensor_data_label = torch.from_numpy(gt_all).type(torch.FloatTensor)
    # torch_dataset_all = Data.TensorDataset(all_tensor_data, all_tensor_data_label)


    train_iter = Data.DataLoader(
        dataset=torch_dataset_train,  # torch TensorDataset format
        batch_size=batch_size,  # mini batch size
        shuffle=True,  
        num_workers=0, 
    )
    # valiada_iter = Data.DataLoader(
    #     dataset=torch_dataset_valida,  # torch TensorDataset format
    #     batch_size=batch_size,  # mini batch size
    #     shuffle=True,  
    #     num_workers=0, 
    # )
    # test_iter = Data.DataLoader(
    #     dataset=torch_dataset_test,  # torch TensorDataset format
    #     batch_size=batch_size,  # mini batch size
    #     shuffle=False, 
    #     num_workers=0, 
    # )
    # all_iter = Data.DataLoader(
    #     dataset=torch_dataset_all,  # torch TensorDataset format
    #     batch_size=batch_size,  # mini batch size
    #     shuffle=False, 
    #     num_workers=0, 
    # )
    # return train_iter, valiada_iter, test_iter, all_iter #, y_test
    return torch_dataset_train, train_iter

