import torch
import torch.nn.functional as F
from dataloaders.sequential_dataset import SequentialDataSet
from torch.utils.data import DataLoader, Dataset
import os
import glob
import h5py
import numpy as np
import einops
from einops import rearrange
import os


class Advection1D(SequentialDataSet):
    '''Generates analytic trajectories for advection
    \partial_t u = \beta \partial_x u
    __getitem__ returns (xx,yy,grid)
    :return xx: (B, S, t, 1)
    :return yy: (B, S, T, 1)
    :return grid: (B, S, 1)
    B = batch size
    S = spatial dimension length
    t = initial time step
    T = number state variables

    :param k_base: float, base wavenumber, the initial conditions will be a superposition of sinuosidal waves sin( 2 pi n k_base), where 1 <= n <= nk_max
    :param Nk: int, number of supserimposed waves
    '''
    def __init__(self,
                num_samples,
                spatial_resolution,
                timesteps,
                k_base,
                nk_max,
                Nk,
                beta = 0.07,
                if_test=False,
                test_fraction=0.1,
                transformation=None,
                # boundary_family="periodic",
                ):
        if if_test:
            num_samples = int(num_samples * test_fraction)
        x = np.linspace(0, 1, spatial_resolution, endpoint=False, dtype = np.float32)
        x_hr = np.linspace(0, 1, spatial_resolution*4, endpoint=False, dtype = np.float32)
        t = np.linspace(0, 1, timesteps, endpoint=False, dtype = np.float32)
        xx, tt = np.meshgrid(x, t)
        xx_hr, tt_hr = np.meshgrid(x_hr, t)

        # sample Nk As that sum to 1
        A = np.random.rand(num_samples, Nk).astype(np.float32)
        A = A / A.sum(axis=1, keepdims=True)

        xx = xx[np.newaxis, np.newaxis,:, :, ]
        xx_hr = xx_hr[np.newaxis, np.newaxis,:, :, ]
        tt = tt[np.newaxis, np.newaxis,:, :, ]
        tt_hr = tt_hr[np.newaxis, np.newaxis,:, :, ]
        
        # if boundary_family == "periodic":
        k = np.random.randint(1, nk_max, size=(num_samples, Nk)) * k_base 
        
        # phi = np.random.rand(num_samples, Nk) * 2 * np.pi
        
        # analytic solution: u = u0(x +beta t)
        
        k = k[:, :, np.newaxis, np.newaxis]
        # phi = phi[:, :, np.newaxis, np.newaxis]
    
        u = np.einsum("nm , nmxt -> nxt", A, np.sin(2 * np.pi * k * ( xx - beta  * tt )) ) # (N, T, Sx)
        u_hr = np.einsum("nm , nmxt -> nxt", A, np.sin(2 * np.pi * k * ( xx_hr - beta  * tt_hr )) ) # (N, T, Sx)
        # u = np.einsum("nm , nmxt -> nxt", A, np.sin(2 * np.pi * k * ( xx - beta  * tt)) ) # (N, T, Sx)
        u = u[...,np.newaxis] # (N, T, Sx, 1)
        u = u.transpose(0,2,1,3) # (N, Sx, T, 1)
        u_hr = u_hr[...,np.newaxis] # (N, T, Sx, 1)
        u_hr = u_hr.transpose(0,2,1,3) # (N, Sx, T, 1)
        
        # elif boundary_family == "neumann":
        #     # u(x,0)=(x−a)(b−x)

        if transformation is None:
            pass
        elif transformation == "abs":
            u = np.abs(u)
            u_hr = np.abs(u_hr)
        elif transformation == "square":
            u = u**2
            u_hr = u_hr**2
        elif transformation == "cube":
            u = u**3
            u_hr = u_hr**3
        elif transformation == "sigmoid":
            u = 1/(1+np.exp(-u))
            u_hr = 1/(1+np.exp(-u_hr))
        elif transformation == "exp":
            u = np.exp(u)
            u_hr = np.exp(u_hr)
        else: 
            raise ValueError(f"Transformation not recognized {transformation}")
            

        self.u = u.astype(np.float32)
        self.x = np.repeat(x[np.newaxis, :, np.newaxis], num_samples, axis=0).astype(np.float32)
        self.t = t.astype(np.float32)
        self.u_hr = u_hr.astype(np.float32)
        
    
    def input_shape(self):
        '''Returns a tuple input shape of the dataset (Sx, [Sy], T, V), where:
        Sx, [Sz], [Sz] = spatial dimension length
        T = number of timesteps
        V = number state variables
        :return: tuple
        '''
        return self.u.shape[1:]

    def __len__(self):
        return len(self.u)
    
    def __getitem__(self, idx):
        '''returns (xx,yy,grid)
        :return xx: (B, S, t, 1)
        :return yy: (B, S, T, 1)
        :return grid: (B, S, 1)
        B = batch size
        S = spatial dimension length
        t = initial time step
        T = number state variables
        '''
        return self.u[idx], self.x[idx]
    
    # def unscale_data(self, u):
    #     '''Unscales the data
    #     '''
    #     return u * self.std + self.mean




class Diffusion1D(SequentialDataSet):
    '''Generates analytic trajectories for advection
    \partial_t u = dc \partial_xx u
    __getitem__ returns (xx,yy,grid)
    :return xx: (B, S, t, 1)
    :return yy: (B, S, T, 1)
    :return grid: (B, S, 1)
    B = batch size
    S = spatial dimension length
    t = initial time step
    T = number state variables

    :param k_base: float, base wavenumber, the initial conditions will be a superposition of sinuosidal waves sin( 2 pi n k_base), where 1 <= n <= nk_max
    :param Nk: int, number of supserimposed waves
    '''
    def __init__(self,
                num_samples,
                spatial_resolution,
                timesteps,
                k_base,
                nk_max,
                Nk,
                dnk = -1, # amplitude of nks
                dc = -1.0,
                if_test=False,
                test_fraction=0.1,
                transformation=None,
                # boundary_family="periodic",
                ):
        if if_test:
            num_samples = int(num_samples * test_fraction)
        if dnk == -1: 
            dnk = nk_max
        x = np.linspace(0, 1, spatial_resolution, endpoint=False,)
        x_hr = np.linspace(0, 1, spatial_resolution*4, endpoint=False,)
        t = np.linspace(0, 1, timesteps, endpoint=False,)
        xx, tt = np.meshgrid(x, t)
        xx_hr, tt_hr = np.meshgrid(x_hr, t)

        # sample Nk As that sum to 1
        A = np.random.rand(num_samples, Nk)
        A = A / A.sum(axis=1, keepdims=True)

        xx = xx[np.newaxis, np.newaxis,:, :, ]
        xx_hr = xx_hr[np.newaxis, np.newaxis,:, :, ]
        tt = tt[np.newaxis, np.newaxis,:, :, ]
        
        # if boundary_family == "periodic":
        k = np.random.randint(nk_max - dnk, nk_max, size=(num_samples, Nk)) * k_base 
        
        # analytic solution: u = u0(x +beta t)
        
        k = k[:, :, np.newaxis, np.newaxis]
    
        u = np.einsum("nm , nmxt -> nxt", A, np.sin(2 * np.pi * k * xx) * np.exp(dc * (2 * np.pi * k) ** 2 * tt))
        u_hr = np.einsum("nm , nmxt -> nxt", A, np.sin(2 * np.pi * k * xx_hr) * np.exp(dc * (2 * np.pi * k) ** 2 * tt_hr))
        u = u[...,np.newaxis] # (N, T, Sx, 1)
        u_hr = u_hr[...,np.newaxis] # (N, T, Sx, 1)
        u = u.transpose(0,2,1,3) # (N, Sx, T, 1)
        u_hr = u_hr.transpose(0,2,1,3)
        
        # elif boundary_family == "neumann":
        #     # u(x,0)=(x−a)(b−x)

        if transformation is None:
            pass
        elif transformation == "abs":
            u = np.abs(u)
            u_hr = np.abs(u_hr)
        elif transformation == "square":
            u = u**2
            u_hr = u_hr**2
        elif transformation == "cube":
            u = u**3
            u_hr = u_hr**3
        elif transformation == "sigmoid":
            u = 1/(1+np.exp(-u))
            u_hr = 1/(1+np.exp(-u_hr))
        elif transformation == "exp":
            u = np.exp(u)
            u_hr = np.exp(u_hr)
        else: 
            raise ValueError("Transformation not recognized")
            

        self.u = u.astype(np.float32)
        self.u_hr = u_hr.astype(np.float32)
        self.x = np.repeat(x[np.newaxis, :, np.newaxis], num_samples, axis=0).astype(np.float32)
        self.t = t.astype(np.float32)
        
    
    def input_shape(self):
        '''Returns a tuple input shape of the dataset (Sx, [Sy], T, V), where:
        Sx, [Sz], [Sz] = spatial dimension length
        T = number of timesteps
        V = number state variables
        :return: tuple
        '''
        return self.u.shape[1:]

    def __len__(self):
        return len(self.u)
    
    def __getitem__(self, idx):
        '''returns (xx,yy,grid)
        :return xx: (B, S, t, 1)
        :return yy: (B, S, T, 1)
        :return grid: (B, S, 1)
        B = batch size
        S = spatial dimension length
        t = initial time step
        T = number state variables
        '''
        return self.u[idx], self.x[idx]
    
    def unscale_data(self, u):
        '''Unscales the data
        '''
        return u * self.std + self.mean