import os
import argparse

import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm import trange

from model.configs.train_defaults import get_cfg_defaults
from utils.eval_utils import show_compared_params
from model.pdexplain_model import PDExplain
from model.dataset import PDEDataset
import utils


def eval_model(config, model, dataloader, epoch, show_fig, path, device):
    with torch.no_grad():
        for batch in dataloader:
            f_data = batch[0].to(device)
            t = batch[1].to(device)
            gt_params = batch[2].to(device)
            gt_params = torch.cat((gt_params, -1.0 * f_data[:, 1].unsqueeze(2)), dim=2)
            context = batch[3].to(device)

            x_data = torch.FloatTensor(dataloader.dataset.x).to(device)
            recon_context, pred_params = model(t, f_data[:, 1], context)
            test_pde_loss = model.loss_func(f_data, x_data, dataloader.dataset.delta_t, dataloader.dataset.delta_x, pred_params)
            ae_recon_loss = calc_ae_recon_loss(context, recon_context)
            params_error = ((gt_params.to(device) - pred_params) ** 2).mean()

            if show_fig:
                sel_sample = 0
                t = 80
                show_compared_params(pred_params[sel_sample + t].cpu().numpy(), gt_params[sel_sample + t].cpu().numpy(), title=f'Epoch {epoch}', show=True)

    return test_pde_loss, ae_recon_loss, params_error


def calc_ae_recon_loss(context, pred_context, loss_type='l2'):
    if loss_type == 'l1':
        loss = torch.abs(context - pred_context).mean()
    elif loss_type == 'l2':
        loss = ((context - pred_context) ** 2).mean()
    else:
        raise ValueError(f'loss_type can be only l1 / l2, but got {loss_type}')
    return loss


def train(config):
    device = torch.device('cpu') if ((not torch.cuda.is_available()) or config.system.cpu) else torch.device('cuda')
    utils.utils.set_seed(config.system.seed)
    path = config.results_path
    if config.create_timestamp:
        path = os.path.join(path, utils.utils.get_date_time_str())

    os.makedirs(path, exist_ok=True)

    # ----------------------------------------------------------------
    # ---------------------- Create datasets -------------------------
    # ----------------------------------------------------------------
    data_path = os.path.join(config.data.path, 'sol_dataset.pkl')
    params_path = os.path.join(config.data.path, 'parameters_dataset.pkl')
    x_path = os.path.join(config.data.path, 'x_dataset.pkl')
    t_path = os.path.join(config.data.path, 't_dataset.pkl')
    data_config_path = os.path.join(config.data.path, 'config.yaml')

    train_dataset = PDEDataset(data_path, params_path, x_path, t_path, data_config_path, mode='train',
                               t_len_pct=config.data.t_len_pct, data_size=config.data.size, noise_coeff=config.data.noise)
    train_dataloader = DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=True,
                                  num_workers=config.system.num_workers)

    val_dataset = PDEDataset(data_path, params_path, x_path, t_path, data_config_path,
                             mode='val', t_len_pct=config.data.t_len_pct, noise_coeff=config.data.noise)
    val_dataloader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)

    test_dataset = PDEDataset(data_path, params_path, x_path, t_path, data_config_path,
                              mode='test', t_len_pct=config.data.t_len_pct, noise_coeff=config.data.noise)
    test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

    # ----------------------------------------------------------------
    # --------------- Create model and optimizer ---------------------
    # ----------------------------------------------------------------
    input_dim = train_dataset.context_shape
    model = PDExplain(input_dim, train_dataset.x.shape[0], config.pde_type).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)

    # Evaluate pretrained model to have a baseline
    test_pde_loss, ae_recon_loss, params_error = eval_model(config, model, val_dataloader, 'Pretraining', config.show_fig, path, device)
    print(f'Test error pretraining. PDE loss = %.6f, ae loss = %.6f, Params error = %.3f' % (test_pde_loss, ae_recon_loss, params_error))

    # Create a model for saving the best model found (on val set)
    best_model = PDExplain(input_dim, train_dataset.x.shape[0], config.pde_type).to(device)
    best_loss = np.inf

    # ----------------------------------------------------------------
    # --------------------- Start training ---------------------------
    # ----------------------------------------------------------------
    for epoch in trange(config.train.num_epochs):
        losses = []
        train_params_loss = []
        for batch in train_dataloader:
            f_data = batch[0].to(device)
            t = batch[1].to(device)
            gt_params = batch[2].to(device)
            context = batch[3].to(device)

            gt_params = torch.cat((gt_params, -1.0 * f_data[:, 1].unsqueeze(2)), dim=2)

            x_data = torch.FloatTensor(train_dataset.x).to(device)

            recon_context, pred_params = model(t, f_data[:, 1], context)

            batch_params_error = ((gt_params - pred_params) ** 2).mean()
            train_params_loss.append(batch_params_error.item())

            # pde_loss = calc_pde_loss(f_data, train_dataset.delta_t, train_dataset.delta_x, pred_params)
            pde_loss = model.loss_func(f_data, x_data, train_dataset.delta_t, train_dataset.delta_x, pred_params)
            ae_recon_loss = calc_ae_recon_loss(context, recon_context)

            loss = config.train.pde_loss_coeff * pde_loss + config.train.ae_loss_coeff * ae_recon_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(pde_loss.item())

        test_pde_loss, ae_recon_loss, params_error = eval_model(config, model, test_dataloader, epoch + 1, config.show_fig, path, device)
        print(f'Test error on epoch [%d/%d] PDE loss = %.6f, ae loss = %.6f, test params error = %.3f, train params error = %.3f' % (
            epoch + 1, config.train.num_epochs, test_pde_loss, ae_recon_loss, params_error, np.mean(train_params_loss)))

        if np.mean(losses) < best_loss:
            best_loss = np.mean(losses)
            best_model.load_state_dict(model.state_dict())
            print(f'Found a new best model')

    test_pde_loss, ae_recon_loss, params_error = eval_model(config, best_model, test_dataloader, 'Final', False, path, device)
    print(f'Best model test error: PDE loss = %.6f, ae loss = %.6f, test params error = %.3f' % (
        test_pde_loss, ae_recon_loss, params_error))

    torch.save(best_model.state_dict(), os.path.join(path, 'model_checkpoint.pkl'))
    utils.utils.save_config(config, os.path.join(path, 'config.yaml'))
    print(f'Train complete. Model and outputs saved in {path}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('--config-file', type=str, default=None)
    parser.add_argument('--config-list', nargs="+", default=None)

    args = parser.parse_args()

    config = get_cfg_defaults(args.config_file, args.config_list)

    train(config)
