import copy
import json
import os
import random
import sys

import h5py
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from FeedForwardNetModules import FeedForwardNN, TrunkNet, BranchNetConv, ShiftDeepONet2D

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_data(n_samples, idx_, which="training", num_sensor=64):
    data_name_file_0 = "0_Riemann30LR.h5"
    data_name_file_1 = "1_Riemann30LR.h5"
    data_name_file_2 = "2_Riemann30LR.h5"
    data_name_file_3 = "3_Riemann30LR.h5"
    reader_0 = h5py.File('data_benchmarks/' + data_name_file_0, 'r')
    reader_1 = h5py.File('data_benchmarks/' + data_name_file_1, 'r')
    reader_2 = h5py.File('data_benchmarks/' + data_name_file_2, 'r')
    reader_3 = h5py.File('data_benchmarks/' + data_name_file_3, 'r')

    sample_group_0 = reader_0[which]
    sample_group_1 = reader_1[which]
    sample_group_2 = reader_2[which]
    sample_group_3 = reader_3[which]

    data_inputs = np.zeros((n_samples, 4, num_sensor, num_sensor))
    data_outputs = np.zeros((n_samples, idx_.shape[0]))
    sensor_values_path = 'sensor_samples/uniform' + '/' + str(num_sensor)
    for i in range(n_samples):
        sample_path = 'sample_' + str(i)
        input_fun = np.empty((4, num_sensor, num_sensor))
        sample_group_0[sample_path][sensor_values_path + str("/sensor_values")].read_direct(input_fun[0, :, :])
        sample_group_1[sample_path][sensor_values_path + str("/sensor_values")].read_direct(input_fun[1, :, :])
        sample_group_2[sample_path][sensor_values_path + str("/sensor_values")].read_direct(input_fun[2, :, :])
        sample_group_3[sample_path][sensor_values_path + str("/sensor_values")].read_direct(input_fun[3, :, :])
        data_inputs[i, :, :, :] = input_fun
        output_fun_E = sample_group_3[sample_path]["output"][:]
        data_outputs[i, :] = output_fun_E[idx_]

    size_x = size_y = 256
    gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
    gridx = gridx.reshape(1, 1, size_x, 1).repeat([n_samples, size_y, 1, 1])
    gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
    gridy = gridy.reshape(1, size_y, 1, 1).repeat([n_samples, 1, size_x, 1])
    grid = torch.cat((gridy, gridx), dim=-1).numpy()
    grid = grid.reshape(n_samples, size_x * size_y, 2)

    return torch.tensor(data_inputs).type(torch.float32), torch.tensor(data_outputs).type(torch.float32), torch.tensor(grid[0, idx_, :]).type(torch.float32)


if len(sys.argv) == 1:
    folder = "TestSDON2D"
    training_properties = {
        "batch_size": 10,
        "epochs": 1,
        "learning_rate": 5e-4,
        "retrain": 42,
        "num_sensor": 256,
        "n_out": 200 * 200
    }
    trunk_architecture = {
        "n_hidden_layers": 6,
        "neurons": 256,
        "act_string": "softsign",
        "dropout_rate": 0.0,
        "n_basis": 32
    }

    branch_architecture = {
        "n_hidden_layers": 4,
        "neurons": 32,
        "act_string": "softsign",
        "dropout_rate": 0.0,
        "kernel_size": 3
    }

else:
    training_properties = json.loads(sys.argv[2])

    branch_architecture = json.loads(sys.argv[3])

    trunk_architecture = json.loads(sys.argv[4])

    folder = sys.argv[1]

if not os.path.isdir(folder):
    print("Generated new folder")
    os.mkdir(folder)

df = pd.DataFrame.from_dict([training_properties]).T
df.to_csv(folder + '/training_properties.txt', header=False, index=True, mode='w')

df = pd.DataFrame.from_dict([branch_architecture]).T
df.to_csv(folder + '/branch_architecture.txt', header=False, index=True, mode='w')

df = pd.DataFrame.from_dict([trunk_architecture]).T
df.to_csv(folder + '/trunk_architecture.txt', header=False, index=True, mode='w')

n_basis = trunk_architecture["n_basis"]
learning_rate = training_properties["learning_rate"]
epochs = training_properties["epochs"]
batch_size = training_properties["batch_size"]
n_out = training_properties["n_out"]
num_sens = training_properties["num_sensor"]

trunk_architecture["retrain"] = training_properties["retrain"]
branch_architecture["retrain"] = training_properties["retrain"]

shift_scale_architecture = branch_architecture
true_size = 256 * 256
idx = np.sort(np.random.choice(np.arange(0, true_size), n_out, replace=False))
idx_val = np.sort(np.random.choice(np.arange(0, true_size), true_size, replace=False))

training_inputs, training_outputs, grid = get_data(1024, idx_=idx, which="training", num_sensor=num_sens)
testing_inputs, testing_outputs, grid_val = get_data(128, idx_=idx_val, which="validation", num_sensor=num_sens)

if not os.path.isdir(folder):
    print("Generated new folder")
    os.mkdir(folder)

branch = BranchNetConv(training_inputs.shape[1], n_basis, network_architecture=branch_architecture)
trunk = TrunkNet(2 * n_basis, n_basis, network_architecture=trunk_architecture)
scale_net = BranchNetConv(training_inputs.shape[1], n_basis, network_architecture=shift_scale_architecture)
shift_net = BranchNetConv(training_inputs.shape[1], n_basis, network_architecture=shift_scale_architecture)
bias_net = FeedForwardNN(1, n_basis, network_architecture=trunk_architecture)

model = ShiftDeepONet2D(branch, trunk, shift_net, scale_net, bias_net, n_basis, device)
model = model.to(device)
# n_params = model.print_size()
n_params = 0.
writer = SummaryWriter(log_dir=folder)

batch_acc = 1
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(training_inputs, training_outputs), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(testing_inputs, testing_outputs), batch_size=batch_size, shuffle=False)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999, verbose=True)

loss = torch.nn.L1Loss()
freq_print = 1
p = 1
best_model_testing_error = 100

for epoch in range(epochs):
    with tqdm(unit="batch", disable=True) as tepoch:
        model.train()
        tepoch.set_description(f"Epoch {epoch}")
        train_mse = 0.0
        running_relative_train_mse = 0.0
        for step, (input_batch, output_batch) in enumerate(train_loader):
            optimizer.zero_grad()
            input_batch = input_batch.to(device)
            output_batch = output_batch.to(device)

            output_pred_batch = model(input_batch, grid)
            loss_f = loss(output_pred_batch, output_batch) / loss(torch.zeros_like(output_batch).to(device), output_batch)
            loss_f.backward()

            train_mse = train_mse * step / (step + 1) + loss_f.item() / (step + 1)
            tepoch.set_postfix({'Batch': step + 1, 'Train loss (in progress)': train_mse})

            optimizer.step()
            optimizer.zero_grad()

        writer.add_scalar("train_loss/train_loss", train_mse, epoch)

        with torch.no_grad():
            model.eval()
            test_relative_l2 = 0.0
            for step, (input_batch, output_batch) in enumerate(test_loader):
                input_batch = input_batch.to(device)
                output_batch = output_batch.to(device)
                output_pred_batch = model(input_batch, grid_val)

                loss_f = (torch.mean(abs(output_pred_batch - output_batch) ** p) / torch.mean(abs(output_batch) ** p)) ** (1 / p) * 100
                test_relative_l2 += loss_f.item()
            test_relative_l2 /= len(test_loader)

            writer.add_scalar("val_loss/val_loss", test_relative_l2, epoch)

            if test_relative_l2 < best_model_testing_error:
                model = model.eval()
                best_model_testing_error = test_relative_l2
                best_model = copy.deepcopy(model)
                torch.save(best_model, folder + "/model.pkl")
                writer.add_scalar("val_loss/Best Relative Testing Error", best_model_testing_error, epoch)

        tepoch.set_postfix({'Train loss': train_mse, "Relative Val loss": test_relative_l2})
        tepoch.close()

        with open(folder + '/errors.txt', 'w') as file:
            file.write("Training Error: " + str(train_mse) + "\n")
            file.write("Best Testing Error: " + str(best_model_testing_error) + "\n")
            file.write("Current Epoch: " + str(epoch) + "\n")
            file.write("Params: " + str(0) + "\n")
        scheduler.step()
        if scheduler.get_last_lr()[0] < 1e-5:
            scheduler.gamma = 1
