import copy
import json
import os
import random
import sys

import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from ResnetDataClass import LinearAdvection, Burgers, LaxSod, Riemann

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

if len(sys.argv) == 1:
    folder = "TestFNN"
    training_properties = {
        "batch_size": 25,
        "epochs": 1,
        "learning_rate": 1e-4,
        "retrain": 42
    }
    net_architecture = {
        "n_hidden_layers": 6,
        "neurons": 128,
        "act_string": "tanh",
        "dropout_rate": 0.0,
    }
    which_example = "riemann"

else:
    folder = sys.argv[1]
    training_properties = json.loads(sys.argv[2])
    net_architecture = json.loads(sys.argv[3])
    which_example = sys.argv[4]

learning_rate = training_properties["learning_rate"]
epochs = training_properties["epochs"]
batch_size = training_properties["batch_size"]
net_architecture["retrain"] = training_properties["retrain"]

if not os.path.isdir(folder):
    print("Generated new folder")
    os.mkdir(folder)

writer = SummaryWriter(log_dir=folder)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if which_example == "advection":
    example = LinearAdvection(net_architecture, device, batch_size)
elif which_example == "burgers":
    example = Burgers(net_architecture, device, batch_size)
elif which_example == "shocktube":
    example = LaxSod(net_architecture, device, batch_size)
elif which_example == "riemann":
    example = Riemann(net_architecture, device, batch_size)
else:
    raise ValueError("the variable which_example has to be one between burgers, advection, shocktube, riemann")

df = pd.DataFrame.from_dict([training_properties]).T
df.to_csv(folder + '/training_properties.txt', header=False, index=True, mode='w')
df = pd.DataFrame.from_dict([net_architecture]).T
df.to_csv(folder + '/net_architecture.txt', header=False, index=True, mode='w')

model = example.model
n_params = model.print_size()
train_loader = example.train_loader
test_loader = example.test_loader

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.999)

loss = torch.nn.L1Loss()
freq_print = 1
p = 1
best_model_testing_error = 100
for epoch in range(epochs):
    with tqdm(unit="batch", disable=True) as tepoch:
        model.train()
        tepoch.set_description(f"Epoch {epoch}")
        train_mse = 0.0
        running_relative_train_mse = 0.0
        for step, (input_batch, output_batch) in enumerate(train_loader):
            optimizer.zero_grad()
            input_batch = input_batch.to(device)
            output_batch = output_batch.to(device)

            output_pred_batch = model(input_batch)
            loss_f = loss(output_pred_batch, output_batch) / loss(torch.zeros_like(output_batch).to(device), output_batch)
            loss_f.backward()
            optimizer.step()
            train_mse = train_mse * step / (step + 1) + loss_f.item() / (step + 1)
            tepoch.set_postfix({'Batch': step + 1, 'Train loss (in progress)': train_mse})

        writer.add_scalar("train_loss/train_loss", train_mse, epoch)

        with torch.no_grad():
            model.eval()
            test_relative_l2 = 0.0
            for step, (input_batch, output_batch) in enumerate(test_loader):
                input_batch = input_batch.to(device)
                output_batch = output_batch.to(device)
                output_pred_batch = model(input_batch)

                loss_f = (torch.mean(abs(output_pred_batch - output_batch) ** p) / torch.mean(abs(output_batch) ** p)) ** (1 / p) * 100
                test_relative_l2 += loss_f.item()
            test_relative_l2 /= len(test_loader)

            writer.add_scalar("val_loss/val_loss", test_relative_l2, epoch)

            if test_relative_l2 < best_model_testing_error:
                model = model.eval()
                best_model_testing_error = test_relative_l2
                best_model = copy.deepcopy(model)
                torch.save(best_model, folder + "/model.pkl")
                writer.add_scalar("val_loss/Best Relative Testing Error", best_model_testing_error, epoch)

        tepoch.set_postfix({'Train loss': train_mse, "Relative Val loss": test_relative_l2})
        tepoch.close()

        with open(folder + '/errors.txt', 'w') as file:
            file.write("Training Error: " + str(train_mse) + "\n")
            file.write("Best Testing Error: " + str(best_model_testing_error) + "\n")
            file.write("Current Epoch: " + str(epoch) + "\n")
            file.write("Params: " + str(n_params) + "\n")
        scheduler.step()
        if scheduler.get_last_lr()[0] < 1e-5:
            scheduler.gamma = 1
