import torch
import torch.backends.cudnn
import torch.backends.cuda
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_default_dtype(torch.float32)
import argparse
import numpy as np
import os
import sys
import phys
import time
from utils import L2_loss, to_pickle, from_pickle, DotDict


def get_args():
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument('--noretry', dest='noretry', action='store_true', help='not do a finished trial.')
    parser.add_argument('--load', dest='load', action='store_true', help='load existing data.')
    # network, experiments
    parser.add_argument('--hidden_dim', default=200, type=int, help='number of hidden units.')
    parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate.')
    parser.add_argument('--batch_size', default=200, type=int, help='batch size.')
    parser.add_argument('--act', default='tanh', type=str, help='activation function.')
    parser.add_argument('--total_steps', default=100000, type=int, help='number of iterations.')
    parser.add_argument('--train_deriv', dest='train_deriv', action='store_true', help='training using derivative.')
    parser.add_argument('--adjoint', dest='adjoint', action='store_true', help='use adjoint method for odeint.')
    # display
    parser.add_argument('--log_freq', default=200, type=int, help='number of steps between prints.')
    parser.add_argument('--verbose', dest='verbose', action='store_true', help='verbose?.')
    parser.add_argument('--name', default='2body', type=str, help='dataset name; 2body, 2pend, ...')
    parser.add_argument('--seed', default=0, type=int, help='random seed.')
    parser.add_argument('--save_dir', default='./experiments/', type=str, help='where to save the trained model.')
    parser.add_argument('--result_dir', default='results', type=str, help='where to save the results.')
    # model
    parser.add_argument('--model', default='hnn', type=str, help='base model.')
    parser.add_argument('--solver', default='dopri5', type=str, help='numerical integrator.')
    parser.add_argument('--norm', dest='norm', action='store_true', help='data normalization at the first layer.')
    parser.add_argument('--postfix', default=None, type=str, help='postfix for saved files.')
    parser.add_argument('--tag', default=None, type=str, help='meaningless tag.')
    # FINDE
    parser.add_argument('--finde', default=None, type=str, help='use FINDE.')
    parser.add_argument('--finde_num', default=0, type=int, help='number of first integrals.')
    parser.add_argument('--finde_keeprate', default=1.0, type=float, help='keep rate for first integrals during training.')
    parser.add_argument('--finde_hnn', default=False, action='store_true', help='use the Hamiltonian as a first integral.')
    parser.set_defaults(feature=True)
    args = parser.parse_args()
    finde = None if args.finde is None else DotDict(variant=args.finde, num=args.finde_num, keeprate=args.finde_keeprate, hnn=args.finde_hnn)
    del args.finde, args.finde_num, args.finde_keeprate , args.finde_hnn  # type: ignore
    args.finde = finde
    return args


def get_diff(u_reshaped, dt):
    u1 = u_reshaped[:, :-1].contiguous().view(-1, u_reshaped.shape[-1])
    u2 = u_reshaped[:, 1:].contiguous().view(-1, u_reshaped.shape[-1])
    diff_u = (u2 - u1) / dt
    return u1.detach(), u2.detach(), diff_u.detach()


def train(args):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    dtype = torch.get_default_dtype()
    torch.set_grad_enabled(False)

    # load data
    import importlib
    if os.path.exists('experiments/data{}.py'.format(args.name)):
        dataset = importlib.import_module('experiments.data{}'.format(args.name))
    else:
        dataset = importlib.import_module('experiments.data{}'.format(args.name.capitalize()))
    get_dataset, get_energies=dataset.get_dataset, dataset.get_energies
    data = get_dataset(args.name, args.save_dir, verbose=True)

    args.input_dim = data['u'].shape[-1]

    print('Initializing model and data...')
    # set random seed after data generation
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # init model and optimizer
    data_mean = data['u'].reshape(-1, args.input_dim).mean(axis=0) if args.norm else None
    data_std = data['u'].reshape(-1, args.input_dim).std(axis=0) if args.norm else None
    model = phys.PhysicsModel(args.input_dim, args.hidden_dim,
                              act=args.act, model=args.model, solver=args.solver,
                              data_mean=data_mean, data_std=data_std, finde=args.finde)
    model = model.to(device)
    optim = torch.optim.Adam(model.parameters(), args.lr, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.total_steps)

    # load model if needed
    if args.load:
        model.load_state_dict(torch.load(args.path_tar, map_location=device))
        stats = from_pickle(args.path_pkl)
        args.total_steps = 0
        print('Model successfully loaded from {}'.format(args.path_tar))
    else:
        stats = {}

    # prepate data
    dt = data['dt']
    u_train = torch.tensor(data['u'], requires_grad=True, device=device, dtype=dtype)
    u_test = torch.tensor(data['test']['u'], requires_grad=True, device=device, dtype=dtype)
    dudt_train = torch.tensor(data['dudt'], device=device, dtype=dtype)
    dudt_test = torch.tensor(data['test']['dudt'], device=device, dtype=dtype)
    u1_train, u2_train, diff_u_train = get_diff(u_train.view(-1, data['meta']['n_steps'] + 1, args.input_dim), dt)
    u1_test, u2_test, diff_u_test = get_diff(u_test.view(-1, data['test']['meta']['n_steps'] + 1, args.input_dim), dt)

    print('Entering the training loop...')
    # train loop
    if args.total_steps > 0:
        stats['loss_train'] = []
        stats['loss_test'] = []
    time_clock = time.time()
    for itr in range(1, args.total_steps + 1):
        model.train(True)
        with torch.enable_grad():
            # train itr
            if args.train_deriv:
                ixs = torch.randint(0, u_train.shape[0], (args.batch_size,))
                dudt_hat = model.time_derivative(u_train[ixs])
                loss_train = L2_loss(dudt_train[ixs], dudt_hat)
            else:
                ixs = torch.randint(0, u1_train.shape[0], (args.batch_size,))
                diff_u_hat = model.discrete_time_derivative(u1_train[ixs], dt=dt, x2=u2_train[ixs])
                loss_train = L2_loss(diff_u_train[ixs], diff_u_hat)
        optim.zero_grad()
        loss_train.backward()
        optim.step()
        scheduler.step()

        # run test data
        model.train(False)
        if args.train_deriv:
            ixs = torch.randint(0, u_test.shape[0], (args.batch_size,))
            dudt_hat = model.time_derivative(u_test[ixs])
            loss_test = L2_loss(dudt_test[ixs], dudt_hat)
        else:
            ixs = torch.randint(0, u1_test.shape[0], (args.batch_size,))
            diff_u_hat = model.discrete_time_derivative(u1_test[ixs], dt=dt, x2=u2_test[ixs])
            loss_test = L2_loss(diff_u_test[ixs], diff_u_hat)

        # logging
        stats['loss_train'].append(loss_train.item())
        stats['loss_test'].append(loss_test.item())
        if args.verbose and itr % args.log_freq == 0:
            time_clock_new = time.time()
            print("itr {}, time {:.2f}s, loss_train {:.4e}, loss_test {:.4e}, lr {}"
                  .format(itr, time_clock_new - time_clock, loss_train.item(), loss_test.item(), optim.param_groups[0]['lr']))
            time_clock = time_clock_new

    model.train(False)
    torch.save(model.state_dict(), args.path_tar)

    test_is_implicit = model.solver_eval in phys.SOLVER_LIST_ADDITIONAL_IMPLICIT \
        or (model.finde is not None and model.finde.is_discrete)
    batch_size_test = args.batch_size if not test_is_implicit else 1

    print('Making first integrals errors...')
    if model.finde is not None and model.finde.quantities is not None:
        u_flatten_test = u_test.reshape(-1, args.input_dim)
        v_test = torch.cat([model.finde.quantities(u_flatten_test[idx:idx + args.batch_size])
                                for idx in range(0, len(u_flatten_test), args.batch_size)], dim=0).reshape(*u_test.shape[:-1],-1)
        stats['v_test'] = v_test.cpu().numpy()
        u_flatten_train = u_train.reshape(-1, args.input_dim)
        v_train = torch.cat([model.finde.quantities(u_flatten_train[idx:idx + args.batch_size])
                                for idx in range(0, len(u_flatten_train), args.batch_size)], dim=0).reshape(*u_train.shape[:-1],-1)
        stats['u_train'] = u_train.cpu().numpy()
        stats['v_train'] = v_train.cpu().numpy()

    print('Making derivative errors...')
    u_flatten_test = u_test.reshape(-1, args.input_dim)
    dudt_flatten_test = dudt_test.reshape(-1, args.input_dim)
    dudt_hat_test = torch.cat([model.time_derivative(u_flatten_test[idx:idx + args.batch_size])
                               for idx in range(0, len(u_flatten_test), args.batch_size)], dim=0) if not test_is_implicit else dudt_flatten_test
    stats['dudt_test'] = dudt_flatten_test.cpu().numpy()
    stats['dudt_hat_test'] = dudt_hat_test.cpu().numpy()
    stats['deriv_error_test'] = (stats['dudt_test'] - stats['dudt_hat_test']).__pow__(2).mean().item()
    print("deriv MSE {:.4e}".format(stats['deriv_error_test']))

    print('Making 1-step prediction errors...')
    dudt_hat = torch.cat([model.discrete_time_derivative(u1_test[idx:idx + batch_size_test], dt=dt)
                          for idx in range(0, len(u1_test), batch_size_test) if print(idx,'/',len(u1_test),end='\r') is None], dim=0)
    u2_hat_test = dudt_hat * dt + u1_test
    stats['u1_test'] = u1_test.reshape(-1, data['test']['meta']['n_steps'], args.input_dim).cpu().numpy()
    stats['u2_test'] = u2_test.reshape(-1, data['test']['meta']['n_steps'], args.input_dim).cpu().numpy()
    stats['u2_hat_test'] = u2_hat_test.reshape(-1, data['test']['meta']['n_steps'], args.input_dim).cpu().numpy()
    stats['1step_error_test'] = (stats['u2_test'] - stats['u2_hat_test']).__pow__(2).mean().item()
    print("1-step MSE {:.4e}".format(stats['1step_error_test']))

    print('Making orbit and energy errors...')
    t_eval = torch.from_numpy(data['test']['t_eval']).to(device)
    stats['orbit_model_test'] = torch.concat([model.get_orbit(u_test[idx:idx + batch_size_test, 0], t_eval=t_eval)
                                              for idx in range(0, u_test.shape[0], batch_size_test)  if print(idx,'/',len(u_test),end='\r') is None], dim=1).transpose(1, 0).cpu().numpy()
    stats['energies_model_test'] = get_energies(stats['orbit_model_test'])

    stats['orbit_data_train'] = data['u']
    stats['orbit_data_test'] = data['test']['u']
    stats['energies_data_test'] = data['test']['energies']

    stats['orbit_error_test'] = (stats['orbit_data_test'][:, 1:] - stats['orbit_model_test'][:, 1:]).__pow__(2).mean().item()
    stats['energy_error_test'] = (stats['energies_data_test']['energy'][:, 1:] - stats['energies_model_test']['energy'][:, 1:]).__pow__(2).mean().item()

    print("state MSE {:.4e}".format(stats['orbit_error_test']))
    print("energy MSE {:.4e}".format(stats['energy_error_test']))

    to_pickle(stats, args.path_pkl)

    return model, stats


if __name__ == "__main__":
    args = get_args()
    # save
    os.makedirs('{}/{}'.format(args.save_dir, args.result_dir)) if not os.path.exists('{}/{}'.format(args.save_dir, args.result_dir)) else None
    label = args.name
    label = label + '-{}'.format(args.model)
    label = label + '-{}'.format(args.solver)
    label = label + '-{}'.format(args.postfix) if args.postfix else label
    label = label + '-norm' if args.norm else label
    label = label + '-finde,{}{},{},{}'.format(args.finde.variant, 'HNN' if args.finde.hnn else '', args.finde.num, args.finde.keeprate) if args.finde else label
    label = label + '-seed{}'.format(args.seed)
    label = ('derivartive-' if args.train_deriv else '') + label
    result_path = '{}/{}/phys-{}'.format(args.save_dir, args.result_dir, label)
    args.path_tar = '{}.tar'.format(result_path)
    args.path_pkl = '{}.pkl'.format(result_path)
    args.path_txt = '{}.txt'.format(result_path)


    model_data = None
    stats = None
    if os.path.exists(args.path_txt):
        if args.noretry:
            print(args.path_txt)
            print('====== already done:', ' '.join(sys.argv), flush=True)
            exit()


    print('====== not yet:', ' '.join(sys.argv), flush=True)
    print('====== entering:', ' '.join(sys.argv))
    model, stats = train(args)
    keys = ['deriv_error_test', '1step_error_test', 'energy_error_test', 'orbit_error_test', ]
    with open(args.path_txt, 'w') as of:
        print('#', *keys, sep='\t', file=of)
        print(*[stats[k] for k in keys], sep='\t', file=of)

    print('====== ended:', ' '.join(sys.argv), flush=True)
