"""Train a model iteratively, using smoother distributions."""

import torch
import os
from isbmodel.dpf import IterativeTrainer
import hydra
import numpy as np
import time
from isbmodel.data_loaders import read_s_shape_observations, load_toy_data, TensorSampler
from isbmodel.plot import plot_trajectory_video

@hydra.main(config_path="configs", config_name="s_shape2d_iterative_smoother")
def train_iterative(config):
    base_folder = config.base_folder
    assert base_folder is not None, 'The config argument base_folder has not been defined: modify the config file'

    torch.random.manual_seed(42)    # matters for selection of observations and for initial sample.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    config.dataset.n_dim = tuple(eval(config.dataset.n_dim))

    print(device)
    if torch.cuda.is_available():
        print(f'Device: {torch.cuda.get_device_name(0)}')

    model_class = config.model.score
    dataset_name = config.dataset.dataset_name
    plot_folder = os.path.join(base_folder, 'plots', 'isb_model', dataset_name, 'iterative_training', time.strftime("%Y-%m-%d"), time.strftime("%H-%M"))

    # get data
    assert dataset_name == 's_shape', f'Dataset {dataset_name} not recognized.'
    loaders, _ = load_toy_data([dataset_name], batch_size=10000, device=device)
    loader = loaders[dataset_name]
    term_data = next(loader)
   

    if config.dataset.obs_dataset == 's_shape':
        obs_ts, obs_samples, obs_times, rand_select = read_s_shape_observations(base_folder, config, device=device)   
    else:
        raise NotImplementedError(f'Observation dataset {config.dataset.obs_dataset} not recognized/ not included in supplementary material.')

    if config.model.init_dist == 'gauss':
        data = term_data
        if config.filter.zero_mean:
            mean = torch.zeros(*config.dataset.n_dim, device=device, dtype=torch.float64)
        else:
            mean = torch.mean(data, axis=0)
        if config.filter.unit_var:
            cov = torch.eye(*config.dataset.n_dim, device=device, dtype=torch.float64)
        else:
            cov = torch.eye(*config.dataset.n_dim, device=device, dtype=torch.float64)*(np.std(data.detach().cpu().numpy())**2)
        init_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
    else:
        raise NotImplementedError(f'Initial distribution type {config.model.init_dist} not implemented')

    if model_class == 'nn':
        init_drift_model_path = None
        fwd_drift_model_path = os.path.join(config.base_folder, 'model',  config.model.fwd_drift_file_name)
        bwd_drift_model_path = os.path.join(config.base_folder, 'model',  config.model.bwd_drift_file_name)
    else:
        raise NotImplementedError(f'Model class {model_class} not implemented')

    terminal_dist = TensorSampler(term_data)
    iterative_trainer = IterativeTrainer(config,
                                        plot_folder=plot_folder,
                                        init_dist=init_dist,
                                        fwd_drift_model_path=fwd_drift_model_path,
                                        bwd_drift_model_path=bwd_drift_model_path,
                                        init_drift_model_path=init_drift_model_path,
                                        device=device)
    iterative_trainer.train(obs_times, obs_samples, terminal_dist, rand_select, obs_ts, plot=config.train.plot, batch_size=config.train.batch_size, lr=config.train.lr) 
    
    with torch.no_grad():
        iterative_trainer.init_loop = False
        iterative_trainer.init_particle_filter_fwd()
        rev_obs = torch.flip(obs_samples, dims=[1])
        rev_times = []
        smooth_particles, _ , _, _ = iterative_trainer.generate_filtered_particles(rev_times, rev_obs)

    
    plot_xlim = [config.dataset.plot_x_min, config.dataset.plot_x_max]
    plot_ylim = [config.dataset.plot_y_min, config.dataset.plot_y_max]
    plot_trajectory_video(plot_folder, config.filter.n_steps - rand_select - 1, torch.flip(obs_ts.cpu(), dims=[1]),
                          smooth_particles, xlim=plot_xlim, ylim=plot_ylim, filename='trained_model')

 
if __name__ == '__main__':
    train_iterative()