import torch
import torch.backends.cudnn
import torch.backends.cuda
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_default_dtype(torch.float32)
import numpy as np
import os
import phys
from utils import to_pickle
from train_ode import get_args


def train(args):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    dtype = torch.get_default_dtype()
    torch.set_grad_enabled(False)

    # load data
    args.input_dim = 2

    print('Initializing model and data...')
    # set random seed after data generation
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # init model and optimizer

    model = phys.PhysicsModel(2, 1,
                              act=args.act, model=args.model, solver=args.solver,
                              data_mean=None, data_std=None, finde=args.finde)
    model = model.to(device)
    # train loop
    model.train(False)

    print('Making orbit and energy...')
    t_eval = torch.from_numpy(np.linspace(0, 200, 1000)).to(device, dtype)
    u0 = torch.from_numpy(np.array([1.0, 0.0])).reshape(1, 2).to(device, dtype)
    stats = {}
    stats['orbit'] = model.get_orbit(u0, t_eval=t_eval)
    q, v = stats['orbit'][..., 0], stats['orbit'][..., 1]  # type: ignore
    stats['energy'] = 0.5 * (q**2 + v**2)

    to_pickle(stats, args.path_pkl)
    return model, stats


if __name__ == "__main__":
    args = get_args()
    # save
    os.makedirs('{}'.format(args.save_dir)) if not os.path.exists('{}'.format(args.save_dir)) else None
    label = args.name
    # label = label + '-{}'.format(args.model)
    label = label + '-{}'.format(args.solver)
    label = label + '-finde,{},{},{}'.format(args.finde.variant, args.finde.num, args.finde.keeprate) if args.finde else label
    label = ('derivartive-' if args.train_deriv else '') + label
    result_path = '{}/phys-{}'.format(args.save_dir, label)
    args.path_pkl = '{}.pkl'.format(result_path)

    model_data = None
    stats = None

    model, stats = train(args)


# python train_ode_massspring.py --solver dopri5
# python train_ode_massspring.py --solver dopri5   --finde continuous --finde_num 0
# python train_ode_massspring.py --solver dopri5   --finde discrete --finde_num 0
# python train_ode_massspring.py --solver leapfrog
# python train_ode_massspring.py --solver leapfrog --finde continuous --finde_num 0
# python train_ode_massspring.py --solver leapfrog --finde discrete --finde_num 0
