import os
import argparse

import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm import trange
import pde

from model.configs.train_defaults import get_cfg_defaults
from model.rhs_model import RHSModel
from model.dataset import PDEDataset
import utils


def solve_pde(initial_conditions, delta_t, delta_x, t_len, x_len, model, context):
    grid = pde.CartesianGrid([[0.0, x_len]], x_len // delta_x)
    state = pde.ScalarField(grid, data=initial_conditions)
    bc = [{'value': initial_conditions[0]}, {'value': initial_conditions[-1]}]
    eq = PDE(bc, model, context)

    storage = pde.MemoryStorage()
    eq.solve(state, t_range=t_len, dt=delta_t/2.0, tracker=storage.tracker(delta_t))

    return np.array(storage.data)


class PDE(pde.PDEBase):
    def __init__(self, bc, model, context):
        super(PDE, self).__init__()
        self.bc = bc
        self.model = model
        self.context = torch.FloatTensor(context).unsqueeze(0)

    def evolution_rate(self, state, t=0):
        model_t = torch.FloatTensor([t]).unsqueeze(0)
        state_tensor = torch.FloatTensor(state.data).unsqueeze(0)
        _, dx_dt = self.model(model_t, state_tensor, self.context)
        dx_dt = dx_dt.squeeze().numpy()
        if type(state) == pde.ScalarField:
            dx_dt = pde.ScalarField(state.grid, data=dx_dt)
        if type(state) == pde.VectorField:
            dx_dt = pde.VectorField(state.grid, data=dx_dt)
        return dx_dt


def eval_model(config, model, dataloader, epoch, show_fig, path, device):
    with torch.no_grad():
        for batch in dataloader:
            f_data = batch[0].to(device)
            t = batch[1].to(device)
            context = batch[3].to(device)

            recon_context, pred_rhs = model(t, f_data[:, 1], context)
            test_pde_loss = model.loss_func(f_data, dataloader.dataset.delta_t, pred_rhs)
            ae_recon_loss = calc_ae_recon_loss(context, recon_context)

    return test_pde_loss, ae_recon_loss


def calc_ae_recon_loss(context, pred_context, loss_type='l2'):
    if loss_type == 'l1':
        loss = torch.abs(context - pred_context).mean()
    elif loss_type == 'l2':
        loss = ((context - pred_context) ** 2).mean()
    else:
        raise ValueError(f'loss_type can be only l1 / l2, but got {loss_type}')
    return loss


def train(config):
    device = torch.device('cpu') if ((not torch.cuda.is_available()) or config.system.cpu) else torch.device('cuda')
    utils.utils.set_seed(config.system.seed)
    path = config.results_path
    if config.create_timestamp:
        path = os.path.join(path, utils.utils.get_date_time_str())

    os.makedirs(path, exist_ok=True)

    # ----------------------------------------------------------------
    # ---------------------- Create datasets -------------------------
    # ----------------------------------------------------------------
    data_path = os.path.join(config.data.path, 'sol_dataset.pkl')
    params_path = os.path.join(config.data.path, 'parameters_dataset.pkl')
    x_path = os.path.join(config.data.path, 'x_dataset.pkl')
    t_path = os.path.join(config.data.path, 't_dataset.pkl')
    data_config_path = os.path.join(config.data.path, 'config.yaml')

    train_dataset = PDEDataset(data_path, params_path, x_path, t_path, data_config_path, mode='train',
                               t_len_pct=config.data.t_len_pct, data_size=config.data.size)
    train_dataloader = DataLoader(train_dataset, batch_size=config.train.batch_size, shuffle=True,
                                  num_workers=config.system.num_workers)

    val_dataset = PDEDataset(data_path, params_path, x_path, t_path, data_config_path,
                             mode='val', t_len_pct=config.data.t_len_pct)
    val_dataloader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)

    test_dataset = PDEDataset(data_path, params_path, x_path, t_path, data_config_path,
                              mode='test', t_len_pct=config.data.t_len_pct)
    test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

    # ----------------------------------------------------------------
    # --------------- Create model and optimizer ---------------------
    # ----------------------------------------------------------------
    input_dim = train_dataset.context_shape
    model = RHSModel(input_dim, train_dataset.x.shape[0]).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)

    # Evaluate pretrained model to have a baseline
    test_pde_loss, ae_recon_loss = eval_model(config, model, val_dataloader, 'Pretraining', False, path, device)
    print(f'Test error pretraining. PDE loss = %.6f, ae loss = %.6f' % (test_pde_loss, ae_recon_loss))

    # ----------------------------------------------------------------
    # --------------------- Start training ---------------------------
    # ----------------------------------------------------------------
    for epoch in trange(config.train.num_epochs):
        losses = []
        for batch in train_dataloader:
            f_data = batch[0].to(device)
            t = batch[1].to(device)
            context = batch[3].to(device)

            recon_context, pred_rhs = model(t, f_data[:, 1], context)

            pde_loss = model.loss_func(f_data, train_dataset.delta_t, pred_rhs)
            ae_recon_loss = calc_ae_recon_loss(context, recon_context)

            loss = config.train.pde_loss_coeff * pde_loss + config.train.ae_loss_coeff * ae_recon_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

        test_pde_loss, ae_recon_loss = eval_model(config, model, test_dataloader, epoch + 1, config.show_fig, path, device)
        print(f'Test error on epoch [%d/%d] PDE loss = %.6f, ae loss = %.6f' % (
            epoch + 1, config.train.num_epochs, test_pde_loss, ae_recon_loss))

    torch.save(model.state_dict(), os.path.join(path, 'model_checkpoint.pkl'))
    utils.utils.save_config(config, os.path.join(path, 'config.yaml'))
    print(f'Train complete. Model and outputs saved in {path}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('--config-file', type=str, default=None)
    parser.add_argument('--config-list', nargs="+", default=None)

    args = parser.parse_args()

    config = get_cfg_defaults(args.config_file, args.config_list)

    train(config)
