import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.nn import functional as F
from typing import Tuple
#from common.augmentation import to_coords
import random
from sentence_transformers import SentenceTransformer, InputExample
from transformers import BertModel

import transformers
#from models.llama3.reference_impl.generation import Llama
from transformers import AutoTokenizer, AutoModel

from tqdm import tqdm

class PDEDataset(Dataset):
    """Load samples of an PDE Dataset, get items according to PDE"""

    def __init__(self,
                 path: str,
                 pde: str,
                 mode: str,
                 resolution: list=None,
                 augmentation = None,
                 augmentation_ratio: float=0.0,
                 shift: str='fourier',
                 load_all: bool=False,
                 num_samples: int=-1,
                 device: str = 'cuda:0') -> None:
        """Initialize the dataset object
        Args:
            path: path to dataset
            pde: string of PDE 
            mode: [train, valid, test]
            base_resolution: base resolution of the dataset [nt, nx]
            super_resolution: super resolution of the dataset [nt, nx]
            load_all: load all the data into memory
        Returns:
            None
        """
        super().__init__()
        f = h5py.File(path, 'r')
        self.mode = mode
        self.pde = pde
        self.resolution = (250, 100) if resolution is None else resolution
        self.data = f[self.mode]

        self.num_samples = len(self.data["u"])+2 if(num_samples == -1) else num_samples
        self.u = self.data["u"][:self.num_samples]
        self.length = len(self.u)
        self.alpha = self.data["alpha"][:self.num_samples]
        self.beta = self.data["beta"][:self.num_samples]
        self.gamma = self.data["gamma"][:self.num_samples]

        self.x = torch.tensor(np.array(self.data["x"][:self.num_samples]))
        self.t = torch.tensor(np.array(self.data["t"][:self.num_samples]))

        self.tmin = self.t[0]
        self.tmax = self.t[-1]
        self.nt = len(self.t)
        self.dt = (self.tmax - self.tmin) / self.nt

        self.xmin = self.x[0]
        self.xmax = self.x[-1]
        self.nx = len(self.x)
        self.dx = (self.xmax - self.xmin)/ self.nx
        
        self.augmentation = [] if(augmentation is None) else augmentation 
        self.shift = shift
        self.augmentation_ratio = augmentation_ratio

        self.device = device

        if load_all:
            self.u = torch.tensor(self.u[:]).to(device)
            self.alpha = torch.tensor(self.alpha[:]).to(device)
            self.beta = torch.tensor(self.beta[:]).to(device)
            self.gamma = torch.tensor(self.gamma[:]).to(device)
            self.x = self.x.to(device)
            self.t = self.t.to(device)

            f.close()

        # Use LLM if CLIP
        self.clip = clip
        if(self.clip):
            self.sentence_embedder = SentenceTransformer("all-MiniLM-L6-v2", device='cpu')
            self.sentence_embeddings = []
            for idx in range(self.x.shape[0]):
                print(idx)
            raise


    def __len__(self):
        return self.length*(len(self.augmentation)+1)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, list]:
        """
        Get data item
        Args:
            idx (int): data index
        Returns:
            torch.Tensor: numerical baseline trajectory
            torch.Tensor: downprojected high-resolution trajectory (used for training)
            torch.Tensor: spatial coordinates
            list: equation specific parameters
        """
        # Super resolution trajectories are downprojected via kernel which averages of neighboring cell values
        t_idx = idx % (len(self.augmentation) + 1)
        idx = idx // (len(self.augmentation) + 1)
        u = self.u[idx]
        x = self.x

        # Base resolution trajectories (numerical baseline) and equation specific parameters
        variables = {}
        variables['alpha'] = self.alpha[idx]
        variables['beta'] = self.beta[idx]
        variables['gamma'] = self.gamma[idx]

        if self.mode == "train" and self.augmentation is not None:
            if self.augmentation_ratio > random.random(): # augment data w/ probability augmentation_ratio
                t = self.t
                # Augment data
                X = to_coords(x, t)

                if not torch.is_tensor(u):
                    u = torch.tensor(u)

                sol = (u, X)
                sol = self.augmentation[t_idx](sol, self.shift)
                u = sol[0]

        return u, x, variables

        
class PDEDataset2D(Dataset):
    """Load samples of a 2D PDE Dataset, get items according to PDE"""

    def __init__(self,
                 path: str,
                 pde: str,
                 mode: str,
                 resolution: list=None,
                 augmentation = None,
                 augmentation_ratio: float=0.0,
                 shift: str='fourier',
                 load_all: bool=False,
                 device: str='cuda:0',
                 num_samples: int=-1,
                 clip: bool=False,
                 llm: str=None,
                 sentence: bool=False,
                 downsample: int=1,
                 debug: bool=False,
                 subset: str='heat,adv,burger',
                 coeff: bool=True,
                 qualitative: bool=False
                 ) -> None:
        """Initialize the dataset object
        Args:
            path: path to dataset
            pde: string of PDE 
            mode: [train, valid, test]
            resolution: base resolution of the dataset [nt, nx, ny]
            augmentation: Data augmentation object
            augmentation_ratio: Probability to augment data
            load_all: load all the data into memory
            device: if load_all, load data onto device
        Returns:
            None
        """
        super().__init__()
        f = h5py.File(path, 'r')
        self.mode = mode
        self.pde = pde
        self.downsample = downsample
        self.resolution = (100, 64, 64) if resolution is None else resolution
        self.llm = llm
        self.debug = debug
        self.subset = subset
        self.device = device
        self.data = f[self.mode]

        # Different embedding strategies (Maybe rename to be more clear...)
        self.clip = clip                # Use LLM
        self.sentence = sentence        # Return sentences and train LLM end-to-end
        self.coeff = coeff              # Use sentence information
        self.qualitative = qualitative  # Include qualitative information

        if(mode == 'train'):
            self.num_samples = len(self.data["u"])+2 if(num_samples == -1) else num_samples
        else:
            self.num_samples = 50 if(self.debug) else 768 # Use entire validation set
        #idxs = torch.randperm(self.data["u"].shape[0])[:self.num_samples].cpu().numpy()

        # Get data
        #self.u = self.data["u"][:][idxs][...,::self.downsample, ::self.downsample]
        self.u = torch.Tensor(self.data["u"][:][...,::self.downsample, ::self.downsample]).to(self.device)
        
        # Get coefficients
        #self.nu = torch.Tensor(self.data["nu"][:][idxs])
        #self.ax = torch.Tensor(self.data["ax"][:][idxs])
        #self.ay = torch.Tensor(self.data["ay"][:][idxs])
        #self.cx = torch.Tensor(self.data["cx"][:][idxs])
        #self.cy = torch.Tensor(self.data["cy"][:][idxs])
        self.nu = torch.Tensor(self.data["nu"][:]).to(self.device)
        self.ax = torch.Tensor(self.data["ax"][:]).to(self.device)
        self.ay = torch.Tensor(self.data["ay"][:]).to(self.device)
        self.cx = torch.Tensor(self.data["cx"][:]).to(self.device)
        self.cy = torch.Tensor(self.data["cy"][:]).to(self.device)
        self.coeffs = torch.cat((self.nu.unsqueeze(0), self.ax.unsqueeze(0),
                                 self.ay.unsqueeze(0), self.cx.unsqueeze(0), self.cy.unsqueeze(0)), dim=0).T

        # Choose subset of data
        self.total_samples = len(self.u)
        self.choose_subset(self.subset, n=num_samples)

        # Get grid and time info
        self.x = torch.tensor(np.array(self.data["x"][:]))[...,::self.downsample, ::self.downsample].to(self.device)
        self.t = torch.tensor(np.array(self.data["t"][:])).to(self.device)

        # Get potentially useful variables from space and time
        self.tmin = self.t[0]
        self.tmax = self.t[-1]
        self.nt = len(self.t)
        self.dt = torch.Tensor([(self.tmax - self.tmin) / (self.nt-1)]*len(self.t)).unsqueeze(0)

        self.xmin = self.x[0, 0, 0]
        self.xmax = self.x[0, 0, -1]
        self.nx = len(self.x[0, 0])
        self.dx = (self.xmax - self.xmin)/ self.nx
        
        self.augmentation = augmentation
        self.shift = shift
        self.augmentation_ratio = augmentation_ratio

        f.close()

        if(self.clip or self.sentence):

            # Only get sentence_embedder if we're not returning whole sentences
            if(self.llm is not None and not self.sentence):
                #self.sentence_embedder = SentenceTransformer(self.llm, device='cpu')
                print("LOADING LLM TO GPU")
                self.sentence_embedder = SentenceTransformer(self.llm, device='cuda')
            elif(not self.sentence):
                self.sentence_embedder = SentenceTransformer("all-MiniLM-L6-v2", device='cpu')

            self.sentence_embeddings = []
            self.sentences = []
            print("Getting sentence embeddings...")
            #for idx in tqdm(range(self.u.shape[0])):
            for idx in tqdm(self.indexes):
                # Burgers
                if(self.nu[idx] != 0 and self.cx[idx] != 0 and self.cy[idx] != 0):
                    ratio = ((self.cx[idx] + self.cy[idx])**2)**(0.5)/self.nu[idx]
                    #ratios.append(ratio)
                    sentence = 'Burgers equation models a conservative system that can develop shock wave discontinuities.'
                    sentence += ' Burgers equation is a first order quasilinear hyperbolic partial differential equation.'
                    if(self.coeff):
                        sentence += ' In this case, the advection term has a coefficient of {} in the x direction, '
                        sentence += '{} in the y direction, and the diffusion term has a coefficient of {}.'.format(self.cx[idx],
                                                                                                          self.cy[idx], self.nu[idx])

                        if(self.qualitative):
                            cls = 'strongly' if(ratio > 100) else 'weakly'
                            sim = ' not ' if(ratio > 100) else ' '
                            sentence += ' This system is {} advection dominanted and does{}behave similarly to heat equation.'.format(cls, sim)
                            sentence += ' Ths predicted state should have shocks.' if(cls == 'strongly') else \
                                        ' The predicted state should look smoother than the inputs'
                # Advection
                elif(self.ax[idx] != 0 and self.ay[idx] != 0):
                    adv = ((self.ax[idx] + self.ay[idx])**2)**(0.5)
                    #advs.append(adv)
                    sentence = 'The Advection equation models bulk transport of a substance or quantity. It does not develop shocks.'
                    sentence += ' The Advection equation is a linear hyperbolic partial differential equation.'
                    if(self.coeff):
                        sentence += ' In this case, the advection term has a coefficient of {} in the x direction, '
                        sentence += '{} in the y direction.'.format(self.ax[idx], self.ay[idx])
    
                        if(self.qualitative):
                            cls = 'strongly' if(adv > 2) else 'weakly'
                            sentence += ' This system is {} advective.'.format(cls)
                            sentence += ' The predicted state should look like the input but shifted in space.'

                # Heat
                elif(self.nu[idx] != 0 and self.cx[idx] == 0 and self.cy[idx] == 0):
                    sentence = 'The Heat equation models how a quantity such as heat diffuses through a given region.'
                    sentence += ' The Heat equation is a linear parabolic partial differential equation.'
                    if(self.coeff):
                        sentence += ' In this case, the diffusion term has a coefficient of {}.'.format(self.nu[idx])

                        if(self.qualitative):
                            cls = 'strongly' if(self.nu[idx] > 0.01) else 'weakly'
                            sentence += ' This system is {} diffusive.'.format(cls)
                            sentence += ' The predicted state should look smoother than the inputs.'

                sentence += " This system has periodic boundary conditions."
                #sentence += " Give me an embedding that is useful for numerically predicting the target state."
                if(self.sentence):
                    #while(len(sentence) < 650): # Pad them to have same length
                    while(len(sentence) < 400): # Pad them to have same length
                        sentence += ' '
                    #while(len(sentence) < 400): # Pad them to have same length
                    #    sentence += ' '
                    if(len(sentence) > 650):
                    #if(len(sentence) > 400):
                        print(len(sentence))
                        raise
                    self.sentences.append(sentence)
                else:
                    self.sentence_embeddings.append(self.sentence_embedder.encode(sentence))
            print("Done.")

    def __len__(self):
        return len(self.indexes)#*self.u.shape[1]

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, list, torch.Tensor]:
        """
        Get data item
        Args:
            idx (int): data index
        Returns:
            torch.Tensor: numerical baseline trajectory
            torch.Tensor: downprojected high-resolution trajectory (used for training)
            torch.Tensor: spatial coordinates
            list: equation specific parameters
        """
        original_idx = idx
        #original_idx = idx//self.u.shape[1]
        #slice_idx = idx%self.u.shape[1]
        #print()
        #print(idx, original_idx, slice_idx)
        #print()
        idx = self.indexes[idx]
        #idx = self.indexes[idx//self.u.shape[1]]
        #print(idx)
        #raise
        u = self.u[idx]
        x = self.x
        t = self.t
        
        #variables = {}
        #variables['nu'] = self.nu[idx] 
        #variables['ax'] = self.ax[idx]
        #variables['ay'] = self.ay[idx]
        #variables['cx'] = self.cx[idx]
        #variables['cy'] = self.cy[idx]

        if self.mode == "train" and self.augmentation is not None:
            if self.augmentation_ratio > random.random(): # augment data w/ probability augmentation_ratio
                pde = self.get_PDE(variables)
                if not torch.is_tensor(u):
                    u = torch.tensor(u)
                u = self.augmentation(u, pde, self.shift)

        if(self.clip and not self.sentence):
            return_u = torch.cat((u.unsqueeze(-1), torch.zeros(u.shape[0], u.shape[1], u.shape[2], 3)), dim=-1)
            return return_u, \
                   x.permute(1,2,0), \
                   self.coeffs[original_idx], \
                   self.dt[0][0], \
                   self.sentence_embeddings[original_idx]

        elif(self.clip and self.sentence):
            return u, x.permute(1,2,0), variables, self.sentences[original_idx]
        else:
            return u, x.permute(1,2,0), variables

    def get_data(self):
        self.coeffs = self.coeffs[self.indexes]
        try:
            self.sentence_embeddings = torch.Tensor(self.sentence_embeddings)
        except AttributeError:
            self.sentence_embeddings = None
        #print()
        #print(type(self.sentence_embeddings))
        #print()
        #raise
        return self.u[self.indexes].unsqueeze(-1)
    

    def get_PDE(self, variables):
        if variables['ax'] != 0 and variables['ay'] != 0:
            return "advection"
        elif variables["cx"] != 0 and variables["cy"] != 0:
            return "burgers"
        elif variables["nu"] != 0:
            return "heat"
        else:
            raise ValueError("PDE not found")


    def choose_subset(
            self,
            chosen: str = 'heat,adv,burger',
            reverse: bool = False,
            n: int = None,
            ):
        """
        Choose subset of the dataset
        Args:
            chosen: str 
                stringof chosen PDEs and subset of PDE coefficients.
                DO NOT USE ANY SPACES!
                Example:
                    'heat,nu>0.5,adv,ax<0.4,burger,cx<0.3'

                Ranges:
                    nu:
                        - burgers: [7.5e-3, 1.5e-2]
                        - heat: [3e-3, 2e-2]
                    ax, ay: [0.1, 2.5]
                    cx, cy: [0.5, 1.0]

                    
            reverse: bool.
                if True, choose all PDEs except the specified ones
            n: int or None
                number of samples to use from the specified subset
            seed: int
                random seed when choosing n samples (for reproducibility)
        Returns:
            None
        """
        gs = chosen.split(',')

        if 'adv' in gs:
            adv = ((self.ax!=0) | (self.ay!=0)) & ((self.cx==0) & (self.cy==0)) & (self.nu==0)
        else:
            adv = torch.zeros(self.total_samples).bool()

        if 'burger' in gs:
            burger =((self.ax==0) & (self.ay==0)) & ((self.cx!=0) | (self.cy!=0)) & (self.nu!=0)
        else:
            burger = torch.zeros(self.total_samples).bool()

        if 'heat' in gs:
            heat = ((self.ax==0) & (self.ay==0)) & ((self.cx==0) & (self.cy==0)) & (self.nu!=0)
        else:
            heat = torch.zeros(self.total_samples).bool()

        if 'ns' in gs:
            ns = (self.visc != 0) & (self.amp != 0)
        else:
            ns = torch.zeros(self.total_samples).bool()

        for g in gs:
            if '>' in g:
                attr, val = g.split('>')
                if attr in ['ax', 'ay']:
                    adv = adv & (getattr(self, attr)>float(val))
                elif attr in ['cx', 'cy']:
                    burger = burger & (getattr(self, attr)>float(val))
                elif attr in ['nu']:
                    burger = burger & (getattr(self, attr)>float(val))
                    heat = heat & (getattr(self, attr)>float(val))
            elif '<' in g:
                attr, val = g.split('<')
                if attr in ['ax', 'ay']:
                    adv = adv & (getattr(self, attr)<float(val))
                elif attr in ['cx', 'cy']:
                    burger = burger & (getattr(self, attr)<float(val))
                elif attr in ['nu']:
                    burger = burger & (getattr(self, attr)<float(val))
                    heat = heat & (getattr(self, attr)<float(val))

        which = heat.to(self.device) | adv.to(self.device) | burger.to(self.device) | ns.to(self.device)
        if reverse:
            which = ~which

        self.indexes = torch.arange(self.total_samples, device=which.device)[which]

        if type(n) is int:
            if n > len(self.indexes):
                print(f"You want {n} samples but there are only {len(self.indexes)} available. Overriding {n} to {len(self.indexes)}")
                self.num_samples = len(self.indexes)
                n = len(self.indexes)

            self.indexes = self.indexes[np.random.choice(len(self.indexes), n, replace=False)]

        # Check number of equations
        eq_dict = {"heat": 0, "adv": 0, "burgers": 0}
        for idx in self.indexes:
            eq = self.get_eq(idx)
            eq_dict[eq] += 1

        print(eq_dict)


    def get_eq(self, idx):
        nu = self.nu[idx]
        cx = self.cx[idx]

        if nu == 0:
            return "adv"
        if cx == 0:
            return "heat"
        else:
            return "burgers"


class BurgersPDEDataset2D(Dataset):
    """Load samples of a 2D PDE Dataset, get items according to PDE"""

    def __init__(self,
                 pde: str,
                 mode: str,
                 resolution: list=None,
                 augmentation = None,
                 augmentation_ratio: float=0.0,
                 shift: str='fourier',
                 load_all: bool=False,
                 device: str='cuda:0',
                 num_samples: int=-1,
                 clip: bool=False,
                 llm: str=None,
                 sentence: bool=False,
                 downsample: int=1,
                 debug: bool=False,
                 subset: str='heat,adv,burger',
                 coeff: bool=True,
                 bcs: bool=False,
                 qualitative: bool=False,
                 seed: int=0
                 ) -> None:
        """Initialize the dataset object
        Args:
            path: path to dataset
            pde: string of PDE 
            mode: [train, valid, test]
            resolution: base resolution of the dataset [nt, nx, ny]
            augmentation: Data augmentation object
            augmentation_ratio: Probability to augment data
            load_all: load all the data into memory
            device: if load_all, load data onto device
        Returns:
            None
        """
        super().__init__()

        print("\nSEED: {}\n".format(seed))
        np.random.seed(seed)
        torch.manual_seed(seed)

        self.mode = mode
        self.pde = pde
        self.downsample = downsample
        self.resolution = (100, 64, 64) if resolution is None else resolution
        self.llm = llm
        self.debug = debug
        self.subset = subset
        self.device = device

        if(pde == 'heat'):
            #path = "/home/PATH_TO_DATA/Burgers_Data_Gen/heat_500/"
            path = "/home/PATH_TO_DATA/Burgers_Data_Gen/another_new_heat_900/"
            #path = "/home/PATH_TO_DATA/Burgers_Data_Gen/periodic_new_heat_120/"
        elif(pde == 'burger'):
            #path = "/home/PATH_TO_DATA/Burgers_Data_Gen/burger_500/"
            path = "/home/PATH_TO_DATA/Burgers_Data_Gen/another_new_burger_900/"
            #path = "/home/PATH_TO_DATA/Burgers_Data_Gen/periodic_new_burger_50/"

        #self.u = torch.Tensor(np.load("{}/all_data.npy".format(path))).unsqueeze(-1)
        self.u = torch.tensor(np.load("{}/all_data.npy".format(path)), device='cpu').unsqueeze(-1)
        self.t = np.load("{}/all_time.npy".format(path))
        self.coeffs = torch.Tensor(np.load("{}/all_coeff.npy".format(path)))
        self.bc_name = np.load("{}/all_bc_name.npy".format(path))
        self.bc_fac = torch.Tensor(np.load("{}/all_bc_fac.npy".format(path)))
        self.ic_name = np.load("{}/all_ic_name.npy".format(path))
        self.ic_fac = torch.Tensor(np.load("{}/all_ic_fac.npy".format(path)))
        self.x = torch.Tensor(np.load("{}/all_grid.npy".format(path)))

        # Different embedding strategies (Maybe rename to be more clear...)
        self.clip = clip                # Use LLM
        self.sentence = sentence        # Return sentences and train LLM end-to-end
        self.coeff = coeff              # Use sentence information
        self.bcs = bcs
        self.qualitative = qualitative  # Include qualitative information

        # TODO: Fix this
        self.nu = self.coeffs[:,0]
        self.cx = self.coeffs[:,1]
        self.cy = self.coeffs[:,2]

        #if(mode == 'train'):
        #    self.num_samples = len(self.u) if(num_samples == -1) else num_samples
        #else:
        #    self.num_samples = 50 if(self.debug) else 768 # Use entire validation set
        self.num_samples = len(self.u) if(num_samples == -1) else num_samples

        #idxs = torch.randperm(self.u.shape[0])

        #####
        ### Selecting various BC/IC combinations for debugging
        #####
        #bc_sub = self.bc_name != 'periodic'
        #bc_sub = self.bc_name == 'derivative'
        #bc_sub = self.bc_name == 'value'
        bc_sub = [True]*len(self.bc_name)

        #ic_sub = self.ic_name == 'exp'
        #ic_sub = self.ic_name != 'prod_sin'
        #ic_sub = self.ic_name == 'sum_sin'

        #ic_sub1 = self.ic_name == 'sum_sin'
        #ic_sub2 = self.ic_name == 'prod_sin'
        #ic_sub = np.logical_or(ic_sub1, ic_sub2)
        ic_sub = [True]*len(bc_sub)

        sub = np.logical_and(ic_sub, bc_sub)

        select_idxs = torch.arange(0, len(self.u))[sub].cpu()
        idxs = select_idxs[torch.randperm(len(select_idxs)).cpu()]

        # 80-20 split
        idxs = idxs[:self.num_samples]
        if(mode == 'train'):
            idxs = idxs[:int(0.8*len(idxs))]
        elif(mode == 'val'):
            idxs = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        else:
            idxs = idxs[int(0.9*len(idxs)):]


        # Slice data for split
        #self.u = self.u[idxs].cuda().permute(0,2,3,1,4)
        self.u = self.u[idxs].permute(0,2,3,1,4)
        print()
        print(self.u.min(), self.u.max())
        print()
        #self.u = self.u[idxs].cuda()#.permute(0,2,3,1,4)
        #self.len = self.u.shape[0]*self.u.shape[1]

        self.coeffs = self.coeffs[idxs].cuda()
        self.nu = self.nu[idxs].cuda()
        self.cx = self.cx[idxs].cuda()
        self.cy = self.cy[idxs].cuda()
        self.bc_name = self.bc_name[idxs.cpu()]
        self.bc_fac = self.bc_fac[idxs]
        self.ic_name = self.ic_name[idxs.cpu()]
        self.ic_fac = self.ic_fac[idxs]

        # Get potentially useful variables from space and time
        self.tmin = self.t[0]
        self.tmax = self.t[-1]
        self.nt = len(self.t)
        self.dt = torch.Tensor([(self.tmax - self.tmin) / (self.nt-1)]*len(self.t)).unsqueeze(0)

        self.xmin = self.x[0, 0, 0]
        self.xmax = self.x[0, 0, -1]
        self.nx = len(self.x[0, 0])
        self.dx = (self.xmax - self.xmin)/ self.nx
        
        self.augmentation = augmentation
        self.shift = shift
        self.augmentation_ratio = augmentation_ratio

        if(self.clip or self.sentence):

            # Only get sentence_embedder if we're not returning whole sentences
            ####if(self.llm is not None and not self.sentence):
            ####    print("LOADING LLM TO GPU")
            ####    self.sentence_embedder = SentenceTransformer(self.llm, device='cuda')
            ####elif(not self.sentence):
            ####    self.sentence_embedder = SentenceTransformer("all-MiniLM-L6-v2", device='cpu')
            if(self.llm == 'meta-llama/Meta-Llama-3.1-8B'):
                self.sentence_embedder = transformers.pipeline(
                    "feature-extraction",
                    #"summarization",
                    model="meta-llama/Meta-Llama-3.1-8B",
                    device="cuda",
                )
            else:
                self.sentence_embedder = SentenceTransformer(self.llm, device='cuda')

            self.sentence_embeddings = []
            self.sentences = []
            print("Getting sentence embeddings...")
            for idx in tqdm(range(len(self.u))):
            #for idx in idxs:
                # Burgers
                if(self.nu[idx] != 0 and self.cx[idx] != 0 and self.cy[idx] != 0):
                    ratio = ((self.cx[idx] + self.cy[idx])**2)**(0.5)/self.nu[idx]
                    sentence = 'Burgers equation models a conservative system that can develop shock wave discontinuities.'
                    sentence += ' Burgers equation is a first order quasilinear hyperbolic partial differential equation.'

                    #TODO Fix this too
                    if(self.bcs):
                        sentence += " This system has {} boundary conditions.".format(self.bc_name[idx])
                        if(self.bc_name[idx] == 'derivative'):
                            sentence += " Neumann boundary conditions have a constant gradient."
                            sentence += " In this case we have a gradient of {} on the boundary.".format(self.bc_fac[idx])
                        elif(self.bc_name[idx] == 'value'):
                            sentence += " Dirichlet boundary conditions have a constant value."
                            sentence += " In this case we have a value of {} on the boundary.".format(self.bc_fac[idx])
                        #elif(self.bc_name[idx] == 'auto_periodic_neumann'):
                        elif(self.bc_name[idx] == 'periodic'):
                            sentence += " The simulation space is a torus."
                        else:
                            raise ValueError("Issue with {} BC".format(self.bc_name[idx]))

                    if(self.coeff):
                        sentence += ' In this case, the advection term has a coefficient of {} in the x direction, '.format(self.cx[idx])
                        sentence += '{} in the y direction, and the diffusion term has a coefficient of {}.'.format(self.cy[idx],
                                                                                                          self.nu[idx])

                    if(self.qualitative):
                        cls = 'advection' if(ratio > 100) else 'diffusion'
                        sim = ' not ' if(ratio > 100) else ' '
                        sentence += ' This system is {} dominated and does{}behave similarly to heat equation.'.format(cls, sim)
                        sentence += ' The predicted state should develop shocks.' if(cls == 'strongly') else \
                                    ' The predicted state should look smoother than the inputs.'

                # Heat
                elif(self.nu[idx] != 0 and self.cx[idx] == 0 and self.cy[idx] == 0):
                    sentence = 'The Heat equation models how a quantity such as heat diffuses through a given region.'
                    sentence += ' The Heat equation is a linear parabolic partial differential equation.'
                    if(self.bcs):
                        sentence += " This system has {} boundary conditions.".format(self.bc_name[idx])
                        if(self.bc_name[idx] == 'derivative'):
                            sentence += " Neumann boundary conditions have a constant gradient."
                            sentence += " In this case we have a gradient of {} on the boundary.".format(self.bc_fac[idx])
                        elif(self.bc_name[idx] == 'value'):
                            sentence += " Dirichlet boundary conditions have a constant value."
                            sentence += " In this case we have a value of {} on the boundary.".format(self.bc_fac[idx])
                        #elif(self.bc_name[idx] == 'auto_periodic_neumann'):
                        elif(self.bc_name[idx] == 'periodic'):
                            sentence += " The simulation space is a torus."
                        else:
                            raise ValueError("Issue with {} BC".format(self.bc_name[idx]))

                    if(self.coeff):
                        sentence += ' In this case, the diffusion term has a coefficient of {}.'.format(self.nu[idx])

                    if(self.qualitative):
                        cls = 'strongly' if(self.nu[idx] > 0.01) else 'weakly'
                        sentence += ' This system is {} diffusive.'.format(cls)
                        sentence += ' The predicted state should look smoother than the inputs.'

                if(self.sentence):
                    self.sentences.append(sentence)
                else:
                    if(self.llm == 'meta-llama/Meta-Llama-3.1-8B'):
                        output = torch.Tensor(self.sentence_embedder(sentence))
                        self.sentence_embeddings.append(output.mean(dim=1))
                    else:
                        self.sentence_embeddings.append(torch.Tensor(self.sentence_embedder.encode(sentence)).unsqueeze(0))
            print("Done.")

            if(not self.sentence):
                self.sentence_embeddings = torch.cat(self.sentence_embeddings, dim=0)
                self.sentence_embeddings = self.sentence_embeddings.cuda()
        self.u = self.u.cuda()
        self.x = self.x.cuda()
        self.coeffs = self.coeffs.cuda()

        try:
            del self.sentence_embedder
        except AttributeError:
            pass
        torch.cuda.empty_cache()
        print("\nNUMBER OF SAMPLES: {}\n".format(len(self.u)))

    def __len__(self):
        return len(self.u)#*self.u.shape[1]

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, list, torch.Tensor]:
        """
        Get data item
        Args:
            idx (int): data index
        Returns:
            torch.Tensor: numerical baseline trajectory
            torch.Tensor: downprojected high-resolution trajectory (used for training)
            torch.Tensor: spatial coordinates
            list: equation specific parameters
        """
        original_idx = idx
        u = self.u[idx]
        x = self.x
        t = self.t
        
        if(self.clip and not self.sentence):
            return_u = torch.cat((u, torch.zeros(u.shape[0], u.shape[1], u.shape[2], 3)), dim=-1)
            return u, \
                   x, \
                   self.coeffs[original_idx], \
                   self.dt[0][0], \
                   self.sentence_embeddings[original_idx]

        elif(self.clip and self.sentence):
            return u, \
                   x, \
                   self.coeffs[original_idx], \
                   self.dt[0][0], \
                   self.sentences[original_idx]
        else:
            # Once again hacked together
            return_u = torch.cat((u, torch.zeros(u.shape[0], u.shape[1], u.shape[2], 3)), dim=-1)
            return u, x, self.coeffs[idx], self.coeffs[idx]

    def get_data(self):
        self.coeffs = self.coeffs[self.indexes]
        try:
            self.sentence_embeddings = torch.Tensor(self.sentence_embeddings)
        except AttributeError:
            self.sentence_embeddings = None
        #print()
        #print(type(self.sentence_embeddings))
        #print()
        #raise
        return self.u[self.indexes].unsqueeze(-1)
    

    def get_PDE(self, variables):
        if variables['ax'] != 0 and variables['ay'] != 0:
            return "advection"
        elif variables["cx"] != 0 and variables["cy"] != 0:
            return "burgers"
        elif variables["nu"] != 0:
            return "heat"
        else:
            raise ValueError("PDE not found")


    def choose_subset(
            self,
            chosen: str = 'heat,adv,burger',
            reverse: bool = False,
            n: int = None,
            ):
        """
        Choose subset of the dataset
        Args:
            chosen: str 
                stringof chosen PDEs and subset of PDE coefficients.
                DO NOT USE ANY SPACES!
                Example:
                    'heat,nu>0.5,adv,ax<0.4,burger,cx<0.3'

                Ranges:
                    nu:
                        - burgers: [7.5e-3, 1.5e-2]
                        - heat: [3e-3, 2e-2]
                    ax, ay: [0.1, 2.5]
                    cx, cy: [0.5, 1.0]

                    
            reverse: bool.
                if True, choose all PDEs except the specified ones
            n: int or None
                number of samples to use from the specified subset
            seed: int
                random seed when choosing n samples (for reproducibility)
        Returns:
            None
        """
        gs = chosen.split(',')

        if 'adv' in gs:
            adv = ((self.ax!=0) | (self.ay!=0)) & ((self.cx==0) & (self.cy==0)) & (self.nu==0)
        else:
            adv = torch.zeros(self.total_samples).bool()

        if 'burger' in gs:
            burger =((self.ax==0) & (self.ay==0)) & ((self.cx!=0) | (self.cy!=0)) & (self.nu!=0)
        else:
            burger = torch.zeros(self.total_samples).bool()

        if 'heat' in gs:
            heat = ((self.ax==0) & (self.ay==0)) & ((self.cx==0) & (self.cy==0)) & (self.nu!=0)
        else:
            heat = torch.zeros(self.total_samples).bool()

        if 'ns' in gs:
            ns = (self.visc != 0) & (self.amp != 0)
        else:
            ns = torch.zeros(self.total_samples).bool()

        for g in gs:
            if '>' in g:
                attr, val = g.split('>')
                if attr in ['ax', 'ay']:
                    adv = adv & (getattr(self, attr)>float(val))
                elif attr in ['cx', 'cy']:
                    burger = burger & (getattr(self, attr)>float(val))
                elif attr in ['nu']:
                    burger = burger & (getattr(self, attr)>float(val))
                    heat = heat & (getattr(self, attr)>float(val))
            elif '<' in g:
                attr, val = g.split('<')
                if attr in ['ax', 'ay']:
                    adv = adv & (getattr(self, attr)<float(val))
                elif attr in ['cx', 'cy']:
                    burger = burger & (getattr(self, attr)<float(val))
                elif attr in ['nu']:
                    burger = burger & (getattr(self, attr)<float(val))
                    heat = heat & (getattr(self, attr)<float(val))

        which = heat.to(self.device) | adv.to(self.device) | burger.to(self.device) | ns.to(self.device)
        if reverse:
            which = ~which

        self.indexes = torch.arange(self.total_samples, device=which.device)[which]

        if type(n) is int:
            if n > len(self.indexes):
                print(f"You want {n} samples but there are only {len(self.indexes)} available. Overriding {n} to {len(self.indexes)}")
                self.num_samples = len(self.indexes)
                n = len(self.indexes)

            self.indexes = self.indexes[np.random.choice(len(self.indexes), n, replace=False)]

        # Check number of equations
        eq_dict = {"heat": 0, "adv": 0, "burgers": 0}
        for idx in self.indexes:
            eq = self.get_eq(idx)
            eq_dict[eq] += 1

        print(eq_dict)


    def get_eq(self, idx):
        nu = self.nu[idx]
        cx = self.cx[idx]

        if nu == 0:
            return "adv"
        if cx == 0:
            return "heat"
        else:
            return "burgers"


class NSDataset2D(Dataset):
    def __init__(self, 
                 initial_step=10,
                 saved_folder='./data/',
                 reduced_resolution=4,
                 reduced_resolution_t=1,
                 reduced_batch=1,
                 sim_time=-1,
                 mode="train",
                 test_ratio=0.1,
                 val_ratio=0.1,
                 num_samples=None,
                 return_text=False,
                 rollout_length=10,
                 split_style='initial_condition',
                 samples_per_equation=12,
                 seed=0,
                 device='cuda:0',
                 clip=None,
                 llm=None,
                 coeff=True,
                 bcs=False,
                 qualitative=False,
                 sentence=False,
                 ):
        """
        
        :param filename: filename that contains the dataset
        :type filename: STR
        :param filenum: array containing indices of filename included in the dataset
        :type filenum: ARRAY
        :param initial_step: time steps taken as initial condition, defaults to 10
        :type initial_step: INT, optional

        """

        # Define path to files
        self.file_path = "/home/PATH_TO_DATA/2D_NS_DATA/2d_ns_30s_256_370eq.h5"
        f = h5py.File(self.file_path, 'r')
        self.return_text = return_text
        self.rollout_length = rollout_length
        self.split_style = split_style
        self.samples_per_equation = samples_per_equation

        # Extract list of seeds
        self.data_list = list(f.keys())

        # Time steps used as initial conditions
        self.initial_step = initial_step

        # Get sentence info
        self.clip = clip
        self.llm = llm
        self.bcs = bcs
        self.coeff = coeff
        self.qualitative = qualitative
        self.sentence = sentence

        self.h5_file = h5py.File(self.file_path, 'r')
        self.sim_time = sim_time

        sample_num = 0
        # Get all indices
        idxs = []
        #TODO this shuffles by EQUATION, need to shuffle by SIMULATION?
        for i in range(len(self.data_list)):
            seed_group = self.h5_file[self.data_list[i]]
            samples_per_sim = seed_group['u'].shape[0]
            for j in range(seed_group['u'].shape[0]):
                idxs.append(i*seed_group['u'].shape[0] + j)

        idxs = [i for i in range(len(self.data_list*12))] 

        np.random.seed(seed)
        np.random.shuffle(idxs)
        self.idxs = idxs[:num_samples]

        # Split indices
        train_idx = int(num_samples * (1 - test_ratio - val_ratio))
        val_idx = int(num_samples * (1 - test_ratio))
        if(mode == 'train'):
            self.idxs = self.idxs[:train_idx]
        elif(mode == 'val'):
            self.idxs = self.idxs[train_idx:val_idx]
        elif(mode == 'test'):
            self.idxs = self.idxs[val_idx:]

        self.x = []
        self.time = []
        self.w0 = []
        self.available_idxs = []
        self.data_list = np.array(self.data_list)

        # Only hold on to only `num_samples` total samples
        self.u = torch.empty(len(self.idxs), 64, 64, 101).float()
        self.coeffs = torch.empty(len(self.idxs), 2).float()

        #for i in tqdm(range(len(self.data_list))):
        for idx, i in tqdm(enumerate(self.idxs)):
            dl_idx = i//12
            tr_idx = i%12

            seed_group = self.h5_file[self.data_list[dl_idx]]

            data = seed_group['u'][:][tr_idx,::reduced_resolution,::reduced_resolution,...]

            # Get extra info
            base_tokens = seed_group['tokens'][:]
            x = seed_group['X'][:][::reduced_resolution,::reduced_resolution,np.newaxis]
            y = seed_group['Y'][:][::reduced_resolution,::reduced_resolution,np.newaxis]
            w0 = seed_group['a'][:][tr_idx,::reduced_resolution,::reduced_resolution,np.newaxis]

            # Add initial condition
            complete_data = np.concatenate((w0, data), axis=2)
            self.u[idx] = torch.Tensor(complete_data)#[:self.samples_per_equation])

            # Add initial time
            time = list(seed_group['t'][:])
            time.insert(0, 0.0)
            self.time.append(time)

            # Get grid
            self.x.append(np.dstack((x,y)))

            # Get coeffs
            split_dl = self.data_list[dl_idx].split("_")
            self.coeffs[idx][0] = float(split_dl[0])
            self.coeffs[idx][1] = float(split_dl[1])

        # Arrange data
        self.u = self.u.unsqueeze(-1).cuda()

        # Grid to tensor
        self.x = torch.Tensor(np.array(self.x))[0].cuda()

        # Time and tokens to tensors
        self.time = torch.Tensor(np.array(self.time)).cuda()
        self.dt = (self.time[0][1] - self.time[0][0]).reshape((1,1,1)).cuda()

        self.h5_file.close()
        self.get_sentences()

        self.u = self.u.cuda()
        self.x = self.x.cuda()
        self.coeffs = self.coeffs.cuda()

        print("DATA SHAPE: {}".format(self.u.shape))
        print("NUM AVAILABLE IDXS: {}".format(len(self.available_idxs)))
        print("NUM IDXS: {}".format(len(self.idxs)))
        print("{} good samples.".format(len(self.u)))


    def get_sentences(self):
        # Only get sentence_embedder if we're not returning whole sentences
        if(self.llm is not None and not self.sentence):
            print("LOADING LLM TO GPU")
            #self.sentence_embedder = SentenceTransformer(self.llm, device='cuda')
            if(self.llm == 'meta-llama/Meta-Llama-3.1-8B'):
                self.sentence_embedder = transformers.pipeline(
                    "feature-extraction",
                    #"summarization",
                    model="meta-llama/Meta-Llama-3.1-8B",
                    device="cuda",
                )
            else:
                self.sentence_embedder = SentenceTransformer(self.llm, device='cuda')
        elif(not self.sentence):
            self.sentence_embedder = SentenceTransformer("all-MiniLM-L6-v2", device='cpu')

        self.sentence_embeddings = []
        self.sentences = []
        print("Getting sentence embeddings...")
        for idx in tqdm(range(len(self.u))):

            sentence = 'The incompressible Navier Stokes equations describe the motion of a viscous fluid with constant density.'
            sentence += ' We are predicting the vorticity field, which describes the local spinning motion of the fluid.'

            #TODO Fix this too
            if(self.bcs):
                sentence += " This system has periodic boundary conditions."
                sentence += " The simulation space is a torus."

            if(self.coeff):
                sentence += ' In this case, the viscosity is {0:.2e}.'.format(self.coeffs[idx][0])
                sentence += ' This system is driven by a forcing term of the form f(x,y) = A*(sin(2*pi*(x+y)) + cos(2*pi*(x+y)))'
                sentence += ' with amplitude A={0:.2e}.'.format(self.coeffs[idx][1])

            if(self.qualitative):
                visc = 1000000*self.coeffs[idx][0]
                amp = 1000 *self.coeffs[idx][1]
                #print(visc, amp)

                # Information based on viscosity
                if(visc >= 1.):
                    sentence += ' This system has high viscosity and will not develop small scale structure.'
                elif(visc >= 0.01):
                    sentence += ' This sytem has moderate viscosity and will have some small scale structure.'
                else:
                    sentence += ' This system has low viscosity and will have chaotic evolution with small scale structure.'

                # Information based on forcing term amplitude
                if(amp >= 7.):
                    sentence += ' This system has a strong forcing term and evolution will be heavily influenced by it.'
                elif(amp >= 3.):
                    sentence += ' This system has a moderate forcing term and evolution will be moderately influenced by it.'
                else:
                    sentence += ' This system has a weak forcing term and evolvution will be weakly influenced by it.'

            if(self.sentence):
                self.sentences.append(sentence)
            else:
                #print(sentence)
                #self.sentence_embeddings.append(torch.Tensor(self.sentence_embedder.encode(sentence)).unsqueeze(0))
                if(self.llm == 'meta-llama/Meta-Llama-3.1-8B'):
                    output = torch.Tensor(self.sentence_embedder(sentence))
                    self.sentence_embeddings.append(output.mean(dim=1))
                else:
                    self.sentence_embeddings.append(torch.Tensor(self.sentence_embedder.encode(sentence)).unsqueeze(0))

        if(not self.sentence):
            self.sentence_embeddings = torch.cat(self.sentence_embeddings, dim=0)
            self.sentence_embeddings = self.sentence_embeddings.cuda()

        del self.sentence_embedder
        torch.cuda.empty_cache()
        print("Done.")


    def __len__(self):
        return self.u.shape[0]


    def __getitem__(self, idx):
        '''
        idx samples the file.
        Need to figure out a way to sample the snapshots within the file...
        '''
        if(self.clip and not self.sentence):
            return self.u[idx], \
                   self.x, \
                   self.coeffs[idx], \
                   self.dt[0][0], \
                   self.sentence_embeddings[idx]
        elif(self.clip and self.sentence):
            return self.u[idx], \
                   self.x, \
                   self.coeffs[idx], \
                   self.dt[0][0], \
                   self.sentences[idx]
        else:
            return self.u[idx], \
                   self.x, \
                   self.coeffs[idx], \
                   self.dt[0][0]


class MultiDataset(Dataset):
    def __init__(self, filenames,
                 initial_step=10,
                 saved_folders=[],
                 reduced_resolution=1,
                 reduced_resolution_t=1,
                 reduced_batch=1,
                 mode='train',
                 if_test=False,
                 test_ratio=0.1,
                 num_samples_max = -1,
                 sim_time = -1,

                 clip = False,
                 llm = None,
                 sentence = False,
                 coeff = False,
                 eq_coeff = False,
                 qualitative = False,
                 time = False,

                 image_size = None,
                 bcs = False,
                 normalize = False,
                 seed = 0
                 ):

        #self.dsets = []
        self.dsets, self.u, self.grids, self.dt, self.coeff = [], [], [], [], []

        self.clip = clip
        self.sentence = sentence
        print("MultiDataset")

        for fname, saved_folder in zip(filenames, saved_folders):
            print(fname, saved_folder)
            d_file = "/home/PATH_TO_DATA/NEW_HeatAdvBurgers_{}_downsampled.h5".format(3072 if(if_test) else 9216)
            s_fac = 0.1 if(if_test) else 0.8
            #mode = 'val' if(if_test) else 'train'
            if(fname == 'navierstokes'):
                self.dsets.append(NSDataset2D(
				                mode=mode,
				                device='cuda:0',
				                num_samples=num_samples_max,
				                clip=clip,
				                llm=llm,
				                bcs=bcs,
                                coeff=eq_coeff,
				                qualitative=qualitative,
				                sentence=sentence,
				                seed=seed,
                ))
            else:
                self.dsets.append(BurgersPDEDataset2D(
                                pde=fname,
                                mode=mode,
                                resolution=[sim_time,64,64],
                                augmentation=[],
                                augmentation_ratio=0.0,
                                shift='None',
                                load_all=False,
                                device='cuda:0',
                                num_samples=num_samples_max,
                                clip=clip,
                                llm=llm,
                                subset=None,
                                bcs=bcs,
                                coeff=eq_coeff,
                                qualitative=qualitative,
                                sentence=sentence,
                                seed=seed,
                ))

            # Need to adjust resolution and timeframe... try down first, then try up if its an issue, I suppose
            #TODO See if interpolating is any better.
            print()
            print(self.dsets[-1].u.min(), self.dsets[-1].u.max())
            print()
            #self.u.append(self.dsets[-1].u)
            self.u.append(self.dsets[-1].u.cpu())
            self.grids.append(self.dsets[-1].x.unsqueeze(0))
            self.dt.append(self.dsets[-1].dt[0][0][0])

            # Heat/Burgers has 3 coefficient values, NS has 2, need to expand to 5
            c = self.dsets[-1].coeffs
            if(self.dsets[-1].coeffs.shape[1] == 3): # Zeros go last
                new_coeff = torch.cat((c, torch.zeros((c.shape[0], 2))), dim=1)
                self.coeff.append(new_coeff)
            else: # Zeros go first
                new_coeff = torch.cat((torch.zeros((c.shape[0], 3)), c), dim=1)
                self.coeff.append(new_coeff)

            del self.dsets[-1].u
            del self.dsets[-1].x
            del self.dsets[-1].coeffs
            del self.dsets[-1].dt

        self.u = torch.cat(self.u, dim=0).cuda()
        self.grids = torch.cat(self.grids, dim=0).cuda()
        self.coeff = torch.cat(self.coeff, dim=0).cuda()
        print(self.u.shape, self.grids.shape, self.coeff.shape)
        #self.u = torch.cat(self.u, dim=0).cpu()
        #self.grids = torch.cat(self.grids, dim=0).cpu()
        #self.coeff = torch.cat(self.coeff, dim=0).cpu()
        #self.dt = self.dt.cpu()
        torch.cuda.empty_cache()


    def __len__(self):
        #return sum(len(d) for d in self.dsets)
        return len(self.u)

    def __getitem__(self, idx):
        dset_idx = idx%len(self.dsets)
        sample_idx = idx//len(self.dsets)
        if(self.clip or self.sentence):
            if(self.sentence):
                return self.u[idx], \
                       self.grids[dset_idx], \
                       self.coeff[idx], \
                       self.dt[dset_idx], \
                       self.dsets[dset_idx].sentences[sample_idx]
                       #self.coeff[dset_idx], \
            else:
                return self.u[idx], \
                       self.grids[dset_idx], \
                       self.coeff[idx], \
                       self.dt[dset_idx], \
                       self.dsets[dset_idx].sentence_embeddings[sample_idx]
                       #self.coeff[dset_idx], \
        else:
            return self.u[idx], \
                   self.grids[dset_idx], \
                   self.coeff[dset_idx], \
                   self.dt[dset_idx]

