"""Training script"""

import time
import gc
import os.path as osp
import random
import torch
from omegaconf import DictConfig
import hydra
import wandb
from data.gp_sample_function import GPSampleFunction
from typing import Dict
from utils.paths import get_split_dataset_path, get_exp_path, get_log_filepath
from utils.log import get_logger, Averager, log_fn
from utils.dataclasses import (
    ExperimentConfig,
    PredictionConfig,
    OptimizationConfig,
    SamplerConfig,
    LossConfig,
    TrainConfig,
    DataConfig,
)
from utils.config import (
    build_model,
    build_optimizer,
    build_scheduler,
    build_dataloader,
    load_checkpoint,
    save_checkpoint,
)
from utils.data import (
    has_nan_or_inf,
    set_all_seeds,
    get_dx_dy_datapaths,
    MultiFileHDF5Dataset,
)
from policy_learning import select_next_query, compute_policy_loss
from prediction import predict_with_metrics, prepare_prediction_dataset
from wandb_wrapper import init as wandb_init, save_artifact

FREQ_PRINT = 500
FREQ_SAVE = 1000
FREQ_PLOT = 5000
FREQ_VAL = 2000
FREQ_LOG_GRAD = 1000
FREQ_SAVE_EXTRA = 5000
PLOT_PER_N_STEPS = -1
VAL_NUM_OPT_TASKS = 4
CTX_SIZE_BURNIN_RATIO = 0.8
ZERO_MEAN = True


@hydra.main(version_base=None, config_path="configs", config_name="train_v1.yaml")
def main(config: DictConfig):
    torch.set_default_dtype(torch.float32)
    torch.set_default_device("cpu")

    # Setup configurations
    exp_cfg = ExperimentConfig(**config.experiment)
    pred_cfg = PredictionConfig(**config.prediction)
    opt_cfg = OptimizationConfig(**config.optimization)
    loss_cfg = LossConfig(**config.loss)
    train_cfg = TrainConfig(**config.train)
    sampler_cfg = SamplerConfig(**config.sampler)
    data_cfg = DataConfig(**config.data)

    # Setup logging
    log_filename = get_log_filepath(
        group_name=exp_cfg.model_name, expid=exp_cfg.expid, prefix=exp_cfg.task
    )
    logger = get_logger(file_name=log_filename, mode="w")
    log = log_fn(logger)
    log(f"log_filename:\t{log_filename}")

    # Setup wandb
    if exp_cfg.log_to_wandb:
        log(f"wandb configuration:{config.wandb}\n")
        wandb_init(config=config, **config.wandb)

    # Setup experiment path
    exp_path = get_exp_path(model_name=exp_cfg.model_name, expid=exp_cfg.expid)
    log(f"exp_path:\t{exp_path}")

    train(
        exp_path=exp_path,
        model_kwargs=config.model,  # TODO change to dataclass
        exp_cfg=exp_cfg,
        opt_cfg=opt_cfg,
        pred_cfg=pred_cfg,
        train_cfg=train_cfg,
        sampler_cfg=sampler_cfg,
        loss_cfg=loss_cfg,
        data_cfg=data_cfg,
        log=log,
        val_cfg=config.validation,
        reinit_optimizer=config.reinit_optimizer,
        **config.logging,
        **config.extra
    )



def get_d_T_B(
    num_cur, num_total, B, interval: float = [0.0, 0.01, 0.02, 0.05, 1.0], warmup: bool = False
):
    assert num_cur <= num_total, f"{num_cur} > {num_total}"
    scale_factor = num_cur / num_total
    for i, thres in enumerate(interval):
        if scale_factor <= thres:
            scale_factor = i
            break
    scale_factor = max(0, scale_factor - 1)  # 0, 1, 2

    ds = [32, 64, 128, 256]
    ts = [30, 50, 100, 100]
    Bs= [32, 8, 4, 4]

    if warmup: 
        d = ds[scale_factor]
        T = min(ts[scale_factor], d)  
        B = min(Bs[scale_factor], B)
    else: 
        d = ds[-1]
        T = min(ts[-1], d)
        B = min(Bs[-1], B)
    
    return d, T, B


def run_optimization_step(
    model,
    sampler_cfg,
    opt_cfg,
    exp_cfg,
    loss_cfg,
    data_cfg,
    epoch,
    num_burnin_epochs,
    num_after_burnin_epochs,
    log,
):
    # This entire block is now a function
    # Optimization forward (model + loss)
    x_dim = random.choice(sampler_cfg.x_dim_list)
    y_dim = random.choice(sampler_cfg.y_dim_list)

    d, T, B = get_d_T_B(
        num_cur=epoch - num_burnin_epochs,
        num_total=num_after_burnin_epochs,
        B=opt_cfg.batch_size,
    )

    gp_sample_function = GPSampleFunction(
        batch_size=B,
        x_dim=x_dim,
        y_dim=y_dim,
        max_x_dim=model.max_x_dim,  # use model attributes
        max_y_dim=model.max_y_dim,  # use model attributes
        dim_scatter_mode=data_cfg.dim_scatter_mode,
        d=d,
        sampler_config=sampler_cfg,
        use_grid_sampling=opt_cfg.use_grid_sampling,
        use_factorized_policy=opt_cfg.use_factorized_policy,
        num_samples=opt_cfg.num_samples,
        device=exp_cfg.device,
        online_generate=True,
        restore_full_dim_later=True,
        zero_mean=ZERO_MEAN,
    )

    x_ctx, y_ctx, _, _ = gp_sample_function.init(
        num_initial_points=opt_cfg.num_initial_points,
        regret_type=opt_cfg.regret_type,
        compute_hv=False,
        compute_regret=False,
        device=exp_cfg.device,
    )

    B = x_ctx.shape[0]

    neg_regrets = torch.empty((T, B), device=exp_cfg.device)
    log_probs = torch.empty((T, B), device=exp_cfg.device)
    if loss_cfg.entropy_coeff > 0.0:
        entropies = torch.empty((T, B), device=exp_cfg.device)
    
    for t in range(1, T + 1):
        query_results = select_next_query(
            model=model,
            input_bounds=sampler_cfg.x_range,
            x_ctx=x_ctx,
            y_ctx=y_ctx,
            x_mask=gp_sample_function.x_mask,
            y_mask=gp_sample_function.y_mask,
            d=d,
            t=t,
            T=T,
            use_grid_sampling=opt_cfg.use_grid_sampling,
            use_fixed_query_set=True,
            use_factorized_policy=opt_cfg.use_factorized_policy,
            use_time_budget=opt_cfg.use_time_budget,
            q_chunk=gp_sample_function.chunks,
            q_chunk_mask=gp_sample_function.chunk_mask,
            evaluate=False,
            read_cache=opt_cfg.read_cache,
            write_cache=opt_cfg.write_cache,
        )
        indices = query_results[1]
        logp = query_results[2]
        entropy = query_results[3]

        indices_exp = indices.unsqueeze(-1).unsqueeze(-1)
        x_ctx, y_ctx, _, regret = gp_sample_function.step(
            index_new=indices_exp,
            x_ctx=x_ctx,
            y_ctx=y_ctx,
            compute_hv=False,
            compute_regret=True,
            regret_type=opt_cfg.regret_type,
        )

        neg_regret = -torch.from_numpy(regret).to(exp_cfg.device).requires_grad_(False)
        neg_regrets[t - 1] = neg_regret
        log_probs[t - 1] = logp
        if loss_cfg.entropy_coeff > 0.0:
            entropies[t - 1] = entropy
        else: 
            entropy = entropy.detach()

    loss_acq, step_rewards = compute_policy_loss(
        step_rewards=neg_regrets,
        log_probs=log_probs,
        use_cumulative_r=loss_cfg.use_cumulative_rewards,
        discount_factor=loss_cfg.discount_factor,
        batch_standardize=loss_cfg.batch_standardize,
        clip_rewards=loss_cfg.clip_rewards,
        sum_over_tra=loss_cfg.sum_over_trajectories,
        batch_first=False,
        entropy=entropies if loss_cfg.entropy_coeff > 0.0 else None,
        entropy_coeff=loss_cfg.entropy_coeff,
    )

    step_reward_mean = step_rewards.mean().detach().item()
    final_step_reward_mean = step_rewards[:, -1].mean().detach().item()
    final_step_entropy_mean = entropy.mean().detach().item()

    del gp_sample_function, x_ctx, y_ctx, query_results

    return (
        loss_acq,
        step_reward_mean,
        final_step_reward_mean,
        final_step_entropy_mean,
        d,
        B,
        T,
    )


def train(
    exp_path: str,
    model_kwargs: Dict,
    exp_cfg: ExperimentConfig,
    opt_cfg: OptimizationConfig,
    pred_cfg: PredictionConfig,
    train_cfg: TrainConfig,
    sampler_cfg: SamplerConfig,
    loss_cfg: LossConfig,
    data_cfg: DataConfig,
    log: callable = print,
    freq_print: int = FREQ_PRINT,
    freq_save: int = FREQ_SAVE,
    freq_save_extra: int = FREQ_SAVE_EXTRA,
    freq_val: int = FREQ_VAL,
    freq_log_grad: int = FREQ_LOG_GRAD,
    val_cfg: DictConfig = None,
    reinit_optimizer: bool = False,
    num_after_burnin_epochs = None
):
    # Load checkpoint
    ckpt = load_checkpoint(
        exp_path=exp_path,
        device=exp_cfg.device,
        resume=exp_cfg.resume,
    )

    epoch = ckpt.get("epoch", -1)
    # seed = ckpt.get("seed", exp_cfg.seed)
    seed = exp_cfg.seed # TODO dataloader is not permuted when resuming from checkpoint; so always use the provided seed for data randomness
    model_state_dict = ckpt.get("model", {})
    optimizer_state_dict = ckpt.get("optimizer", {})
    scheduler_state_dict = ckpt.get("scheduler", {})

    # Set random seed
    set_all_seeds(seed)
    log(f"seed:\t{seed}, last_epoch:\t{epoch}\n")
    n_gpus = torch.cuda.device_count()  # the number of gpus

    # Setup dataloader
    max_x_dim = model_kwargs.get("max_x_dim", -1)
    max_y_dim = model_kwargs.get("max_y_dim", -1)
    sampler_cfg.assert_dims_within_limits(max_x_dim=max_x_dim, max_y_dim=max_y_dim)

    datapaths, _ = get_dx_dy_datapaths(
        path=get_split_dataset_path(split=exp_cfg.mode),
        x_dim_list=sampler_cfg.x_dim_list,
        y_dim_list=sampler_cfg.y_dim_list,
    )
    log("[train] datapaths:\n" + "\n".join(f"{dp}" for dp in datapaths))
    log("Creating MultiFileHDF5Dataset...")
    dataset = MultiFileHDF5Dataset(
        file_paths=datapaths,
        max_x_dim=max_x_dim,
        max_y_dim=max_y_dim,
        zero_mean=ZERO_MEAN,
        standardize=sampler_cfg.standardize,
        range_scale=sampler_cfg.y_range,
    )
    dataset_size = len(dataset)
    dataset_size_repeated = dataset_size * train_cfg.num_repeat_data

    # Compute number of epochs based on dataset size and batch size
    num_total_epochs = dataset_size_repeated // pred_cfg.batch_size
    num_burnin_epochs = int(train_cfg.burnin_ratio * num_total_epochs)
    if num_after_burnin_epochs is not None:
        # Pre-defined number of epochs after burn-in
        num_after_burnin_epochs = num_after_burnin_epochs
        num_total_epochs = num_burnin_epochs + num_after_burnin_epochs
    else: 
        num_after_burnin_epochs = num_total_epochs - num_burnin_epochs
    num_context_size_burnin_epochs = int(num_burnin_epochs * CTX_SIZE_BURNIN_RATIO)
    log(
        f"Compute epochs for dataset of size {dataset_size} and repeated {train_cfg.num_repeat_data} times, "
        f"with burnin ratio {train_cfg.burnin_ratio}.\n"
        f"num_total_epochs:\t{num_total_epochs}, num_burnin_epochs:\t{num_burnin_epochs}, "
        f"num_context_size_burnin_epochs:\t{num_context_size_burnin_epochs}"
    )

    # Setup model
    model = build_model(
        model_name=exp_cfg.model_name,
        model_kwargs=model_kwargs,
        use_factorized_policy=opt_cfg.use_factorized_policy,
    )
    model = model.to(exp_cfg.device)
    model_param_count = sum(p.numel() for p in model.parameters())
    log(f"Model parameters:\t{model_param_count}")

    if n_gpus > 1:
        # Simple data parallelism
        raise NotImplementedError

    if model_state_dict:
        log("Loading model state dict from checkpoint.")
        model.load_state_dict(model_state_dict, strict=False)

    if exp_cfg.log_to_wandb:
        wandb.watch(model, log="gradients", log_freq=freq_log_grad)

    # Setup optimizer
    optimizer = build_optimizer(
        model=model,
        optimizer_type=train_cfg.optimizer_type,
        lr=train_cfg.lr1,
        weight_decay=train_cfg.weight_decay,
    )
    if not reinit_optimizer and optimizer_state_dict:
        log("Loading optimizer state dict from checkpoint.")
        optimizer.load_state_dict(optimizer_state_dict)

    # Setup scheduler
    if exp_cfg.resume and reinit_optimizer:
        assert (
            epoch > num_burnin_epochs
        ), "Can only re-initialize optimizer after burn-in phase."
        num_sche_epoch = num_total_epochs - epoch
        log(f"Re-build scheduler with {num_sche_epoch} epochs.")
        log(f"Current epoch: {epoch}")
        log(f"Scheduler warmup steps: {train_cfg.num_warmup_steps}")
        scheduler = build_scheduler(
            optimizer=optimizer,
            scheduler_type=train_cfg.scheduler_type,
            num_training_steps=num_sche_epoch,  # NOTE
            last_epoch=-1,
            num_warmup_steps=0, # NOTE
        )
    else:
        scheduler = build_scheduler(
            optimizer=optimizer,
            scheduler_type=train_cfg.scheduler_type,
            num_training_steps=num_total_epochs, 
            last_epoch=epoch,
            num_warmup_steps=train_cfg.num_warmup_steps,
        )
        if not reinit_optimizer and scheduler_state_dict:
            log("Loading scheduler state dict from checkpoint.")
            scheduler.load_state_dict(scheduler_state_dict)

    # Start training
    ravg = Averager()
    repeat_round_start = epoch // (dataset_size // pred_cfg.batch_size)
    for dataset in [dataset]:
        for repeat_round in range(repeat_round_start, train_cfg.num_repeat_data):
            # Setup dataloader
            dataloader = build_dataloader(
                dataset=dataset,
                batch_size=pred_cfg.batch_size,
                split=exp_cfg.mode,
                device=exp_cfg.device,
                num_workers=train_cfg.num_workers,
                prefetch_factor=train_cfg.prefetch_factor,
            )
            dataloader_iter = iter(dataloader)

            # Start training for current repeat round
            while epoch < num_total_epochs:
                epoch += 1

                if epoch == num_burnin_epochs:
                    log(
                        f"Start policy learning at epoch {epoch}; "
                        f"Re-build optimizer and scheduler with lr2: {train_cfg.lr2}"
                    )
                    optimizer = build_optimizer(
                        model=model,
                        optimizer_type=train_cfg.optimizer_type,
                        lr=train_cfg.lr2,
                        weight_decay=train_cfg.weight_decay,
                    )

                t1 = time.time()

                model.train()
                optimizer.zero_grad()
                # Load batch
                batch = next(dataloader_iter, None)
                if batch is None:
                    log(f"[repeat_round={repeat_round}]: finished.")
                    break
                x, y, valid_x_counts, valid_y_counts = batch
                if has_nan_or_inf(x, "x", log) or has_nan_or_inf(y, "y", log):
                    continue

                x = x.to(exp_cfg.device)  # [B, N, max_x_dim]
                y = y.to(exp_cfg.device)  # [B, N, max_y_dim]
                valid_x_counts = valid_x_counts.to(exp_cfg.device)  # [B]
                valid_y_counts = valid_y_counts.to(exp_cfg.device)  # [B]

                # Prepare dataset for prediction task
                x, y, x_mask, y_mask, nc = prepare_prediction_dataset(
                    x=x,
                    y=y,
                    valid_x_counts=valid_x_counts,
                    valid_y_counts=valid_y_counts,
                    dim_scatter_mode=data_cfg.dim_scatter_mode,
                    min_nc=pred_cfg.min_nc,
                    max_nc=pred_cfg.max_nc,
                    warmup=epoch <= num_context_size_burnin_epochs,
                    sigma=data_cfg.sigma,
                )

                # Ranomly split context and target
                perm = torch.randperm(x.shape[1], device=x.device)
                idx1, idx2 = perm[:nc], perm[nc:]
                xc = x[:, idx1]
                yc = y[:, idx1]
                xt = x[:, idx2]
                yt = y[:, idx2]

                # Prediction forward (model + loss)
                loss_pre, mse_mean, _, _ = predict_with_metrics(
                    model=model,
                    x_ctx=xc,
                    y_ctx=yc,
                    x_tar=xt,
                    y_tar=yt,
                    x_mask=x_mask,
                    y_mask=y_mask,
                    compute_nll=True,
                    compute_mse=True,
                    compute_ktt=False,
                    reduce_nll=True,
                    reduce_mse=True,
                    read_cache=pred_cfg.read_cache,
                    write_cache=pred_cfg.write_cache,
                )

                # Optimization forward (model + loss)
                if epoch >= num_burnin_epochs:
                    # Free up prediction memory
                    del (
                        perm,
                        x,
                        y,
                        x_mask,
                        y_mask,
                        xc,
                        yc,
                        xt,
                        yt,
                        valid_x_counts,
                        valid_y_counts,
                    )
                    gc.collect()
                    torch.cuda.empty_cache()

                    (
                        loss_acq,
                        step_reward_mean,
                        final_step_reward_mean,
                        final_step_entropy_mean,
                        d,
                        B,
                        T,
                    ) = run_optimization_step(
                        model,
                        sampler_cfg,
                        opt_cfg,
                        exp_cfg,
                        loss_cfg,
                        data_cfg,
                        epoch,
                        num_burnin_epochs,
                        num_after_burnin_epochs,
                        log,
                    )
                else:
                    loss_acq = torch.tensor(0.0).detach()
                    step_reward_mean = 0.0
                    final_step_reward_mean = 0.0
                    final_step_entropy_mean = 0.0

                # Possible ablation study: can adjust the weight placed over the prediction loss
                w = loss_cfg.loss_weight * (loss_cfg.pred_ratio > random.random())
                loss = w * loss_pre + loss_acq

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()

                # Logging
                epoch_time = time.time() - t1
                mse_dict = {
                    f"train/mse_{j}": mse_mean[j].detach().item()
                    for j in range(mse_mean.shape[0])
                }
                log_dict = {
                    "train/loss_pre": loss_pre.detach().item(),
                    "train/loss_acq": loss_acq.detach().item(),
                    "train/loss": loss.detach().item(),
                    "train/learning_rate": optimizer.param_groups[0]["lr"],
                    "train/step_reward": step_reward_mean,
                    "train/step_reward_final": final_step_reward_mean,
                    "train/step_entropy_final": final_step_entropy_mean,
                    "train/epoch_time": epoch_time,
                    "train/query_set_size": d if epoch >= num_burnin_epochs else 0,
                    "train/opt_batch_size": B if epoch >= num_burnin_epochs else 0,
                    "train/T": T if epoch >= num_burnin_epochs else 0,
                    **mse_dict,
                }
                ravg.batch_update(log_dict)
                if exp_cfg.log_to_wandb:
                    wandb.log(log_dict)

                # Print training status
                if epoch % freq_print == 0:
                    line = (
                        f"[epoch {epoch} / {num_total_epochs}] "
                        f"lr: {optimizer.param_groups[0]['lr']:.3e} "
                        f"[train] "
                        f"{ravg.info()}"
                    )
                    log(line)
                    ravg.reset()

                del loss, loss_acq, loss_pre
                gc.collect()
                torch.cuda.empty_cache()
                
                # Save checkpoint
                if (epoch > 0 and epoch % freq_save == 0) or (
                    epoch == num_total_epochs - 1
                ):
                    log(f"Saving checkpoint at epoch {epoch} to {exp_path}")
                    ckpt, ckpt_filepath = save_checkpoint(
                        exp_path=exp_path,
                        model=model,
                        epoch=epoch,
                        seed=seed,
                        dataloader=dataloader,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        ckpt_name="ckpt.tar",
                        n_gpus=n_gpus,
                    )

                    if epoch % freq_save_extra == 0:
                        epoch_ckpt_filepath = osp.join(
                            exp_path, f"ckpt_epoch_{epoch}.tar"
                        )
                        torch.save(ckpt, epoch_ckpt_filepath)

                    if exp_cfg.log_to_wandb:
                        save_artifact(
                            run=wandb.run,
                            local_path=ckpt_filepath,
                            name="checkpoint",
                            type="model",
                            log=log,
                        )

if __name__ == "__main__":
    main()
