import copy
import json
import os
import random
import sys

import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from DONDataClass import LinearAdvection, Burgers, LaxSod

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

if len(sys.argv) == 1:
    training_properties = {
        "batch_size": 10,
        "epochs": 1,
        "learning_rate": 5e-4,
        "retrain": 4,
        "num_sensor": 64,
        "n_out": 64
    }
    trunk_architecture = {
        "n_hidden_layers": 6,
        "neurons": 256,
        "act_string": "leaky_relu",
        "dropout_rate": 0.0,
        "n_basis": 200
    }

    branch_architecture = {
        "n_hidden_layers": 3,
        "neurons": 256,
        "act_string": "leaky_relu",
        "dropout_rate": 0.0,
        "kernel_size": 3,

    }
    folder = "TestDON"
    which_example = "shocktube"
elif len(sys.argv) == 2:
    training_properties = {
        "batch_size": 10,
        "epochs": 10000,
        "learning_rate": 5e-4,
        "retrain": 4,
        "num_sensor": int(sys.argv[1]),
        "n_out": int(sys.argv[1])
    }
    trunk_architecture = {
        "n_hidden_layers": 6,
        "neurons": 256,
        "act_string": "leaky_relu",
        "dropout_rate": 0.0,
        "n_basis": 200
    }

    branch_architecture = {
        "n_hidden_layers": 3,
        "neurons": 256,
        "act_string": "leaky_relu",
        "dropout_rate": 0.0,
        "kernel_size": 3,

    }
    folder = "ResolutionStudyDON" + str(training_properties["n_out"])
    which_example = "advection"
else:
    training_properties = json.loads(sys.argv[2])

    branch_architecture = json.loads(sys.argv[3])

    trunk_architecture = json.loads(sys.argv[4])

    folder = sys.argv[1]

    which_example = sys.argv[5]

if not os.path.isdir(folder):
    print("Generated new folder")
    os.mkdir(folder)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter(log_dir=folder)

n_basis = trunk_architecture["n_basis"]
learning_rate = training_properties["learning_rate"]
epochs = training_properties["epochs"]
batch_size = training_properties["batch_size"]
n_out = training_properties["n_out"]
num_sens = training_properties["num_sensor"]

trunk_architecture["retrain"] = training_properties["retrain"]
branch_architecture["retrain"] = training_properties["retrain"]

if which_example == "advection":
    example = LinearAdvection(trunk_architecture, branch_architecture, device, batch_size, n_out, num_sens, n_basis)
elif which_example == "burgers":
    example = Burgers(trunk_architecture, branch_architecture, device, batch_size, n_out, num_sens, n_basis)
elif which_example == "shocktube":
    example = LaxSod(trunk_architecture, branch_architecture, device, batch_size, n_out, num_sens, n_basis)
else:
    raise ValueError("the variable which_example has to be one between burgers, advection, shocktube")

df = pd.DataFrame.from_dict([training_properties]).T
df.to_csv(folder + '/training_properties.txt', header=False, index=True, mode='w')

df = pd.DataFrame.from_dict([branch_architecture]).T
df.to_csv(folder + '/branch_architecture.txt', header=False, index=True, mode='w')

df = pd.DataFrame.from_dict([trunk_architecture]).T
df.to_csv(folder + '/trunk_architecture.txt', header=False, index=True, mode='w')

model = example.model

train_loader = example.train_loader
test_loader = example.test_loader
grid = example.grid
grid_val = example.grid_val
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.999)

loss = torch.nn.L1Loss()
freq_print = 1
p = 1
best_model_testing_error = 100
for epoch in range(epochs):
    with tqdm(unit="batch", disable=True) as tepoch:
        model.train()
        tepoch.set_description(f"Epoch {epoch}")
        train_mse = 0.0
        running_relative_train_mse = 0.0
        for step, (input_batch, output_batch) in enumerate(train_loader):
            optimizer.zero_grad()
            input_batch = input_batch.to(device)
            output_batch = output_batch.to(device)

            output_pred_batch = model(input_batch, grid)
            loss_f = loss(output_pred_batch, output_batch) / loss(torch.zeros_like(output_batch).to(device), output_batch)
            loss_f.backward()
            optimizer.step()
            train_mse = train_mse * step / (step + 1) + loss_f.item() / (step + 1)
            tepoch.set_postfix({'Batch': step + 1, 'Train loss (in progress)': train_mse})

        writer.add_scalar("train_loss/train_loss", train_mse, epoch)

        with torch.no_grad():
            model.eval()
            test_relative_l2 = 0.0
            for step, (input_batch, output_batch) in enumerate(test_loader):
                input_batch = input_batch.to(device)
                output_batch = output_batch.to(device)
                output_pred_batch = model(input_batch, grid_val)

                loss_f = (torch.mean(abs(output_pred_batch - output_batch) ** p) / torch.mean(abs(output_batch) ** p)) ** (1 / p) * 100
                test_relative_l2 += loss_f.item()
            test_relative_l2 /= len(test_loader)

            writer.add_scalar("val_loss/val_loss", test_relative_l2, epoch)

            if test_relative_l2 < best_model_testing_error:
                model = model.eval()
                best_model_testing_error = test_relative_l2
                best_model = copy.deepcopy(model)
                torch.save(best_model, folder + "/model.pkl")
                writer.add_scalar("val_loss/Best Relative Testing Error", best_model_testing_error, epoch)

        tepoch.set_postfix({'Train loss': train_mse, "Relative Val loss": test_relative_l2})
        tepoch.close()

        with open(folder + '/errors.txt', 'w') as file:
            file.write("Training Error: " + str(train_mse) + "\n")
            file.write("Best Testing Error: " + str(best_model_testing_error) + "\n")
            file.write("Current Epoch: " + str(epoch) + "\n")
            file.write("Params: " + str(0) + "\n")
        scheduler.step()
        if scheduler.get_last_lr()[0] < 1e-5:
            scheduler.gamma = 1
