import torch
import h5py
from FeedForwardNetModules import BranchNet, TrunkNet, DeepOnet2, BranchNetConv1D
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import random

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)


class LinearAdvection:
    def __init__(self, trunk_properties, branch_properties, device, batch_size, n_out, num_sens, n_basis):
        self.true_size = 2048
        idx = np.sort(np.random.choice(np.arange(0, self.true_size), n_out, replace=False))
        idx_val = np.sort(np.random.choice(np.arange(0, self.true_size), 1024, replace=False))

        training_inputs, training_outputs = self.get_data(1024, idx_=idx, which="training", num_sensor=num_sens)
        testing_inputs, testing_outputs = self.get_data(128, idx_=idx_val, which="validation", num_sensor=num_sens)

        grid = torch.linspace(0, 1, self.true_size).unsqueeze(1).to(device)
        grid_val = torch.linspace(0, 1, self.true_size).unsqueeze(1).to(device)

        branch = BranchNet(training_inputs.shape[1], n_basis, network_architecture=branch_properties)
        trunk = TrunkNet(1, n_basis, network_architecture=trunk_properties)

        self.model = DeepOnet2(branch, trunk).to(device)
        self.train_loader = DataLoader(TensorDataset(training_inputs, training_outputs), batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(TensorDataset(testing_inputs, testing_outputs), batch_size=batch_size, shuffle=False)
        self.grid = grid[idx]
        self.grid_val = grid_val[idx_val]

    def get_data(self, n_samples, idx_, which="training", num_sensor=1024):
        data_name_file = "LinearAdvection.h5"
        reader = h5py.File('data_benchmarks/' + data_name_file, 'r')

        data_inputs = np.zeros((n_samples, num_sensor))
        data_outputs = np.zeros((n_samples, idx_.shape[0]))
        for i in range(n_samples):
            str_sample = "sample_" + str(i)
            input_fun = reader[which][str_sample]["sensor_samples"]["uniform"][str(num_sensor)]["sensor_values"][:]
            output_fun = reader[which][str_sample]["output"][idx_]
            data_inputs[i, :] = input_fun
            data_outputs[i, :] = output_fun

        return torch.tensor(data_inputs).type(torch.float32), torch.tensor(data_outputs).type(torch.float32)


class Burgers:
    def __init__(self, trunk_properties, branch_properties, device, batch_size, n_out, num_sens, n_basis):
        self.true_size = 1024
        idx = np.sort(np.random.choice(np.arange(0, self.true_size), n_out, replace=False))
        idx_val = np.sort(np.random.choice(np.arange(0, self.true_size), 1024, replace=False))

        training_inputs, training_outputs = self.get_data(950, idx_=idx, which="training", num_sensor=num_sens)
        testing_inputs, testing_outputs = self.get_data(74, idx_=idx_val, which="validation", num_sensor=num_sens)

        grid = torch.linspace(0, 1, self.true_size).unsqueeze(1).to(device)
        grid_val = torch.linspace(0, 1, self.true_size).unsqueeze(1).to(device)

        branch = BranchNet(training_inputs.shape[1], n_basis, network_architecture=branch_properties)
        trunk = TrunkNet(1, n_basis, network_architecture=trunk_properties)

        self.model = DeepOnet2(branch, trunk).to(device)
        self.train_loader = DataLoader(TensorDataset(training_inputs, training_outputs), batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(TensorDataset(testing_inputs, testing_outputs), batch_size=batch_size, shuffle=False)
        self.grid = grid[idx]
        self.grid_val = grid_val[idx_val]

    def get_data(self, n_samples, idx_, which="training", num_sensor=1024):
        data_name_file = "BurgersEquation006_01b.h5"
        reader = h5py.File('data_benchmarks/' + data_name_file, 'r')

        data_inputs = np.zeros((n_samples, num_sensor))
        data_outputs = np.zeros((n_samples, idx_.shape[0]))
        for i in range(n_samples):
            str_sample = "sample_" + str(i)
            input_fun = reader[which][str_sample]["sensor_samples"]["uniform"][str(num_sensor)]["sensor_values"][:]
            output_fun = reader[which][str_sample]["output"][idx_]
            data_inputs[i, :] = input_fun
            data_outputs[i, :] = output_fun

        return torch.tensor(data_inputs).type(torch.float32), torch.tensor(data_outputs).type(torch.float32)


class LaxSod:
    def __init__(self, trunk_properties, branch_properties, device, batch_size, n_out, num_sens, n_basis):
        self.true_size = 2048
        idx = np.sort(np.random.choice(np.arange(0, self.true_size), n_out, replace=False))
        idx_val = np.sort(np.random.choice(np.arange(0, self.true_size), 1024, replace=False))

        training_inputs, training_outputs = self.get_data(1024, idx_=idx, which="training", num_sensor=num_sens)
        testing_inputs, testing_outputs = self.get_data(128, idx_=idx_val, which="validation", num_sensor=num_sens)

        grid = torch.linspace(-5, 5, self.true_size).unsqueeze(1).to(device)
        grid_val = torch.linspace(-5, 5, self.true_size).unsqueeze(1).to(device)

        branch = BranchNetConv1D(training_inputs.shape[1], n_basis, network_architecture=branch_properties)
        trunk = TrunkNet(1, n_basis, network_architecture=trunk_properties)

        self.model = DeepOnet2(branch, trunk).to(device)

        self.train_loader = DataLoader(TensorDataset(training_inputs, training_outputs), batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(TensorDataset(testing_inputs, testing_outputs), batch_size=batch_size, shuffle=False)
        self.grid = grid[idx]
        self.grid_val = grid_val[idx_val]

    def get_data(self, n_samples, idx_, which="training", num_sensor=2048):
        data_name_file = "LaxSodShockTube.h5"
        reader = h5py.File('data_benchmarks/' + data_name_file, 'r')

        data_inputs = np.zeros((n_samples, num_sensor, 3))
        data_outputs = np.zeros((n_samples, idx_.shape[0]))
        for i in range(n_samples):
            str_sample = "sample_" + str(i)
            input_fun = reader[which][str_sample]["sensor_samples"]["uniform"][str(num_sensor)]["sensor_values"][:]
            output_fun = reader[which][str_sample]["output"][idx_]
            data_inputs[i, :] = input_fun
            data_outputs[i, :] = output_fun[:, -1]

        return torch.tensor(data_inputs).type(torch.float32).permute(0, 2, 1), torch.tensor(data_outputs).type(torch.float32)  # .permute(0,2,1)
