import torch
import h5py
from FNOModules import FNO1d, FNO2d
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import random

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)


class LinearAdvection:
    def __init__(self, network_properties, device, batch_size, res=2048):
        self.res = res
        training_inputs, training_outputs = self.get_data(1024, "training")
        testing_inputs, testing_outputs = self.get_data(128, "validation")

        self.model = FNO1d(fno_architecture=network_properties, nfun=1).to(device)

        self.train_loader = DataLoader(TensorDataset(training_inputs, training_outputs), batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(TensorDataset(testing_inputs, testing_outputs), batch_size=batch_size, shuffle=False)

    def get_data(self, n_samples, which="training"):
        step = int(2048 / self.res)
        idx = np.arange(0, 2048, step)
        data_name_file = "LinearAdvection.h5"
        reader = h5py.File('data_benchmarks/' + data_name_file, 'r')
        grid = np.arange(0, 2048) / 2048
        data_inputs = np.zeros((n_samples, idx.shape[0], 2))
        data_outputs = np.zeros((n_samples, idx.shape[0]))
        for i in range(n_samples):
            str_sample = "sample_" + str(i)
            input_fun = reader[which][str_sample]["input"][:]
            output_fun = reader[which][str_sample]["output"][:]
            data_inputs[i, :, 0] = grid[idx]
            data_inputs[i, :, 1] = input_fun[idx]
            data_outputs[i, :] = output_fun[idx]

        return torch.tensor(data_inputs).type(torch.float32), torch.tensor(data_outputs).type(torch.float32)


class Burgers:
    def __init__(self, network_properties, device, batch_size):
        training_inputs, training_outputs = self.get_data(950, "training")
        testing_inputs, testing_outputs = self.get_data(74, "validation")

        self.model = FNO1d(fno_architecture=network_properties, nfun=1).to(device)

        self.train_loader = DataLoader(TensorDataset(training_inputs, training_outputs), batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(TensorDataset(testing_inputs, testing_outputs), batch_size=batch_size, shuffle=False)

    def get_data(self, n_samples, which="training"):
        idx = np.arange(0, 1024, 1)
        data_name_file = "BurgersEquation006_01b.h5"
        reader = h5py.File('data_benchmarks/' + data_name_file, 'r')
        grid = np.arange(0, 1024) / 1024
        data_inputs = np.zeros((n_samples, idx.shape[0], 2))
        data_outputs = np.zeros((n_samples, idx.shape[0]))
        for i in range(n_samples):
            str_sample = "sample_" + str(i)
            input_fun = reader[which][str_sample]["input"][:]
            output_fun = reader[which][str_sample]["output"][:]
            data_inputs[i, :, 0] = grid[idx]
            data_inputs[i, :, 1] = input_fun[idx]
            data_outputs[i, :] = output_fun[idx]

        return torch.tensor(data_inputs).type(torch.float32), torch.tensor(data_outputs).type(torch.float32)


class LaxSod:
    def __init__(self, network_properties, device, batch_size):
        training_inputs, training_outputs = self.get_data(1024, "training")
        testing_inputs, testing_outputs = self.get_data(128, "validation")

        self.model = FNO1d(fno_architecture=network_properties, nfun=3).to(device)

        self.train_loader = DataLoader(TensorDataset(training_inputs, training_outputs), batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(TensorDataset(testing_inputs, testing_outputs), batch_size=batch_size, shuffle=False)

    def get_data(self, n_samples, which="training"):
        data_name_file = "LaxSodShockTube.h5"
        reader = h5py.File('data_benchmarks/' + data_name_file, 'r')
        grid = np.linspace(-5, 5, 2048)
        data_inputs = np.zeros((n_samples, 2048, 4))
        data_outputs = np.zeros((n_samples, 2048))
        for i in range(n_samples):
            str_sample = "sample_" + str(i)
            input_fun = reader[which][str_sample]["input"][:]
            output_fun = reader[which][str_sample]["output"][:]
            data_inputs[i, :, 0] = grid
            data_inputs[i, :, 1:] = input_fun
            data_outputs[i, :] = output_fun[:, -1]

        return torch.tensor(data_inputs).type(torch.float32), torch.tensor(data_outputs).type(torch.float32)


class Riemann:
    def __init__(self, network_properties, device, batch_size):
        training_inputs, training_outputs = self.get_data(1024, "training")
        testing_inputs, testing_outputs = self.get_data(128, "validation")

        self.model = FNO2d(fno_architecture=network_properties, device=device).to(device)

        self.train_loader = DataLoader(TensorDataset(training_inputs, training_outputs), batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(TensorDataset(testing_inputs, testing_outputs), batch_size=batch_size, shuffle=False)

    def get_data(self, n_samples, which="training"):
        data_name_file_0 = "0_Riemann30LR.h5"
        data_name_file_1 = "1_Riemann30LR.h5"
        data_name_file_2 = "2_Riemann30LR.h5"
        data_name_file_3 = "3_Riemann30LR.h5"
        reader_0 = h5py.File('data_benchmarks/' + data_name_file_0, 'r')
        reader_1 = h5py.File('data_benchmarks/' + data_name_file_1, 'r')
        reader_2 = h5py.File('data_benchmarks/' + data_name_file_2, 'r')
        reader_3 = h5py.File('data_benchmarks/' + data_name_file_3, 'r')

        data_inputs = np.zeros((n_samples, 256, 256, 4))
        data_outputs = np.zeros((n_samples, 256, 256, 1))

        for i in range(n_samples):
            str_sample = "sample_" + str(i)
            input_fun_rho = reader_0[which][str_sample]["input"][:].reshape(256, 256, 1)
            input_fun_mx = reader_1[which][str_sample]["input"][:].reshape(256, 256, 1)
            input_fun_my = reader_2[which][str_sample]["input"][:].reshape(256, 256, 1)
            input_fun_E = reader_3[which][str_sample]["input"][:].reshape(256, 256, 1)

            input_fun = np.concatenate((input_fun_rho, input_fun_mx, input_fun_my, input_fun_E), -1)

            output_fun_E = reader_3[which][str_sample]["output"][:].reshape(256, 256, 1)

            data_inputs[i, :, :, :] = input_fun
            data_outputs[i, :, :, :] = output_fun_E

        return torch.tensor(data_inputs).type(torch.float32), torch.tensor(data_outputs).type(torch.float32)
