import copy
import json
import os
import sys

import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from FNODataClass import Burgers, LinearAdvection, LaxSod, Riemann

if len(sys.argv) == 1 or len(sys.argv) == 2:

    training_properties = {
        "batch_size": 10,
        "epochs": 1,
        "learning_rate": 5e-4,
        "retrain": 4,
    }
    fno_architecture_ = {
        "width": 64,
        "modes": 16,
        "n_layers": 3,
    }
    which_example = "riemann"
    folder = "TestFNO"

elif len(sys.argv) == 4:
    folder = "TrainedModels/LinearAdvectionModes/" + sys.argv[3]
    training_properties = {
        "batch_size": 10,
        "epochs": 10000,
        "learning_rate": 5e-4,
        "retrain": 42,
    }
    fno_architecture_ = {
        "width": int(sys.argv[1]),
        "modes": int(sys.argv[2]),
        "n_layers": 3,
    }
    which_example = "advection"
else:
    folder = sys.argv[1]
    training_properties = json.loads(sys.argv[2])
    fno_architecture_ = json.loads(sys.argv[3])
    which_example = sys.argv[4]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter(log_dir=folder)

learning_rate = training_properties["learning_rate"]
epochs = training_properties["epochs"]
batch_size = training_properties["batch_size"]
fno_architecture_["retrain"] = training_properties["retrain"]

if which_example == "advection":
    if len(sys.argv) == 2:
        example = LinearAdvection(fno_architecture_, device, batch_size, res=int(sys.argv[1]))
    else:
        example = LinearAdvection(fno_architecture_, device, batch_size)
elif which_example == "burgers":
    example = Burgers(fno_architecture_, device, batch_size)
elif which_example == "shocktube":
    example = LaxSod(fno_architecture_, device, batch_size)
elif which_example == "riemann":
    example = Riemann(fno_architecture_, device, batch_size)
else:
    raise ValueError("the variable which_example has to be one between burgers, advection, shocktube, riemann")

if not os.path.isdir(folder):
    print("Generated new folder")
    os.mkdir(folder)

df = pd.DataFrame.from_dict([training_properties]).T
df.to_csv(folder + '/training_properties.txt', header=False, index=True, mode='w')
df = pd.DataFrame.from_dict([fno_architecture_]).T
df.to_csv(folder + '/net_architecture.txt', header=False, index=True, mode='w')

model = example.model
n_params = model.print_size()
train_loader = example.train_loader
test_loader = example.test_loader

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=1)

loss = torch.nn.L1Loss()
freq_print = 1
p = 1
best_model_testing_error = 100
for epoch in range(epochs):
    with tqdm(unit="batch", disable=True) as tepoch:
        model.train()
        tepoch.set_description(f"Epoch {epoch}")
        train_mse = 0.0
        running_relative_train_mse = 0.0
        for step, (input_batch, output_batch) in enumerate(train_loader):
            optimizer.zero_grad()
            input_batch = input_batch.to(device)
            output_batch = output_batch.to(device)

            output_pred_batch = model(input_batch)

            loss_f = loss(output_pred_batch, output_batch) / loss(torch.zeros_like(output_batch).to(device), output_batch)

            loss_f.backward()
            optimizer.step()
            train_mse = train_mse * step / (step + 1) + loss_f.item() / (step + 1)
            tepoch.set_postfix({'Batch': step + 1, 'Train loss (in progress)': train_mse})

        writer.add_scalar("train_loss/train_loss", train_mse, epoch)

        with torch.no_grad():
            model.eval()
            test_relative_l2 = 0.0
            for step, (input_batch, output_batch) in enumerate(test_loader):
                input_batch = input_batch.to(device)
                output_batch = output_batch.to(device)
                output_pred_batch = model(input_batch)
                loss_f = (torch.mean(abs(output_pred_batch - output_batch) ** p) / torch.mean(abs(output_batch) ** p)) ** (1 / p) * 100
                test_relative_l2 += loss_f.item()
            test_relative_l2 /= len(test_loader)

            writer.add_scalar("val_loss/val_loss", test_relative_l2, epoch)

            if test_relative_l2 < best_model_testing_error:
                best_model_testing_error = test_relative_l2
                best_model = copy.deepcopy(model)
                torch.save(best_model, folder + "/model.pkl")
                writer.add_scalar("val_loss/Best Relative Testing Error", best_model_testing_error, epoch)

        tepoch.set_postfix({'Train loss': train_mse, "Relative Val loss": test_relative_l2})
        tepoch.close()

        with open(folder + '/errors.txt', 'w') as file:
            file.write("Training Error: " + str(train_mse) + "\n")
            file.write("Best Testing Error: " + str(best_model_testing_error) + "\n")
            file.write("Current Epoch: " + str(epoch) + "\n")
            file.write("Params: " + str(n_params) + "\n")
        scheduler.step()
