import os
import numpy as np
import sympy as sp
import torch
from torch.utils.data import DataLoader
from parser_utils import get_args
from dataset import get_dataset
from sindy import SINDyRegression
from train import train_DISR, train_DISR_lbfgs
from evaluation.eval_eq import eval_sindy_regressor, di_sindy_truth, equivsindy_r_truth

if __name__ == '__main__':
    args = get_args()

    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    args = vars(args)

    train_dataset, valid_dataset, test_dataset = get_dataset(args)
    if args['sindy_optimizer'] != 'lbfgs':
        train_loader = DataLoader(train_dataset, batch_size=args['batch_size'], shuffle=True)
    else:
        data_size = int(len(train_dataset) * args['lbfgs_subsample'])
        train_loader = DataLoader(train_dataset, batch_size=data_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=args['batch_size'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args['batch_size'], shuffle=False)

    terms = ['t', 'x', 'u', 'dudt', 'dudx', 'dudxdx', 'dudxdxdx', 'dudxdxdxdx']
    if args['model'] == 'di-sindy':
        if args['pde'] in ['KdV', 'KS', 'Burgers']:
            differential_invariants = ['dudx', 'dudxdx', 'dudxdxdx', 'dudxdxdxdx', 'dudt + u * dudx']
        elif args['pde'] == 'nKdV':
            differential_invariants = ['dudx', 'dudxdx', 'dudxdxdx', 'dudxdxdxdx', 'exp(-t/50) * dudt + u * dudx']
    elif args['model'] == 'equivsindy-r':
        if args['pde'] in ['KdV', 'KS', 'Burgers']:
            differential_invariants = ['u', 'dudx', 'dudxdx', 'dudxdxdx', 'dudxdxdxdx',
                                       'u ** 2', 'u * dudx', 'u * dudxdx', 'u * dudxdxdx', 'u * dudxdxdxdx',
                                       'dudx ** 2', 'dudx * dudxdx', 'dudx * dudxdxdx', 'dudx * dudxdxdxdx',
                                       'dudxdx ** 2', 'dudxdx * dudxdxdx', 'dudxdx * dudxdxdxdx',
                                       'dudxdxdx ** 2', 'dudxdxdx * dudxdxdxdx',
                                       'dudxdxdxdx ** 2',
                                       'dudt']
        elif args['pde'] == 'nKdV':
            differential_invariants = ['u', 'dudx', 'dudxdx', 'dudxdxdx', 'dudxdxdxdx',
                                       'u ** 2', 'u * dudx', 'u * dudxdx', 'u * dudxdxdx', 'u * dudxdxdxdx',
                                       'dudx ** 2', 'dudx * dudxdx', 'dudx * dudxdxdx', 'dudx * dudxdxdxdx',
                                       'dudxdx ** 2', 'dudxdx * dudxdxdx', 'dudxdx * dudxdxdxdx',
                                       'dudxdxdx ** 2', 'dudxdxdx * dudxdxdxdx',
                                       'dudxdxdxdx ** 2',
                                       'exp(-t/50) * dudt']
    if args['pde'] in ['KdV', 'KS', 'Burgers']:
        generator = [['1 + 0 * u', '0 * u', '0 * u', '0 * u', '0 * u', '-dudx']]
        generator = sp.lambdify(terms, generator)
    elif args['pde'] == 'nKdV':
        generator = [['exp(-t/50)', '0 * u', '0 * u', '0 * u', '0 * u', '0 * u', 'dudt / 50 * exp(-t/50)'],
                     ['0 * u', '1 + 0 * u', '0 * u', '0 * u', '0 * u', '0 * u', '-dudx * exp(t/50)']]
        generator = sp.lambdify(terms, generator, modules={'exp': torch.exp})
    regressor = SINDyRegression(terms=terms, differential_invariants=differential_invariants, **args).to(args['device'])

    if args['sindy_optimizer'] == 'lbfgs':
        train_fn = train_DISR_lbfgs
    else:
        train_fn = train_DISR
    train_fn(
        generator=generator,
        regressor=regressor,
        train_loader=train_loader,
        test_loader=test_loader,
        **args
    )

    if not os.path.exists(f'saved_models/{args["save_dir"]}'):
        os.makedirs(f'saved_models/{args["save_dir"]}')
    torch.save(regressor.state_dict(), f'saved_models/{args["save_dir"]}/regressor.pt')

    print('\n=== Evaluation ===\n')
    if args['model'] == 'di-sindy':
        true_eq = di_sindy_truth[args['pde']]
    elif args['model'] == 'equivsindy-r':
        true_eq = equivsindy_r_truth[args['pde']]
    coef, cf, mse, cf_all, mse_all = eval_sindy_regressor(regressor, true_eq, args['model'])
    print(f'Correct form: {cf}')
    print(f'MSE: {np.where(cf, mse, 0.0)}')
    print(f'MSE (any): {mse}')
    eval_results = {
            'coefficients': coef,
            'correct_form': cf,
            'mse': mse,
            'correct_form_all': cf_all,
            'mse_all': mse_all,
        }
    eval_save_dir = f'eval_results/{args["save_dir"]}'
    if not os.path.exists(eval_save_dir):
        os.makedirs(eval_save_dir)
    np.savez(f'{eval_save_dir}/seed{seed}.npz', **eval_results)
