# This file has been taken from PIDM
from einops import rearrange
import numpy as np
import pandas as pd
from pathlib import Path
import torch
from torch.utils import data
import torch.nn.functional as F

def generalized_image_to_b_xy_c(tensor):
    """
    Transpose the tensor from [batch, channels, ..., pixel_x, pixel_y] to [batch, pixel_x*pixel_y, channels, ...]. We assume two pixel dimensions.
    """
    num_dims = len(tensor.shape) - 3  # Subtracting batch and pixel dimensions
    pattern = 'b ' + ' '.join([f'c{i}' for i in range(num_dims)]) + ' x y -> b (x y) ' + ' '.join([f'c{i}' for i in range(num_dims)])
    return rearrange(tensor, pattern)

def generalized_b_xy_c_to_image(tensor, pixels_x=None, pixels_y=None):
    """
    Transpose the tensor from [batch, pixel_x*pixel_y, channels, ...] to [batch, channels, ..., pixel_x, pixel_y] using einops.
    """
    if pixels_x is None or pixels_y is None:
        pixels_x = pixels_y = int(np.sqrt(tensor.shape[1]))
    num_dims = len(tensor.shape) - 2  # Subtracting batch and pixel dimensions (NOTE that we assume two pixel dimensions that are FLATTENED into one dimension)
    pattern = 'b (x y) ' + ' '.join([f'c{i}' for i in range(num_dims)]) + f' -> b ' + ' '.join([f'c{i}' for i in range(num_dims)]) + ' x y'
    return rearrange(tensor, pattern, x=pixels_x, y=pixels_y)

def cycle(dl):
    while True:
        for data in dl:
            yield data

class Dataset(data.Dataset):
    def __init__(
        self,
        data_directories,
        use_double = False,
        return_img = True,
        gaussian_prior = False,
    ):
        super().__init__()

        # Assuming data_directories is a tuple of file paths
        self.data_paths = list(data_directories)
        channels = len(self.data_paths)            

        # load data
        for i in range(channels):
            if i == 0:
                self.data = pd.read_csv(self.data_paths[i], header=None)
            else:
                self.data = np.stack((self.data, pd.read_csv(self.data_paths[i], header=None)), axis=-1)

        # convert to torch tensor
        dtype = torch.float64 if use_double else torch.float32
        self.data = torch.tensor(self.data, dtype=dtype)
        self.num_datapoints = len(self.data)

        if return_img:
            assert len(self.data.shape) == 3, "Data must be of shape (num_datapoints, pixels_x*pixels_y, channels)"
            self.data = generalized_b_xy_c_to_image(self.data)

        if gaussian_prior:
            # instead consider no information at all
            self.data = torch.randn_like(self.data)

    def normalize(self, arr, min_val, max_val):
        return (arr - min_val) / (max_val - min_val)

    def unnorm(self, arr, min_val, max_val):
        return arr * (max_val - min_val) + min_val

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        if index >= self.num_datapoints:
            raise IndexError('index out of range')
        return self.data[index]

class Dataset_Paths(data.Dataset):
    def __init__(
        self,
        data_directories,
        use_double = False,
        return_img = True,
        gaussian_prior = False,
        exts = ['npy'],
    ):
        super().__init__()
        
        # load topo data
        self.paths = [p for ext in exts for p in Path(f'{data_directories}').glob(f'**/*.{ext}')]
        # sort paths by number of name
        self.paths = sorted(self.paths, key=lambda x: int(x.name.split('.')[0]))
        self.num_datapoints = len(self.paths)

        # convert to torch tensor in correct dtype
        self.dtype = torch.float64 if use_double else torch.float32

        self.return_img = return_img
        self.gaussian_prior = gaussian_prior

    def normalize(self, arr, min_val, max_val):
        return (arr - min_val) / (max_val - min_val)

    def unnorm(self, arr, min_val, max_val):
        return arr * (max_val - min_val) + min_val

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):

        if index >= self.num_datapoints:
            raise IndexError('index out of range')
                             
        data_np = np.load(self.paths[index], allow_pickle = True, encoding = 'latin1')
        data = torch.tensor(data_np.transpose(2, 0, 1), dtype = self.dtype) # vf_arr, strain_energy_density_fem, von_mises_stress, disp_x, disp_y, E_field, BC_node_x, BC_node_y, load_x_img, load_y_img x pixels x pixels
        return data

def sample_images_with_squares(no_points, pixels_per_dim, dim, frame_dim = False, use_double = False):

    dtype = np.float64 if use_double else np.float32

    # Define the size of the square (e.g., a quarter of the image dimension)
    square_size = pixels_per_dim // 4

    # initialize an array to store the images
    # shape: (no_points, pixels_per_dim, pixels_per_dim, dim)
    if frame_dim:
        images = np.zeros((no_points, dim, 1, pixels_per_dim, pixels_per_dim), dtype=dtype)
    else:
        images = np.zeros((no_points, dim, pixels_per_dim, pixels_per_dim), dtype=dtype)

    for i in range(no_points):
        # randomly choose the top-left corner of the square
        x_start = np.random.randint(0, pixels_per_dim - square_size)
        y_start = np.random.randint(0, pixels_per_dim - square_size)

        for j in range(dim):
            # draw the square in each channel of the image
            if frame_dim:
                images[i, j, :, x_start:x_start + square_size, y_start:y_start + square_size] = 1.
            else:
                images[i, j, x_start:x_start + square_size, y_start:y_start + square_size] = 1.
                
    return images

class SquareImagesDataset(Dataset):
    def __init__(self, no_points, pixels_per_dim, dim, frame_dim=False, use_double=False):
        """
        Args:
            no_points (int): Number of images to generate.
            pixels_per_dim (int): The size of each image dimension.
            dim (int): Number of channels in the image.
            frame_dim (bool): Whether to include an additional frame dimension.
            use_double (bool): Whether to use double precision.
        """
        self.data = sample_images_with_squares(no_points, pixels_per_dim, dim, frame_dim, use_double)
        self.use_double = use_double

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError('Index out of range')

        # Convert numpy array to PyTorch tensor
        dtype = torch.float64 if self.use_double else torch.float32
        image = torch.tensor(self.data[index], dtype=dtype)

        return image

class Normalization:
    def __init__(self,data,dataType,strategy):
        self.mu = torch.mean(data,dim=0)
        self.std = torch.std(data,dim=0)
        self.min = torch.min(data,dim=0)[0]
        self.max = torch.max(data,dim=0)[0]
        self.globalmin = torch.min(data)
        self.globalmax = torch.max(data)
        self.dataType = dataType
        self.cols = data.size()[1]
        self.strategy = strategy
    
    def normalize(self, data):
        list_index_cat = []       
        temp = torch.zeros(data.shape,device=data.device)
        for i in range(0, self.cols):
            if self.dataType[i] == 'continuous':

                if(self.strategy == 'min-max-1'):
                    #scale to [0,1]
                    temp[:,i] = torch.div(data[:,i]-self.min[i], self.max[i]-self.min[i])

                elif(self.strategy == 'global-min-max-1'):
                    #scale to [-1,1] based on min max of full dataset
                    temp[:,i] = torch.div(data[:,i]-self.globalmin, self.globalmax-self.globalmin)

                elif(self.strategy == 'min-max-2'):
                    #scale to [-1,1]
                    temp[:,i] = 2.*torch.div(data[:,i]-self.min[i], self.max[i]-self.min[i])-1.

                elif(self.strategy == 'global-min-max-2'):
                    #scale to [-1,1] based on min max of full dataset
                    temp[:,i] = 2.*torch.div(data[:,i]-self.globalmin, self.globalmax-self.globalmin)-1.

                elif(self.strategy == 'mean-std'):
                    #scale s.t. mean=0, std=1
                    temp[:,i] = torch.div(data[:,i]-self.mu[i], self.std[i])

                elif (self.strategy == 'none'):
                    temp[:,i] = data[:,i]

                else:
                    raise ValueError('Incorrect normalization strategy')

            elif self.dataType[i] == 'categorical':
                #convert categorical features into binaries and append at the end of feature tensor
                temp = torch.cat((temp,F.one_hot(data[:,i].to(torch.int64))),dim=1)
                list_index_cat = np.append(list_index_cat,i)
                                   
            else:
                raise ValueError("Data type must be either continuous or categorical")

        # delete original (not one-hot encoded) categorical features
        j = 0
        for i in np.array(list_index_cat, dtype=np.int64):          
            temp = torch.cat([temp[:,0:i+j], temp[:,i+1+j:]],dim=1)
            j -= 1

        return temp

    def unnormalize(self, data):
        temp = torch.zeros(data.shape,device=data.device)
        for i in range(0, self.cols):
            if self.dataType[i] == 'continuous':
                
                if(self.strategy == 'min-max-1'):
                    temp[:,i] = torch.mul(data[:,i], self.max[i]-self.min[i]) +self.min[i]

                elif(self.strategy == 'global-min-max-1'):
                    temp[:,i] = torch.mul(data[:,i], self.globalmax-self.globalmin) +self.globalmin

                elif(self.strategy == 'min-max-2'):
                    temp[:,i] = torch.mul(0.5*data[:,i]+0.5, self.max[i]-self.min[i]) +self.min[i]

                elif(self.strategy == 'global-min-max-2'):
                    temp[:,i] = torch.mul(0.5*data[:,i]+0.5, self.globalmax-self.globalmin) +self.globalmin
            
                elif(self.strategy == 'mean-std'):
                    temp[:,i] = torch.mul(data[:,i], self.std[i]) + self.mu[i]

                elif (self.strategy == 'none'):
                    temp[:,i] = data[:,i]

                else:
                    raise ValueError('Incorrect normalization strategy')
                
            elif self.dataType[i] == 'categorical':
                temp[:,i] = data[:,i]

            else:
                raise ValueError("Data type must be either continuous or categorical")
        return temp
