"""
Copyright 2025 [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os, sys, argparse, time, shutil
sys.path.append(os.pardir)
import yaml
import torch
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
from src.dataset import generate_dataset
from src.nn import OUFlow
from src.train import train
from src.utils import fix_random, nanstd, count_parameters
# from benchmark import gpr, latentsde, acssm, dspd



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, help='Config file name.')
    parser.add_argument('-d', action='store_true', help='If specified, dataset is generated.')
    parser.add_argument('-t', action='store_true', help='If specified, the ML model is trained.')
    args = parser.parse_args()

    if not args.d and not args.t:
        print('Please specify -d or -t. If you want to generate dataset, specify -d. If you want to train the model, specify -t.')
        sys.exit()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    current_dir = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(current_dir, 'config', args.config + '.yaml')
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    fix_random(config['seed'])

    outdir = os.path.join(current_dir, 'output', config['experiment_name'])
    os.makedirs(outdir, exist_ok=True)
    dataset_dir = os.path.join(outdir, 'dataset')

    if args.d:
        print('==== Dataset generation ====')
        os.makedirs(dataset_dir, exist_ok=True)
        generate_dataset(dataset_dir, config['dataset'])

    if args.t:
        print('==== Training ====')

        xs_train = torch.load(os.path.join(dataset_dir, 'train_x.pt'), weights_only=False).to(device)
        ts_train = torch.load(os.path.join(dataset_dir, 'train_t.pt'), weights_only=False).to(device)
        xs_val = torch.load(os.path.join(dataset_dir, 'val_x.pt'), weights_only=False).to(device)
        ts_val = torch.load(os.path.join(dataset_dir, 'val_t.pt'), weights_only=False).to(device)

        if 'name' not in config['model']:
            checkpoint_dir = os.path.join(outdir, 'checkpoint')
            os.makedirs(checkpoint_dir, exist_ok=True)
            
            # Calculate mean and standard deviation of the training data for normalization
            xs_train = pad_packed_sequence(xs_train, batch_first=True, padding_value=torch.nan)[0]
            ts_train = pad_packed_sequence(ts_train, batch_first=True, padding_value=torch.nan)[0]
            if config['model']['double_precision']:
                xs_train = xs_train.double()
                ts_train = ts_train.double()
            config['model']['x_base'] = torch.zeros(xs_train.shape[-1], dtype=xs_train.dtype, device=device)
            torch.nanmean(xs_train, dim=(0, 1), out=config['model']['x_base'])
            config['model']['x_scale'] = nanstd(xs_train, dim=(0, 1), unbiased=True)

            dim = xs_train.shape[-1]
            xs_train = pack_sequence([x for x in xs_train])
            ts_train = pack_sequence([t for t in ts_train])
            
            model = OUFlow(dim, **config['model']).to(device)

            log_dir = os.path.join(outdir, 'log')
            os.makedirs(log_dir, exist_ok=True)
            now = time.strftime('%Y%m%d_%H%M%S')
            log_path = os.path.join(log_dir, f'{now}.log')
            shutil.copy(config_path, log_path)
            count_parameters(model, log_path)
            
            train(model, xs_train, ts_train, xs_val, ts_val, checkpoint_dir, log_path, config['train'])

        elif config['model']['name'] == 'GPR':
            gpr.train(xs_train, ts_train, xs_val, ts_val, outdir, config_path, config['model'], config['train'])

        elif config['model']['name'] == 'LatentSDE':
            latentsde.train(xs_train, ts_train, xs_val, ts_val, outdir, config_path, config['model'], config['train'])

        elif config['model']['name'] == 'ACSSM':
            acssm.train(xs_train, ts_train, xs_val, ts_val, outdir, config_path, config['model'], config['train'])
        
        elif config['model']['name'] == 'DSPD':
            dspd.train(xs_train, ts_train, xs_val, ts_val, outdir, config_path, config['model'], config['train'])