import os

import time
import ml_collections
import wandb

from jax import random
import jax.numpy as jnp
import orbax.checkpoint as ocp


from src.model import CVit
from src.utils import create_optimizer, create_train_state, create_checkpoint_manager,  create_train_step, create_eval_step
from src.data_pipeline import create_dataloaders, batch_parser

from dr_pipeline import create_dr_datasets


def train_and_evaluate(config: ml_collections.ConfigDict):
    # Initialize model
    model = CVit(**config.model)
    # Create learning rate schedule and optimizer
    lr, tx = create_optimizer(config)
    state = create_train_state(config, model, tx)

    # Create checkpoint manager
    ckpt_path = os.path.join(os.getcwd(), config.wandb.name, "ckpt")
    ckpt_mngr = create_checkpoint_manager(config.saving, ckpt_path)

    # Create train and eval step functions
    train_step_fn = create_train_step(model)
    eval_step_fn = create_eval_step(model)

    # Create dataloaders
    train_dataset, test_dataset, mean, std = create_dr_datasets(filename=config.dataset.path,
                                                     prev_steps=config.dataset.prev_steps,
                                                     pred_steps=config.dataset.pred_steps,
                                                     train_samples=config.dataset.train_samples,
                                                     test_samples=config.dataset.test_samples,
                                                     downsample_factor=config.dataset.downsample_factor)

    train_iter, test_iter = create_dataloaders(config.dataset, train_dataset, test_dataset)

    # Initialize W&B
    wandb_config = config.wandb
    wandb.init(project=wandb_config.project, name=wandb_config.name)

    # Train
    rng = random.PRNGKey(config.seed + 1)
    start_time = time.time()
    last_loss = 1.0
    for step in range(config.training.num_steps):
        rng, _ = random.split(rng)

        batch = next(train_iter)
        batch = batch_parser(batch, rng, config.dataset.num_query_points)
        state, loss = train_step_fn(state, batch)

        # Evaluate model
        if step % config.logging.log_interval == 0:

            l2_error_list = []
            smse_list = []
            for _ in range(config.logging.eval_steps):
                batch = next(test_iter)
                batch = batch_parser(batch)
                l2_error, smse = eval_step_fn(state, batch)
                l2_error_list.append(l2_error)
                smse_list.append(smse)

            l2_error = jnp.array(l2_error_list).mean()
            smse = jnp.array(smse_list).mean()

            log_dict = {'loss': loss, 'l2_error': l2_error, 'smse': smse, 'lr': lr(state.step)}
            wandb.log(log_dict, step)
            end_time = time.time()
            print("step: {}, loss: {:.3e}, test error: {:.3e}, test smse: {:.3e}, time: {:.3e}".format(step, loss,
                                                                                                       l2_error, smse,
                                                                                                       end_time - start_time))
            start_time = end_time

            # if loss blowup, restart training from the last checkpoint
            if loss >= last_loss * 10:
                print("Loss blowup detected, reverting to last checkpoint")
                state = ckpt_mngr.restore(ckpt_mngr.latest_step(), args=ocp.args.StandardRestore(state))
                # if revert to last checkpoint, skip the rest of the loop
                continue

        # Save checkpoints
        if (step % config.saving.save_interval == 0 and loss < 2.0 * last_loss) or step == config.training.num_steps - 1:
            ckpt_mngr.save(step, args=ocp.args.StandardSave(state))
            last_loss = loss










