import torch
import numpy as np

from torch.utils.data import Dataset
from einops import rearrange

class ZscoreStandardizer(object):
    """  
    Normalization transformation
    if reduce_dim = [0]: The mean is computed over different time volumes.
    if reduce_dim = []:  The mean is computed over all data points.
    """  
    def __init__(self, x, reduce_dim=[0]):
        self.mean = torch.mean(x, reduce_dim, keepdim=True).squeeze()
        self.std  = torch.std(x, reduce_dim, keepdim=True).squeeze()
        self.epsilon = 1e-10
        assert self.mean.shape == self.std.shape

    def do(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.mean) / (self.std + self.epsilon)

    def undo(self, x: torch.Tensor) -> torch.Tensor:
        return x * (self.std + self.epsilon) + self.mean


class MinMaxStandardizer(object):
    """  
    Min-Max transformation
    if reduce_dim = [0]: The min/max is computed over different time volumes.
    if reduce_dim = []:  The min/max is computed over all data points.
    """  
    def __init__(self, x, reduce_dim=[0]):
        if reduce_dim:
            self.minVal = torch.min(x, reduce_dim[0], keepdim=True)[0].squeeze()
            self.maxVal = torch.max(x, reduce_dim[0], keepdim=True)[0].squeeze()
        else:
            self.minVal = torch.min(x)
            self.maxVal = torch.max(x)


        self.epsilon = 1e-10

        assert self.minVal.shape == self.maxVal.shape

    def do(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.minVal) / (self.maxVal - self.minVal) + self.epsilon

    def undo(self, x: torch.Tensor) -> torch.Tensor:
        return (x - self.epsilon) * (self.maxVal - self.minVal) + self.minVal
    

def get_KHINR(data_dir, n_data, preprocessing, reduce_dim, task, sampling_rate):
    """
    Args:
        data_dir (string): dataset file path
        n_data (int): number of data samples 
        preprocessing (list): data preprocessing method
        reduce_dim (list): data preprocessing reduced dimension
        task (string): 
            task 1: fixed n_points,  fixed positions
            task 2: random n_points, fixed positions
            task 3: fixed n_points,  random positions
            task 4: random n_points, random positions
        sampling_rate (list): sampling rate lower and upper bound
    Returns:
        grid_lat (torch.tensor): (lat, lon) = (192, 288)
        grid_lon (torch.tensor): (lat, lon) = (192, 288)
        gst      (torch.tensor): (n_data, lat, lon) = (n_data, 192, 288)
        indices  (torch.tensor): (n_data, n_points)
    """
    # load data
    dat = np.load(data_dir)
    lat = torch.from_numpy(dat["lats"])
    lon = torch.from_numpy(dat["lons"])

    keys = [key for key in dat.keys()]

    if "temperature" in keys:
        key = "temperature"

    if "ssh" in keys:
        key = "ssh"
    
    if "chl" in keys:
        key = "chl"   

    gst = torch.from_numpy(dat[key][:n_data,:,:])
    t = torch.linspace(1e-10, 1.0, n_data)
    del dat

    # pre-process data
    lat_lon_pre, gst_pre = preprocessing
    lat_lon_reduce_dim, gst_reduce_dim = reduce_dim
    normalizer = {}

    if lat_lon_pre == "zscore":
        normalizer["lat"] = ZscoreStandardizer(lat, lat_lon_reduce_dim)
        normalizer["lon"] = ZscoreStandardizer(lon, lat_lon_reduce_dim)        
    elif lat_lon_pre == "minmax":
        normalizer["lat"] = MinMaxStandardizer(lat, lat_lon_reduce_dim)
        normalizer["lon"] = MinMaxStandardizer(lon, lat_lon_reduce_dim)
    elif lat_lon_pre == "None":
        normalizer["lat"] = MinMaxStandardizer(lat, lat_lon_reduce_dim)
        normalizer["lon"] = MinMaxStandardizer(lon, lat_lon_reduce_dim)

    if gst_pre == "zscore":
        normalizer[key] = ZscoreStandardizer(gst, gst_reduce_dim)  
        gst = normalizer[key].do(gst)   
    elif gst_pre == "minmax":
        normalizer[key] = MinMaxStandardizer(gst, gst_reduce_dim)
        gst = normalizer[key].do(gst)
    elif gst_pre == "None":
        normalizer[key] = gst
        gst = gst.float()

    lat = normalizer["lat"].do(lat)
    lon = normalizer["lon"].do(lon)
    grid_lat, grid_lon = torch.meshgrid(lat, lon, indexing="ij")

    if task == "task1":
        n_points = int(np.round(sampling_rate[1] * grid_lat.numel()))
        indices  = torch.randperm(grid_lat.numel())
        indices  = indices[:n_points]
        indices  = indices.repeat(n_data, 1)
    elif task == "task2":
        rates = np.random.uniform(sampling_rate[0], sampling_rate[1], n_data)
        n_points = [int(np.round(rates[idx] * grid_lat.numel())) for idx in range(n_data)]
        max_n_points = max(n_points)
        indices  = torch.randperm(grid_lat.numel())
        ind_lst = []
        for idx in range(n_data):
            if sampling_rate[0] == 0.01:
                sampled_indices = torch.cat((indices[:n_points[idx]],) * 5)[:max_n_points]
            else:
                sampled_indices = torch.cat((indices[:n_points[idx]],) * 2)[:max_n_points]

            # Pad to max_n_points to ensure consistent size
            if sampled_indices.numel() < max_n_points:
                padding = torch.full((max_n_points - sampled_indices.numel(),), -1, dtype=torch.long)
                sampled_indices = torch.cat((sampled_indices, padding))

            ind_lst.append(sampled_indices[None, :])

        # Concatenating along dim=0
        indices = torch.cat(ind_lst, dim=0)
    elif task == "task3":
        n_points = int(np.round(sampling_rate[1] * grid_lat.numel()))
        ind_lst = []
        for idx in range(n_data):
            indices  = torch.randperm(grid_lat.numel())
            ind_lst.append(indices[:n_points][None,:])
        indices = torch.cat(ind_lst, dim=0)        
    elif task == "task4":
        rates = np.random.uniform(sampling_rate[0], sampling_rate[1], n_data)
        n_points = [int(np.round(rates[idx] * grid_lat.numel())) for idx in range(n_data)]
        max_n_points = max(n_points)
        ind_lst = []
        for idx in range(n_data):
            indices  = torch.randperm(grid_lat.numel())
            # sampling_rate[0] = 0.01 then five times to ensure sufficiency
            if sampling_rate[0] == 0.01:
                sampled_indices = torch.cat((indices[:n_points[idx]], indices[:n_points[idx]], indices[:n_points[idx]], indices[:n_points[idx]], indices[:n_points[idx]]))[:max_n_points]#[None,:])
            else:
                sampled_indices = torch.cat((indices[:n_points[idx]], indices[:n_points[idx]]))[:max_n_points]#[None,:])

            if sampled_indices.numel() < max_n_points:
                padding = torch.full((max_n_points - sampled_indices.numel(),), -1, dtype=torch.long)
                sampled_indices = torch.cat((sampled_indices, padding))
            
            ind_lst.append(sampled_indices[None, :])
        indices = torch.cat(ind_lst, dim=0)
    return n_points, grid_lat, grid_lon, t, gst, indices, normalizer

 
class KHINR(Dataset):
    def __init__(self,
        n_points: int,
        temperature: torch.Tensor,
        grid_lat: torch.Tensor,
        grid_lon: torch.Tensor,
        indices: torch.Tensor,
        time: torch.Tensor,
    ):
        self.temperature = temperature
        self.grid_lat = grid_lat.flatten()
        self.grid_lon = grid_lon.flatten()
        self.indices = indices
        self.n_samples = temperature.shape[0]
        self.time = time
        # print(self.time.shape)
        # exit()
 
    def __len__(self):
        # return len(self.temperature)
        return self.n_samples
 
    def __getitem__(self, idx):
        ind = self.indices[idx]
        lat = self.grid_lat[ind]
        lon = self.grid_lon[ind]
        temp = self.temperature[idx].reshape(-1)
        temp = temp[ind]
        time = self.time[idx].expand(lat.shape[0])
 
        # print(lat.shape, lon.shape, time.shape)
        # exit()
        
        # input lat, lon | output temperature | idx for latent variables
        # size (n_points, 2) | (n_points, 1)
        # input_x = torch.stack((lat, lon), dim=1)
        # print(input_x.shape)
        # z = torch.stack((input_x, temp.unsqueeze(1)), dim=1) # torch.unsqueeze(t, dim=1)
 
        # x = torch.stack([lat, lon, time], dim=-1)
        x_loc = torch.stack([lat, lon], dim=-1)
        y = temp.unsqueeze(-1)
        z = torch.cat([x_loc, y], dim=-1)
        # free memory
        del lat, lon, temp, ind
        return z, x_loc, y, idx
 
class vizKHINR(Dataset):
    def __init__(self,
        n_points: int,
        temperature: torch.Tensor,
        grid_lat: torch.Tensor,
        grid_lon: torch.Tensor,
        indices: torch.Tensor,
        time_t: torch.Tensor,
    ):
        self.n_data = n_points
        self.temperature = temperature
        self.grid_lat = grid_lat
        self.grid_lon = grid_lon
        self.indices = indices
        self.time = time_t   
        # self.grid_lat_sampled, self.grid_lon_sampled, self.time_t_sampled, self.temperature_sampled = assemble_xy(grid_lat, grid_lon, time_t, temperature, indices)
        
        H, W = grid_lat.shape
        self.full_coords = torch.stack([
            grid_lat.flatten(), grid_lon.flatten()
        ], dim=-1)  # (H*W, 2)
 
    def __len__(self):
        return len(self.temperature)
 
    def __getitem__(self, idx):
        
        full_temp = self.temperature[idx].flatten().unsqueeze(-1)     # (H*W,)
        full_x = self.full_coords
        
 
        sparse_lat, sparse_lon, sparse_t, sparse_val = assemble_xy(
            self.grid_lat, self.grid_lon,
            self.time[idx:idx+1],        # send a batch of 1 time step
            self.temperature[idx:idx+1], # send a batch of 1 data sample
            self.indices[idx:idx+1]      # (1, n_points)
        )
        # query = torch.cat([query, self.latents[idx].unsqueeze(1).expand(-1,query.shape[1], -1)], dim=2)
        # exit()
        sparse_coords = torch.stack([
            sparse_lat[0], sparse_lon[0]#, sparse_t[0]
        ], dim=-1)                       # (n_points, 2)
        sparse_val = sparse_val[0].unsqueeze(-1)

        z = torch.cat([sparse_coords, sparse_val], dim=-1)  # (n_points, 3)
 
        return z, full_x, full_temp, idx
 






def assemble_xy(grid_lat, grid_lon, t, gst, indices):
    n_data = gst.shape[0]

    grid_lat = grid_lat.reshape(-1)
    grid_lon = grid_lon.reshape(-1)
    indices = indices.reshape(-1)
    
    # shuffling
    idx = torch.randperm(indices.numel())
    indices = indices[idx]

    grid_lat = grid_lat[indices]
    grid_lon = grid_lon[indices]

    idx_time = torch.randint(0, t.numel(), (grid_lat.numel(),))
    t = t[idx_time]
    
    gst = rearrange(gst, 'n h w -> n (h w)')
    gst = gst[idx_time, indices]

    grid_lat = grid_lat.reshape(n_data, -1)
    grid_lon = grid_lon.reshape(n_data, -1)
    t = t.reshape(n_data, -1)
    gst = gst.reshape(n_data, -1)
    return grid_lat, grid_lon, t, gst