import torch
import torch.nn.functional as F
from dataloaders.sequential_dataset import SequentialDataSet
from torch.utils.data import DataLoader, Dataset
import os
import glob
import h5py
import numpy as np
import einops
from einops import rearrange
import os

import numpy as np
import scipy.signal as signal
from scipy.interpolate import interp1d


def interpolate(u, new_res, axis):
    x_hr = np.linspace(0,1, u.shape[axis], endpoint=False)
    x_lr = np.linspace(0,1, new_res, endpoint=False)
    interpolator = interp1d(x_hr, u, kind='cubic', axis = axis)
    return torch.tensor(interpolator(x_lr), dtype=torch.float)


def read_lpsda(filename, key='train'):
    '''Reads LPSDA-generated data
    :param filename: str
    '''
    with h5py.File(filename, 'r') as f:
        f = f[key]
        keys = list(f.keys())
        for k in keys:
            if str(k).startswith("pde"):
                # sometimes "pde-140-256", other times something else...
                pde_key = k
                break
        else: 
            raise ValueError("No pde key found")
        dt = torch.tensor(np.array(f['dt']), dtype=torch.float) #(B, )
        x = torch.tensor(np.array(f['x']), dtype=torch.float) # (B, S)
        t = torch.tensor(np.array(f['t']), dtype=torch.float) # (B, T)
        pde = torch.tensor(np.array(f[pde_key]), dtype=torch.float) # (B, T, S)
    return dt, x, t, pde


class LPSDALoader1D(SequentialDataSet):
    '''Reads LPSDA-generated data.
    __getitem__ returns (xx,yy,grid)
    :return xx: (B, S, t, 1)
    :return yy: (B, S, T, 1)
    :return grid: (B, S, 1)
    B = batch size
    S = spatial dimension length
    t = initial time step
    T = number state variables
    '''
    def __init__(self,
                filename,
                test_filename=None,
                saved_folder='../data/',
                reduced_resolution=1,
                reduced_resolution_t=1,
                reduced_batch=1,
                if_test=False,
                num_samples_max = -1,
                t_train = None,
                t_test = None,
                sort_by_dt = False,
                chunk_train = False,
                train_timesteps = None,
                unfold = False,
                scale = False, 
                mean = 0.0,
                std = 1.0,
                antialiasing = False,
                reduced_fc = 2,
                init_step = 0,
                interpolate_resolution = -1,
                reduced_resolution_test = 1,
                interpolate_resolution_test = -1,
                normalize_grid = False,
                ):
        if if_test:
            reduced_resolution = reduced_resolution_test
            interpolate_resolution = interpolate_resolution_test
        self.mean = mean
        self.std = std
        path = os.path.join(saved_folder, filename if not if_test else test_filename)
        dt, x, t, pde = read_lpsda(path, key = "train" if not if_test else "valid")
        # pde (B, T, S)
        if antialiasing and reduced_resolution > 1:
            fs = pde.shape[2] // reduced_resolution
            fc = fs // reduced_fc - 1
            N = 31
            filter = signal.firwin(N, cutoff=fc, fs=fs, window='hamming')

            pde = signal.filtfilt(filter, 1.0, np.array(pde), axis=-1).astype(np.float32).copy()
        self.u_hr =  pde[::reduced_batch, ::reduced_resolution_t, :]
        if interpolate_resolution > 0: 
            assert reduced_resolution == 1, "Interpolation only works with reduced_resolution = 1, got {}".format(reduced_resolution)
            pde = interpolate(pde, interpolate_resolution, axis=-1)
            x = interpolate(x, interpolate_resolution, axis=-1)

        self.u = pde[::reduced_batch, ::reduced_resolution_t, ::reduced_resolution] # (B, T, S)
        self.x = x[::reduced_batch, ::reduced_resolution] # (B, S)
        self.t = t[::reduced_batch, ::reduced_resolution_t] # (B, T)
        # self.dt = dt[::reduced_batch] # (B, )

        if num_samples_max>0:
            num_samples_max  = min(num_samples_max, self.x.shape[0])
            self.x = self.x[:num_samples_max]
            self.t = self.t[:num_samples_max]
            self.u = self.u[:num_samples_max]
            self.u_hr = self.u_hr[:num_samples_max]
            # self.dt = self.dt[:num_samples_max]

        # grid has (B, S, 1) shape
        self.x = self.x.unsqueeze(-1) # (B, S, 1)
        self.u = einops.rearrange(self.u, "b t s -> b s t 1") # (B, S, T, 1)
        self.u_hr = einops.rearrange(self.u_hr, "b t s -> b s t") # (B, S, T, 1)

        if scale: 
            self.u = (self.u -  self.mean) / self.std

        if t_train is None or t_test is None:
            raise ValueError("t_train and t_test must be specified")
        n_time_steps = t_train if not if_test else t_test
        if n_time_steps > 0:
            self.u = self.u[:,:, init_step:n_time_steps, :]
            self.t = self.t[:, init_step:n_time_steps]
            self.u_hr = self.u_hr[:, :, init_step:n_time_steps]

        self.chunk_train = chunk_train
        if not if_test and self.chunk_train:
            
            assert train_timesteps is not None, "train_timesteps must be specified"
            if unfold: 
                # make all possible combinations of n_time_steps using Unfold
                self.u = self.u.unfold(-2, train_timesteps, 1) 
                # self.u_output = self.u_output.unfold(-2, train_timesteps, 1)
                self.u = einops.rearrange(self.u, "b s t 1 nt -> (b t) s nt 1")
                # self.u_output = einops.rearrange(self.u_output, "b s t 1  nt-> (b nt) s t 1")
                self.t = self.t.unfold(-1, train_timesteps, 1)
                self.t = einops.rearrange(self.t, "b t nt -> (b t) nt")
            else: 
                self.u = einops.rearrange(self.u, "b s (t1 t2) 1 -> (b t1) s t2 1", t2 = train_timesteps)
                # self.u_output = einops.rearrange(self.u_output, "b s (t1 t2) 1 -> (b t1) s t2 1", t2 = train_timesteps)

                self.t = einops.rearrange(self.t, "b (t1 t2) -> (b t1) t2", t2 = train_timesteps)
                # self.dt = self.dt.repeat(train_timesteps)

            B = self.u.shape[0]
                
            self.x = self.x.repeat(B, 1, 1)
        
        if normalize_grid:
            self.x = ( self.x - self.x.mean(dim = 1, keepdims = True ) ) / self.x.std(dim = 1, keepdims = True )
        
        # sort x, t, u, dt by dt
        if sort_by_dt:
            print("Sorting dataset by dt")
            idx = torch.argsort(self.dt)
            self.x = self.x[idx]
            self.t = self.t[idx] 
            self.u = self.u[idx]
            # self.dt = self.dt[idx] * 2.0 * reduced_resolution_t
        
    
    def input_shape(self):
        '''Returns a tuple input shape of the dataset (Sx, [Sy], T, V), where:
        Sx, [Sz], [Sz] = spatial dimension length
        T = number of timesteps
        V = number state variables
        :return: tuple
        '''
        return self.u.shape[1:]

    def __len__(self):
        return len(self.u)
    
    def __getitem__(self, idx):
        '''returns (xx,yy,grid)
        :return xx: (B, S, t, 1)
        :return yy: (B, S, T, 1)
        :return grid: (B, S, 1)
        B = batch size
        S = spatial dimension length
        t = initial time step
        T = number state variables
        '''
        return self.u[idx], self.x[idx]
    
    def unscale_data(self, u):
        '''Unscales the data
        '''
        return u * self.std + self.mean



        



