"""
Training and inference pipelines for MDShortcut diffusion models.

This module provides the core training and inference functionality for molecular dynamics
diffusion models, including:
- Training pipeline with optional shortcut consistency training
- Inference pipeline for generating new material samples
- Model saving and loading utilities
- Forward trajectory generation for visualization

The pipelines handle:
- Gradient clipping and optimization
- TensorBoard logging and progress tracking
- Sample generation with configurable guidance
- Trajectory saving for analysis
- Notification services for long-running experiments
"""
import os
from random import random as rand

import numpy as np
import torch
from ase.io import write as write_ase
from tqdm import trange, tqdm
import json

import utils
from data import save_dirs


def train(model, scheduler, train_dataloader, loss_func, lr, num_epochs, clip_grad_norm=None, weight_decay=0,
          start_epoch=0, save=None, shortcut=None, orig_model=None, init_r_cut=None):
    """Training pipeline for diffusion models with optional shortcut consistency training.

    Trains the diffusion model using the standard denoising objective with optional
    shortcut consistency regularization. Supports model compilation, gradient clipping,
    TensorBoard logging, and periodic model checkpointing.

    Args:
        model (nn.Module): The diffusion denoiser model to train.
        scheduler: Diffusion scheduler implementing forward_process and reverse_step methods.
        train_dataloader (DataLoader): DataLoader providing training samples.
        loss_func: Loss function for computing denoising loss.
        lr (float): Initial learning rate for the Adam optimizer.
        num_epochs (int): Total number of training epochs.
        clip_grad_norm (float, optional): Maximum gradient norm for clipping. Defaults to None.
        weight_decay (float, optional): Weight decay for optimizer. Defaults to 0.
        start_epoch (int, optional): Starting epoch (for resumed training). Defaults to 0.
        save (dict, optional): Save configuration with keys:
            - 'name' (str): Name for saving logs and model checkpoints
            - 'epoch' (int): Frequency of checkpoint saves
        shortcut (dict, optional): Shortcut consistency training configuration with keys:
            - 'enabled' (bool): Whether to use shortcut consistency loss
            - 'shortcut_per' (float): Fraction of batches to apply shortcut loss
            - 'reverse_step_params' (dict): Parameters for reverse steps
        orig_model (nn.Module, optional): Uncompiled model for saving. Defaults to None.
        init_r_cut (float, optional): Initial cutoff radius for neighbor lists. Defaults to None.
    """
    if init_r_cut is not None:
        for sample in tqdm(train_dataloader.dataset, desc='Creating initial neighborlists'):
            sample.update_edges(init_r_cut)
    enable_shortcut = shortcut.get('enabled', False) if shortcut else False

    save_forward_trajectory(scheduler, train_dataloader, save)

    optimizer = torch.optim.Adam(list(model.parameters()), lr=lr, weight_decay=weight_decay)
    model.train()

    # Prepare the log file.
    os.makedirs(os.path.dirname(f"{save_dirs['log']}/{save['name']}"), exist_ok=True)
    tb_logger = utils.TensorBoardLogger(os.path.join(save_dirs['log'], save['name'])) if save else None

    desc = 'Avg loss %.6f'
    utils.send_notification("MDShortcut", f"Training started for {save['name']}")
    with trange(start_epoch, num_epochs, desc=desc % 0) as bar:
        for epoch in bar:
            loss_values = []
            for batch_i, clean_sample in enumerate(tqdm(train_dataloader, leave=False)):
                t = scheduler.gen_random_t(clean_sample)
                noisy_sample, (pos_tgt, el_tgt) = scheduler.forward_process(clean_sample, t)
                pred_pos, pred_els = model(noisy_sample, t, torch.zeros_like(t))
                loss, _ = loss_func(
                    pos_tgt=pos_tgt,
                    pred_pos=pred_pos,
                    el_tgt=el_tgt,
                    pred_els=pred_els
                )

                if enable_shortcut and rand() < shortcut['shortcut_per']:
                    rev_step_params = shortcut.get('reverse_step_params', {})
                    with torch.no_grad():
                        dt = torch.rand_like(t) * t
                        target_prev_sample, _ = scheduler.reverse_step(noisy_sample, t, dt/2, **rev_step_params)
                        target_prev_sample, _ = scheduler.reverse_step(target_prev_sample, t - dt/2, dt/2, **rev_step_params)
                    pred_prev_sample, _ = scheduler.reverse_step(noisy_sample, t, dt, **rev_step_params)
                    shortcut_loss, _ = loss_func(
                        pos_tgt=target_prev_sample.get_positions(),
                        pred_pos=pred_prev_sample.get_positions(),
                        el_tgt=target_prev_sample.get_element_emb(),
                        pred_els=pred_prev_sample.get_element_emb()
                    )
                    loss += shortcut_loss
                    if tb_logger is not None:
                        tb_logger.log_scalar('train/shortcut_loss', shortcut_loss.item(), epoch * len(train_dataloader) + batch_i)

                optimizer.zero_grad()
                loss.backward()
                if clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), clip_grad_norm)
                optimizer.step()
                loss_values.append(loss.detach().item())

                if tb_logger is not None:
                    tb_logger.log_scalar('train/batch_loss', loss.item(), epoch * len(train_dataloader) + batch_i)

            epoch_loss = np.mean(loss_values)
            bar.set_description(desc % epoch_loss)
            if tb_logger:
                tb_logger.log_scalar('train/epoch_loss', epoch_loss, epoch)

            if save is not None and save.get('epoch', None) is not None:
                if (epoch + 1) % save['epoch'] == 0:
                    save_model(model, save['name'], orig_model, epoch + 1)
                    if tb_logger:
                        tb_logger.log_model_params(model, epoch)
                    utils.send_notification("MDShortcut", f"Epoch {epoch + 1} of {num_epochs} completed for {save['name']}", 2)
    if save is not None:
        save_model(model, save['name'], orig_model)

    if tb_logger:
        tb_logger.close()
    utils.send_notification("MDShortcut", f"Training completed for {save['name']}", 5)


@torch.no_grad()
def infer(model, uncond_model, scheduler, infer_dataloader, n_steps, cond_w=0.0,
          save=None, guidance_fn=None, num_cand=16, non_stochastic=False, init_r_cut=None, final_step=True):
    """Inference pipeline for generating new material samples using diffusion models.

    Generates new material samples by sampling from noise and running the reverse diffusion
    process. Supports classifier-free guidance, custom guidance functions, and trajectory saving
    for analysis and visualization.

    Args:
        model (nn.Module): The trained diffusion denoiser model.
        uncond_model (nn.Module): Unconditional model for classifier-free guidance.
        scheduler: Diffusion scheduler implementing reverse_process method.
        infer_dataloader (DataLoader): DataLoader providing initial conditions/templates.
        n_steps (int): Number of reverse diffusion steps for inference.
        cond_w (float, optional): Classifier-free guidance weight. Defaults to 0.0.
        save (dict, optional): Save configuration with keys:
            - 'name' (str): Name for saving generated samples
            - 'trajectories' (dict): Trajectory saving config with keys:
                - 'enabled' (bool): Whether to save full trajectories
                - 'batches' (list): Batch indices to save trajectories for
        guidance_fn (callable, optional): Additional guidance function. Defaults to None.
        num_cand (int, optional): Number of candidates for guidance selection. Defaults to 16.
        non_stochastic (bool, optional): Whether to use deterministic inference. Defaults to False.
        init_r_cut (float, optional): Initial cutoff radius for neighbor lists. Defaults to None.
        final_step (bool, optional): Whether to take a final denoising step. Defaults to True.
    
    Returns:
        None: Results are saved to files if save configuration is provided.
    """
    model.eval()
    if uncond_model is not None:
        uncond_model.eval()

    save_name = save.get('name') if save else None
    traj_config = save.get('trajectories', {}) if save else {}
    enable_trajectories = traj_config.get('enabled', False)
    save_traj_batches = traj_config.get('batches', [0]) if enable_trajectories else []

    if save_name is not None:
        save_dir = f"{save_dirs['infer']}/{save_name}"
        os.makedirs(save_dir, exist_ok=True)

    all_inferred_samples = []
    all_sample_properties = []
    utils.send_notification("MDShortcut", f"Inference {save_name} started")
    for batch_i, clean_sample in tqdm(enumerate(infer_dataloader), desc='Infering',
                                      total=len(infer_dataloader)):
        sample = scheduler.gen_random_sample(clean_sample)
        sample.set_init_r_cut(init_r_cut)
        sample_properties = [{key: val.cpu().item() if val is not None else None for key, val in s.properties.items()} for s in sample.samples]

        trajectory = scheduler.reverse_process(sample, n_steps, cond_w=cond_w, guidance_fn=guidance_fn, num_cand=num_cand, non_stochastic=non_stochastic, final_step=final_step)

        # save trajectories if required
        if save_name is not None and enable_trajectories and (save_traj_batches == 'all' or batch_i in save_traj_batches):
            ase_traj = [[at.to_ase_atoms() for at in batch.samples]
                        for batch in trajectory]
            for batch in ase_traj:
                for at in batch:
                    at.wrap()
            for i_sample in range(sample.get_batch_size()):
                traj = [s[i_sample] for s in ase_traj]
                write_ase(os.path.join(save_dir, f'traj-{len(all_inferred_samples)+i_sample:05d}.extxyz'), traj)

        # Restart denoising for charge unbalanced samples if required
        final_sample = trajectory[-1]

        final_sample_ase = [at.to_ase_atoms() for at in final_sample.samples]
        for at in final_sample_ase:
            at.wrap()

        if save_name is not None:
            file_path = os.path.join(save_dir, 'inferred.extxyz')
            prop_path = os.path.join(save_dir, 'given_properties.json')
            if batch_i == 0:
                write_ase(file_path, final_sample_ase, format='extxyz')
            else:
                write_ase(file_path, final_sample_ase, format='extxyz', append=True)
            all_sample_properties.extend(sample_properties)
            with open(prop_path, 'w') as f:
                json.dump(all_sample_properties, f)

        # collect inferred samples
        all_inferred_samples.extend(final_sample_ase)

    utils.send_notification("MDShortcut", f"Inference {save_name} completed")

def save_model(model, save_name, uncompiled_model=None, epoch=None):
    """Save the trained model parameters to a cache file.

    Args:
        model (nn.Module): the neural network model to save.
        save_name (str): an indicator of the model's name, used as the directory for saving this model.
        epoch (int, optional): indication of the saved model's number of trained epoches.
    """
    os.makedirs(os.path.join(save_dirs['models'], save_name), exist_ok=True)
    model_path = os.path.join(save_dirs['models'], save_name, f'{epoch:05d}.pt' if epoch is not None else 'final.pt')
    if uncompiled_model:
        torch.save(uncompiled_model.state_dict(), model_path)
    else:
        torch.save(model.state_dict(), model_path)


def load_model(model, save_name, epoch=None):
    """Load model parameters from a saved checkpoint file.
    
    Args:
        model (nn.Module): Model to load parameters into.
        save_name (str): Name identifier of the saved model.
        epoch (int, optional): Specific epoch to load. If None, loads 'final.pt'.
    """
    model_name = f'{epoch:05d}.pt' if epoch is not None else 'final.pt'
    model_path = os.path.join(save_dirs['models'], save_name, model_name)
    model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))
    print(f"Loaded model from {model_path}")


@torch.no_grad()
def save_forward_trajectory(scheduler, train_dataloader, save, n_steps=200):
    """Save a forward process trajectory with linearly discretized timesteps.

    Args:
        scheduler: diffusion scheduler with forward_process method.
        train_dataloader (DataLoader): data iterator to get a clean sample from.
        save (dict): dictionary containing save configuration with keys:
            - 'name': str, name for saving the trajectory
        n_steps (int, optional): number of forward diffusion steps. Defaults to 200.
    """
    if save is None:
        return

    clean_sample = next(iter(train_dataloader))
    timesteps = np.linspace(scheduler.t_min, scheduler.t_max, n_steps + 1)

    forward_trajectory = []

    for t_val in tqdm(timesteps, desc='Creating forward trajectory', leave=False):
        t = torch.full((clean_sample.get_batch_size(),), t_val).float().to(clean_sample.get_positions().device)
        noisy_sample, _ = scheduler.forward_process(clean_sample, t)
        forward_trajectory.append(noisy_sample)

    forward_save_dir = f"{save_dirs['infer']}/{save['name']}/forward_trajectory"
    os.makedirs(forward_save_dir, exist_ok=True)

    ase_traj = [[at.to_ase_atoms() for at in batch.samples] for batch in forward_trajectory]
    for batch in ase_traj:
        for at in batch:
            at.wrap()
    for i_sample in range(clean_sample.get_batch_size()):
        traj = [s[i_sample] for s in ase_traj]
        write_ase(os.path.join(forward_save_dir, f'{i_sample:02d}.extxyz'), traj)
