"""
MDShortcut: Molecular Dynamics Diffusion Models for Material Design

This module serves as the main entry point for training and inference with diffusion models
for molecular dynamics simulations and material design. It provides a command-line interface
for configuring and running experiments with different model architectures, schedulers, and
loss functions.

The system supports:
- EGNN-based denoising models for atomic structure generation
- SDE and Flow-based noise schedulers for diffusion processes
- Flexible loss functions with configurable position and element weights
- Guidance functions for conditional generation (e.g., charge balance)
- Model compilation for improved performance

Usage:
    python src/main.py -s config.yaml -g 0 [--compile]
"""
import json
import os
from argparse import ArgumentParser

import torch
import yaml
from torch.utils.data import DataLoader

import data
import pipeline
from models.denoisers import egnn
from models.modules.loss import MaterialLoss
from models.modules.schedule import MaterialSDESchedule, MaterialFlowSchedule
from models.modules.guidance import ChargeBalanceGuide
from dotenv import load_dotenv


def main():
    setting, device, compile_model = parse_args_and_setting()
    model, orig_model, scheduler, loss_func, guidance_fn, uncond_model = init_components(setting, device, compile_model)

    # Training process.
    start_epoch = setting['load']['epoch'] if setting['load']['enabled'] else 0
    if setting.get('train', None) is not None and setting['train']['enabled']:
        train_dataloader = build_dataloader(setting['train']['data'], device)
        pipeline.train(
            model=model, orig_model=orig_model, 
            scheduler=scheduler, train_dataloader=train_dataloader,
            start_epoch=start_epoch, loss_func=loss_func,
            **setting['train']['params']
        )

    # Load inferring dataset and dataloader.
    if setting.get('infer', None) is not None and setting['infer']['enabled']:
        infer_dataloader = build_dataloader(setting['infer'].get('data', setting['train']['data']), device)
        pipeline.infer(
            model=model, uncond_model=uncond_model,
            scheduler=scheduler, infer_dataloader=infer_dataloader,
            guidance_fn=guidance_fn,
            **setting['infer']['params']
        )


def init_components(setting, device, compile_model):
    """Initializes the model, scheduler, and loss function components based on settings.
    
    Args:
        setting (dict): Dictionary containing model configurations.
            Expected structure:
            model:
              name: str  # Model type ('egnn_denoiser')
              params: dict  # Model-specific parameters
            scheduler:
              name: str  # Scheduler type ('sde' or 'flow')
              params: dict  # Scheduler-specific parameters
            loss:
              params:
                norm_type: str  # Type of norm to use (e.g., 'l1', 'l2', 'huber')
                position_weight: float  # Weight for position loss
                element_weight: float  # Weight for element loss
            load:
              enabled: bool  # Whether to load pretrained model
              name: str  # Name of model to load
              epoch: int  # Epoch to load from
            load_uncond:
              enabled: bool  # Whether to load unconditioned model
              name: str  # Name of unconditioned model to load
              epoch: int  # Epoch to load from
            guidance:
              name: str  # Guidance function name (e.g., 'charge_balance')
              params: dict  # Guidance-specific parameters
        device (str): Device to load the model to ('cuda:0' or 'cpu')
        compile_model (bool): Whether to compile the model using torch.compile()
    
    Returns:
        tuple:
            - model (nn.Module): The initialized model
            - orig_model (nn.Module): Original uncompiled model (if compile_model=True) or None
            - scheduler (MaterialSDESchedule or MaterialFlowSchedule): The noise scheduler
            - loss_func (MaterialLoss): The loss function
            - guidance_fn (nn.Module): The guidance function or None
            - uncond_model (nn.Module): Unconditioned model or None
    """
    def _init_model(model_name, params):
        if model_name == 'egnn_denoiser':
            model = egnn.EgnnDenoiser(**params)
        else:
            raise NotImplementedError(f'Unknown model name: {model_name}')
        return model

    # Initialize denoiser model.
    model = _init_model(setting['model']['name'], setting['model']['params'])

    # Load model parameters.
    if setting.get('load', None) is not None and setting['load']['enabled']:
        pipeline.load_model(model, setting['load']['name'], epoch=setting['load']['epoch'])
    model = model.to(device)

    # Initialize the unconditioned model with same architecture but without properties.
    uncond_model = None
    if setting.get('load_uncond', None) is not None and setting['load_uncond']['enabled']:
        uncond_model_params = setting['model']['params'].copy()
        uncond_model_params['d_prop_embed'] = None
        uncond_model_params['properties'] = None
        uncond_model = _init_model(setting['model']['name'], uncond_model_params)
        pipeline.load_model(uncond_model, setting['load_uncond']['name'], epoch=setting['load_uncond']['epoch'])
        uncond_model = uncond_model.to(device)

    n_model_params = sum([x.numel() for x in model.parameters()])
    print(f'Model has {n_model_params} parameters')

    if compile_model:
        # Backup the uncompiled model for parameter caching.
        orig_model = model
        model = torch.compile(orig_model) #mode='reduce-overhead')  # 2x-3x performance on v100 gpu
    else:
        orig_model = None

    loss_func = MaterialLoss(**setting['loss']['params'])

    scheduler_name = setting['scheduler']['name']
    scheduler_params = setting['scheduler'].get('params', {})
    if scheduler_name == 'sde':
        scheduler = MaterialSDESchedule(model=model, uncond_model=uncond_model, **scheduler_params)
    elif scheduler_name == 'flow':
        scheduler = MaterialFlowSchedule(model=model, uncond_model=uncond_model, **scheduler_params)
    else:
        raise NotImplementedError(f'Unknown scheduler name: {scheduler_name}')

    guidance_fn = None
    if setting.get('guidance', None) is not None:
        guidance_fn_name = setting['guidance']['name']
        guidance_fn_params = setting['guidance'].get('params', {})
        if guidance_fn_name == 'charge_balance':
            guidance_fn = ChargeBalanceGuide(**guidance_fn_params)
        else:
            raise NotImplementedError(f'Unknown guidance function name: {guidance_fn_name}')
        print(f'Using guidance function {type(guidance_fn).__name__}')

    return model, orig_model, scheduler, loss_func, guidance_fn, uncond_model


def parse_args_and_setting():
    """Parses command line arguments and loads configuration settings.
    
    Returns:
        tuple:
            - setting (dict): Loaded configuration dictionary from YAML/JSON file
            - device (str): Device string ('cuda:N' or 'cpu')
            - compile_model (bool): Whether to compile the model with torch.compile()
    """
    parser = ArgumentParser()
    parser.add_argument('-s', '--setting',
                        help='path to the setting file to use', type=str, required=True)
    parser.add_argument('-g', '--cuda',
                        help='index of the cuda (GPU) device to use', type=int, default=0)
    parser.add_argument('--compile', help='wether to compile the model with torch.compile',
                        action='store_true', default=False)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(str(i) for i in range(torch.cuda.device_count()))
    os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'

    device = f'cuda:{args.cuda}' if torch.cuda.is_available() and args.cuda is not None else 'cpu'
    if device != 'cpu':
        torch.set_float32_matmul_precision('high')  # for better performance on some gpus
    print('device: ' + device)

    compile_model = args.compile
    if compile_model:
        print('compiling model')

    # Load the setting file. It can be either JSON file (*.json) or YAML file (*.yml or *.yaml).
    with open(args.setting, 'r') as fp:
        if args.setting.endswith('.json'):
            setting = json.load(fp)
        elif args.setting.endswith('.yml') or args.setting.endswith('.yaml'):
            setting = yaml.safe_load(fp)
        else:
            raise NotImplementedError('The settings file has an unsupported format.')

    return setting, device, compile_model


def build_dataloader(data_setting, device):
    """Builds a DataLoader for training or inference based on the data configuration.
    
    Args:
        data_setting (dict): Dictionary containing dataset and dataloader configurations.
            Expected structure:
            dataset:
                atom_src:
                    type: str  # Source format ('extxyz', '3D_empty')
                    params: dict  # Source-specific parameters
                property_src:
                    type: str  # Property source ('file', 'files', 'augment', 'empty')
                    params: dict  # Property-specific parameters
                target_density: float  # Target density for the material
                filter_charge_bal: bool  # Whether to filter for charge balance
            dataloader:
                batch_size: int  # Number of samples per batch
                num_workers: int  # Number of worker processes
                shuffle: bool  # Whether to shuffle the data
            collate: dict  # Optional collate function parameters
        device (str): Device to load the data to ('cuda:0' or 'cpu')
    
    Returns:
        DataLoader: PyTorch DataLoader configured with the specified dataset and parameters
    """
    dataset = data.MaterialDataset(**data_setting['dataset'])
    dataloader = DataLoader(dataset=dataset,
                            collate_fn=data.MaterialCollateFn(device, **data_setting.get('collate', {})),
                            **data_setting['dataloader'])
    return dataloader


if __name__ == '__main__':
    load_dotenv()
    main()
