import torch
import torch.nn.functional as F
from dataloaders.sequential_dataset import SequentialDataSet
from torch.utils.data import DataLoader, Dataset
import os
import glob
import h5py
import numpy as np
import einops
from einops import rearrange
import os

import h5py
import numpy as np

from scipy.fftpack import diff as psdiff
from scipy.integrate import solve_ivp
from scipy.linalg import expm
import json
import time
import os


class ShiftPDE():
    ''' \partial_t u = - C * \partial_xx u + cos(x) u'''
    def __init__(self, k_max, B, C):
        self.k_max = k_max
        # self.Nk = Nk
        self.B = B
        self.C = C
    
    def initial_condition(self, x):
        A = np.random.randn(self.k_max).astype(np.float32)

        k = np.arange( self.k_max+1 , dtype=np.float32)

        u0 = np.zeros_like(x)
        for amplitude, frequency in zip(A, k):
            u0 += amplitude * np.cos(frequency *  x)
        return u0.astype(np.float32)

    def derivative(self, t, u, x):
        uxx  = psdiff(u, period=2 * np.pi, order=2)
        return - self.C * uxx + self.B * np.cos(x) * u
    

def generate_trajectory(PDE, x, t):
    u0  = PDE.initial_condition(x)
    sol = solve_ivp(lambda t, u: PDE.derivative(t, u, x), [t[0], t[-1]], u0, t_eval=t, method='RK45')

    if sol.status != 0:
        raise ValueError(sol.message)
    else:
        u = sol.y
    
    return u

def to_cosines(A,x):
    r = np.zeros_like(x)
    r += 2* A[0]
    # r += A[0]
    for i in range(1,len(A)):
        r += A[i]*np.cos(i*x)
    return r

def generate_trajectory_exponential(PDE, x, t, N, K):
    B = PDE.B
    C = PDE.C
    spatial_resolution = len(x)
    timesteps = len(t)
    k_max = PDE.k_max
    A = np.random.randn(k_max+1).astype(np.float32)
    u0f_ = np.zeros(N, dtype=np.float32)
    u0f_[:A.shape[0]] = A

    L = np.diag(C * np.arange(N)**2).astype(np.float32)
    # offdiagionals with B
    K = 2
    L += np.diag(B*np.ones(N-K).astype(np.float32), k=K)
    L += np.diag(B*np.ones(N-K).astype(np.float32), k=-K)

    L[0,K] = 2*B
    # L[1,0] = B

    t_ = t[:, np.newaxis, np.newaxis]
    L_ = L[np.newaxis, :, :]

    # expL = expm(t_*L_)

    # uf = expL @ u0f_
    uf = [u0f_]
    for i in range(1, timesteps):
        uf.append( uf[-1] + (t[i]-t[i-1]) * L @ uf[-1] )

    uf = np.array(uf)

    ue = np.zeros((timesteps, spatial_resolution), dtype=np.float32)
    for i in range(timesteps):
        ue[i,:] = to_cosines(uf[i,:], x)
    return ue



class ShiftDataset(SequentialDataSet):
    '''Generates analytic trajectories for advection
    \partial_t u = -\partial_xx u + cos(x) u
    __getitem__ returns (xx,yy,grid)
    :return xx: (B, S, t, 1)
    :return yy: (B, S, T, 1)
    :return grid: (B, S, 1)
    B = batch size
    S = spatial dimension length
    t = initial time step
    T = number state variables

    :param k_base: float, base wavenumber, the initial conditions will be a superposition of sinuosidal waves sin( 2 pi n k_base), where 1 <= n <= nk_max
    :param Nk: int, number of supserimposed waves
    '''
    def __init__(self,
                num_samples,
                generation_resolution,
                reduced_resolution,
                timesteps,
                generation_timesteps,
                T,
                B,
                C,
                k_max,
                N,
                K,
                data_dir,
                if_test=False,
                test_fraction=0.1,
                ):
        fname = f"shift_{num_samples}_{generation_resolution}_{reduced_resolution}_{timesteps}_{generation_timesteps}_{T}_{B}_{C}_{k_max}_{N}_{K}_{if_test}_{test_fraction}.h5"
        fpath = os.path.join(data_dir, fname)
        if os.path.exists(fpath):
            print("Loading dataset from", fpath)
            with h5py.File(fpath, 'r') as f:
                self.dataset = f['dataset'][:]
                self.u = f['u'][:]
                self.x = f['x'][:]
        else:
            shiftPDE = ShiftPDE(k_max, B, C)

            x = np.linspace(0, 2 * np.pi, generation_resolution, endpoint=False, dtype = np.float32)
            t = np.linspace(0, T, timesteps, endpoint=True, dtype = np.float32)
            gen_t = np.linspace(0, T, generation_timesteps, endpoint=True, dtype = np.float32)

            dataset = np.zeros((num_samples, generation_resolution, generation_timesteps)).astype(np.float32)
            for i in range(num_samples):
                dataset[i] = generate_trajectory_exponential(shiftPDE, x, gen_t, N, K) # (N, Sx, T)
            
            reduced_t = generation_timesteps // timesteps
            dataset = dataset[...,::reduced_t] # (N, Sx, T)
            
            dataset = dataset[...,np.newaxis] # (N, Sx, T, 1)
            self.dataset = dataset
            self.u = dataset[:,::reduced_resolution,:,:]
            self.x = x[::reduced_resolution, np.newaxis]
            # create dir
            os.makedirs(data_dir, exist_ok=True)
            print("Saving dataset to", fpath)
            with h5py.File(fpath, 'w') as f:
                f.create_dataset('dataset', data=self.dataset)
                f.create_dataset('u', data=self.u)
                f.create_dataset('x', data=self.x)
        
    
    def input_shape(self):
        '''Returns a tuple input shape of the dataset (Sx, [Sy], T, V), where:
        Sx, [Sz], [Sz] = spatial dimension length
        T = number of timesteps
        V = number state variables
        :return: tuple
        '''
        return self.u.shape[1:]

    def __len__(self):
        return len(self.u)
    
    def __getitem__(self, idx):
        '''returns (xx,yy,grid)
        :return xx: (B, S, t, 1)
        :return yy: (B, S, T, 1)
        :return grid: (B, S, 1)
        B = batch size
        S = spatial dimension length
        t = initial time step
        T = number state variables
        '''
        return self.u[idx], self.x
    
    # def unscale_data(self, u):
    #     '''Unscales the data
    #     '''
    #     return u * self.std + self.mean


