import numpy as np
import random
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from torchvision import datasets,transforms

import copy
import csv
import os

import matplotlib.pyplot as plt
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import transforms as trans





##########################################################################################################################
    # data preprocessing/loaders for cifar10 and celeba dataset
##########################################################################################################################
def cifar10_metadataset(dir_root='./data', b_size=16, val_ratio=0.1):

    kwargs = {'num_workers': 0, 'pin_memory': True}
    transform = transforms.ToTensor()

    
    full_train_dataset = datasets.CIFAR10(
        dir_root, train=True, download=True, transform=transform
    )

    
    total_size = len(full_train_dataset)
    val_size = int(total_size * val_ratio)
    train_size = total_size - val_size

    
    generator = torch.Generator().manual_seed(42)

    train_dataset, val_dataset = random_split(
        full_train_dataset,
        [train_size, val_size],
        generator=generator
    )

    
    train_loader = DataLoader(
        train_dataset,
        batch_size=b_size,
        shuffle=True,
        **kwargs
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=b_size,
        shuffle=False,
        **kwargs
    )

    
    eval_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(dir_root, train=False, transform=transform),
        batch_size=b_size, shuffle=False, **kwargs
    )

    return train_loader, val_loader, eval_loader
    
def get_context_idx(N, dim,order_pixels=False):
    # generate the indeces of the N context points in a flattened image
    if order_pixels:
        idx = range(N)
    else:
        idx = random.sample(range(0, dim), N)
    idx = torch.tensor(idx).cuda()
    return idx


def generate_grid(h, w):
    rows = torch.linspace(0, 1, h).cuda()
    cols = torch.linspace(0, 1, w).cuda()
    grid = torch.stack([cols.repeat(h, 1).t().contiguous().view(-1), rows.repeat(w)], dim=1)
    grid = grid.unsqueeze(0)
    return grid


def idx_to_y(idx, data):
    # get the [0;1] pixel intensity at each index, normalize the pixel value to [0,1]
    y = torch.index_select(data, dim=1, index=idx)
    return y

def batch_to_row(batch_img):
    transposed = batch_img.transpose(0, 1)
    row_img = transposed.permute(0, 2, 1, 3).contiguous()
    row_img = row_img.view(row_img.size(0), row_img.size(1), -1)

    return row_img

def idx_to_x(idx, batch_size,dim):
    # from flat idx to 2d coordinates of the 28x28 grid. E.g. 35 -> (1, 7)
    # equivalent to np.unravel_index(), normalizie the (x,y) coordinates to [0,1]*[0,1]
    x_grid = generate_grid(dim, dim)
    x = torch.index_select(x_grid, dim=1, index=idx)
    x = x.expand(batch_size, -1, -1)
    return x
