import torch
import torch.backends.cudnn
import torch.backends.cuda
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_default_dtype(torch.float64)
import numpy as np
import os
import sys
import phys
import time
from utils import L2_loss, to_pickle, from_pickle
from train_ode import get_args


def get_diff(u_reshaped, dt):
    u1 = u_reshaped[:, :-1].contiguous().view(-1, *u_reshaped.shape[-2:])
    u2 = u_reshaped[:, 1:].contiguous().view(-1, *u_reshaped.shape[-2:])
    diff_u = (u2 - u1) / dt
    return u1.detach(), u2.detach(), diff_u.detach()


def train(args):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    torch.set_default_dtype(torch.float64)
    dtype = torch.get_default_dtype()
    torch.set_grad_enabled(False)

    assert args.batch_size % args.batch_div == 0
    assert args.batch_div > 0

    # load data
    import importlib
    if os.path.exists('experiments/data{}.py'.format(args.name)):
        dataset = importlib.import_module('experiments.data{}'.format(args.name))
    else:
        dataset = importlib.import_module('experiments.data{}'.format(args.name.capitalize()))
    get_dataset, get_energies = dataset.get_dataset, dataset.get_energies
    data = get_dataset(args.name, args.save_dir, verbose=True)

    args.input_dim = data['u'].shape[-2]
    args.input_width = data['u'].shape[-1]

    print('Initializing model and data...')
    # set random seed after data generation
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # init model and optimizer
    data_mean = data['u'].transpose(0, 1, 3, 2).reshape(-1, args.input_dim).mean(axis=0).reshape(args.input_dim, 1) if args.norm else None
    data_std = data['u'].transpose(0, 1, 3, 2).reshape(-1, args.input_dim).std(axis=0).reshape(args.input_dim, 1) if args.norm else None
    model = phys.PhysicsModelPDE1d(args.input_dim, args.hidden_dim,
                                   act=args.act, model=args.model, solver=args.solver,
                                   data_mean=data_mean, data_std=data_std, finde=args.finde)
    model = model.to(device)
    optim = torch.optim.Adam(model.parameters(), args.lr, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.total_steps)

    if args.adjoint:
        from torchdiffeq import odeint_adjoint
        model.odeint = odeint_adjoint

    # load model if needed
    if args.load:
        model.load_state_dict(torch.load(args.path_tar, map_location=device))
        stats = from_pickle(args.path_pkl)
        args.total_steps = 0
        print('Model successfully loaded from {}'.format(args.path_tar), flush=True)
    else:
        stats = {}

    # prepate data
    dt = data['dt']
    u_train = torch.tensor(data['u'], requires_grad=True, device=device, dtype=dtype)
    u_test = torch.tensor(data['test']['u'], requires_grad=True, device=device, dtype=dtype)
    dudt_train = torch.tensor(data['dudt'], device=device, dtype=dtype)
    dudt_test = torch.tensor(data['test']['dudt'], device=device, dtype=dtype)
    u1_train, u2_train, diff_u_train = get_diff(u_train.view(-1, data['meta']['n_steps'] + 1, args.input_dim, args.input_width), dt)
    u1_test, u2_test, diff_u_test = get_diff(u_test.view(-1, data['test']['meta']['n_steps'] + 1, args.input_dim, args.input_width), dt)

    print('Entering the training loop...')
    # train loop
    if args.total_steps > 0:
        stats['loss_train'] = []
        stats['loss_test'] = []
    time_clock = time.time()
    for itr in range(1, args.total_steps + 1):
        model.train(True)
        with torch.enable_grad():
            # train itr
            loss_train = torch.zeros(1, device=device, dtype=dtype)
            if args.train_deriv:
                ixs = torch.randint(0, u_train.shape[0], (args.batch_size,))
                dudt_hat = model.time_derivative(u_train[ixs])
                loss_train = L2_loss(dudt_train[ixs], dudt_hat)
            else:
                ixs = torch.randint(0, u1_train.shape[0], (args.batch_size,))
                diff_u_hat = model.discrete_time_derivative(u1_train[ixs], dt=dt, x2=u2_train[ixs])
                loss_train = L2_loss(diff_u_train[ixs], diff_u_hat)
        optim.zero_grad()
        loss_train.backward()
        optim.step()
        scheduler.step()

        # run test data
        model.train(False)
        if args.train_deriv:
            ixs = torch.randint(0, u_test.shape[0], (args.batch_size,))
            dudt_hat = model.time_derivative(u_test[ixs])
            loss_test = L2_loss(dudt_test[ixs], dudt_hat)
        else:
            ixs = torch.randint(0, u1_test.shape[0], (args.batch_size,))
            diff_u_hat = model.discrete_time_derivative(u1_test[ixs], dt=dt, x2=u2_test[ixs])
            loss_test = L2_loss(diff_u_test[ixs], diff_u_hat)

        # logging
        stats['loss_train'].append(loss_train.item())
        stats['loss_test'].append(loss_test.item())
        if args.verbose and itr % args.log_freq == 0:
            time_clock_new = time.time()
            print("itr {}, time {:.2f}s, loss_train {:.4e}, loss_test {:.4e}, lr {}"
                  .format(itr, time_clock_new - time_clock, loss_train.item(), loss_test.item(), optim.param_groups[0]['lr']))
            time_clock = time_clock_new

    model.train(False)
    torch.save(model.state_dict(), args.path_tar)
    to_pickle(stats, args.path_pkl)

    test_is_implicit = model.solver_eval in phys.SOLVER_LIST_ADDITIONAL_IMPLICIT \
        or (model.finde is not None and model.finde.is_discrete)
    batch_size_test = args.batch_size // args.batch_div if not test_is_implicit else 1

    print('Making derivative errors...')
    u_flatten_test = u_test.reshape(-1, args.input_dim, args.input_width)
    dudt_flatten_test = dudt_test.reshape(-1, args.input_dim, args.input_width)
    dudt_hat_test = torch.cat([model.time_derivative(u_flatten_test[idx:idx + batch_size_test])
                               for idx in range(0, len(u_flatten_test), batch_size_test)], dim=0) if not test_is_implicit else dudt_flatten_test
    stats['dudt_test'] = dudt_flatten_test.cpu().numpy()
    stats['dudt_hat_test'] = dudt_hat_test.cpu().numpy()
    stats['deriv_error_test'] = (stats['dudt_test'] - stats['dudt_hat_test']).__pow__(2).mean().item()
    print("deriv MSE {:.4e}".format(stats['deriv_error_test']))

    print('Making 1-step prediction errors...')
    dudt_hat = torch.cat([model.discrete_time_derivative(u1_test[idx:idx + batch_size_test], dt=dt)
                          for idx in range(0, len(u1_test), batch_size_test) if print(idx, '/', len(u1_test), end='\r') is None], dim=0)
    u2_hat_test = dudt_hat * dt + u1_test
    stats['u1_test'] = u1_test.reshape(-1, data['test']['meta']['n_steps'], args.input_dim, args.input_width).cpu().numpy()
    stats['u2_test'] = u2_test.reshape(-1, data['test']['meta']['n_steps'], args.input_dim, args.input_width).cpu().numpy()
    stats['u2_hat_test'] = u2_hat_test.reshape(-1, data['test']['meta']['n_steps'], args.input_dim, args.input_width).cpu().numpy()
    stats['1step_error_test'] = (stats['u2_test'] - stats['u2_hat_test']).__pow__(2).mean().item()
    print("1-step MSE {:.4e}".format(stats['1step_error_test']))

    print('Making orbit and energy errors...')
    t_eval = torch.from_numpy(data['test']['t_eval']).to(device)
    stats['orbit_model_test'] = torch.concat([model.get_orbit(u_test[idx:idx + batch_size_test, 0], t_eval=t_eval)
                                              for idx in range(0, u_test.shape[0], batch_size_test) if print(idx, '/', len(u_test), end='\n') is None], dim=1).transpose(1, 0).cpu().numpy()
    stats['energies_model_test'] = get_energies(stats['orbit_model_test'])

    stats['orbit_data_train'] = data['u']
    stats['orbit_data_test'] = data['test']['u']
    stats['energies_data_test'] = data['test']['energies']

    stats['orbit_error_test'] = (stats['orbit_data_test'][:, 1:] - stats['orbit_model_test'][:, 1:]).__pow__(2).mean().item()
    stats['energy_error_test'] = (stats['energies_data_test']['energy'][:, 1:] - stats['energies_model_test']['energy'][:, 1:]).__pow__(2).mean().item()

    print("state MSE {:.4e}".format(stats['orbit_error_test']))
    print("energy MSE {:.4e}".format(stats['energy_error_test']))

    to_pickle(stats, args.path_pkl)

    return model, stats


if __name__ == "__main__":
    args = get_args()
    # save
    os.makedirs('{}/{}'.format(args.save_dir, args.result_dir)) if not os.path.exists('{}/{}'.format(args.save_dir, args.result_dir)) else None
    label = args.name
    label = label + '-{}'.format(args.model)
    label = label + '-{}'.format(args.solver)
    label = label + '-{}'.format(args.postfix) if args.postfix else label
    label = label + '-norm' if args.norm else label
    label = label + '-finde,{}{},{},{}'.format(args.finde.variant, 'HNN' if args.finde.hnn else '', args.finde.num, args.finde.keeprate) if args.finde else label
    label = label + '-seed{}'.format(args.seed)
    label = ('derivartive-' if args.train_deriv else '') + label
    result_path = '{}/{}/phys-{}'.format(args.save_dir, args.result_dir, label)
    args.path_tar = '{}.tar'.format(result_path)
    args.path_pkl = '{}.pkl'.format(result_path)
    args.path_txt = '{}.txt'.format(result_path)

    model_data = None
    stats = None
    if os.path.exists(args.path_txt):
        if args.noretry:
            print(args.path_txt)
            print('====== already done:', ' '.join(sys.argv), flush=True)
            exit()

    print('====== not yet:', ' '.join(sys.argv), flush=True)
    print('====== entering:', ' '.join(sys.argv))
    model, stats = train(args)
    keys = ['deriv_error_test', '1step_error_test', 'energy_error_test', 'orbit_error_test', ]
    with open(args.path_txt, 'w') as of:
        print('#', *keys, sep='\t', file=of)
        print(*[stats[k] for k in keys], sep='\t', file=of)

    print('====== ended:', ' '.join(sys.argv), flush=True)
