import dataclasses
import logging
import os
from typing import Optional, Tuple, Dict, List
from LieCG import so13

import numpy as np
import torch.nn.functional
from LorentzMACE.tools import torch_geometric
from e3nn import o3
from torch.optim.swa_utils import AveragedModel, SWALR
from torch_ema import ExponentialMovingAverage

from LorentzMACE import data, tools, modules


@dataclasses.dataclass
class SubsetCollection:
    train: data.Configurations
    valid: data.Configurations
    tests: List[Tuple[str, data.Configurations]]


def get_dataset(downloads_dir: str, dataset: str, subset: Optional[str],
                split: Optional[int]) -> SubsetCollection:
    if dataset == 'tag_jet':
        ref_configs = data.load_tag_jet(directory=downloads_dir)
        train_size, valid_size = 10000, 100
        train_valid_configs = np.random.default_rng(1).choice(
            ref_configs, train_size + valid_size)
        train_configs, valid_configs = train_valid_configs[:
                                                           train_size], train_valid_configs[
                                                               train_size:]
        return SubsetCollection(train=train_configs,
                                valid=valid_configs,
                                tests=[])
    raise RuntimeError(f'Unknown dataset: {dataset}')


gate_dict = {
    'abs': torch.abs,
    'tanh': torch.tanh,
    'silu': torch.nn.functional.silu,
    'leakyRelu': torch.nn.functional.leaky_relu,
    'None': None,
}


def main() -> None:
    args = tools.build_default_arg_parser().parse_args()
    tag = tools.get_tag(name=args.name, seed=args.seed)

    # Setup
    tools.set_seeds(args.seed)
    tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir)
    logging.info(f'Configuration: {args}')
    device = tools.init_device(args.device)
    tools.set_default_dtype(args.default_dtype)

    # Data preparation
    collections = get_dataset(downloads_dir=args.downloads_dir,
                              dataset=args.dataset,
                              subset=args.subset,
                              split=args.split)
    logging.info(
        f'Number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, '
        f'tests={[len(test_configs) for name, test_configs in collections.tests]}'
    )

    # Particles types
    num_elements = 2
    # yapf: disable
    logging.info(f'Number of particle types: {num_elements}')


    train_loader = torch_geometric.dataloader.DataLoader(
        dataset=[data.AtomicData.from_config(c, cutoff_in=args.r_max_in, cutoff_out=args.r_max_out) for c in collections.train],
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
    )
    valid_loader = torch_geometric.dataloader.DataLoader(
        dataset=[data.AtomicData.from_config(c, cutoff_in=args.r_max_in, cutoff_out=args.r_max_out) for c in collections.valid],
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
    )

    loss_fn: torch.nn.Module
    if args.loss == 'classification':
        loss_fn = modules.ClassificationLoss()
    logging.info(loss_fn)

    if args.compute_avg_num_neighbors:
        args.avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader)
    logging.info(f'Average number of neighbors: {args.avg_num_neighbors:.3f}')

    # Build model
    logging.info('Building model')
    model_config = dict(
        r_max=args.r_max_out,
        num_bessel=args.num_radial_basis,
        num_polynomial_cutoff=args.num_cutoff_basis,
        max_ell=args.max_ell,
        interaction_cls=modules.interaction_classes[args.interaction],
        num_interactions=args.num_interactions,
        num_elements=num_elements,
        hidden_irreps=so13.Lorentz_Irreps(args.hidden_irreps),
        avg_num_neighbors=args.avg_num_neighbors,
    )

    model: torch.nn.Module


    if args.model == 'LorentzBOTNet':
        model = modules.LorentzBOTNet(
            **model_config,
            gate=gate_dict[args.gate],
            interaction_cls_first=modules.interaction_classes[args.interaction_first],
            MLP_irreps=so13.Lorentz_Irreps(args.MLP_irreps),
            readout_irreps=so13.Lorentz_Irreps(args.readout_irreps),
            radial_basis_cls=modules.basis_classes[args.radial_basis],
            use_cutoff=args.use_cutoff,
            scale=args.scale,
            device=args.device,
        )


    elif args.model == 'SingleReadoutModel':
        model = modules.SingleReadoutModel(
            **model_config,
            gate=gate_dict[args.gate],
            interaction_cls_first=modules.interaction_classes[args.interaction_first],
            MLP_irreps=so13.Lorentz_Irreps(args.MLP_irreps),
            readout_irreps=so13.Lorentz_Irreps(args.readout_irreps),
            radial_basis_cls=modules.basis_classes[args.radial_basis],
            use_cutoff=args.use_cutoff,
            scale=args.scale,
            device=args.device,
        )
    elif args.model == 'LorentzMACEModel':
        model = modules.LorentzMACEModel(
            **model_config,
            gate=gate_dict[args.gate],
            interaction_cls_first=modules.interaction_classes[args.interaction_first],
            MLP_irreps=so13.Lorentz_Irreps(args.MLP_irreps),
            readout_irreps=so13.Lorentz_Irreps(args.readout_irreps),
            correlation=args.correlation,
            radial_basis_cls=modules.basis_classes[args.radial_basis],
            use_cutoff=args.use_cutoff,
            scale=args.scale,
            device=args.device,
        )


    model.to(device)

    # Optimizer
    param_options = dict(
        params=model.parameters(),
        lr=args.lr,
        amsgrad=args.amsgrad,
        weight_decay=args.weight_decay,
    )

    optimizer: torch.optim.Optimizer
    if args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(**param_options)
    else:
        optimizer = torch.optim.Adam(**param_options)

    logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + '_train')

    if args.scheduler == 'ExponentialLR':
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.lr_scheduler_gamma)
    elif args.scheduler == 'ReduceLROnPlateau':
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=args.lr_factor,
                                                                  patience=args.scheduler_partience)
    elif args.scheduler == 'CosineAnnealingLR':
        T_max = (len(collections.train) / args.batch_size) * args.max_num_epochs
        logging.info(f'Using cosine scheduler with period {T_max}')
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=T_max,
                                                                  eta_min=args.eta_min)
    checkpoint_handler = tools.CheckpointHandler(directory=args.checkpoints_dir, tag=tag, keep=args.keep_checkpoints)

    start_epoch = 0
    if args.restart_latest:
        opt_start_epoch = checkpoint_handler.load_latest(state=tools.CheckpointState(model, optimizer, lr_scheduler),
                                                         device=device)
        if opt_start_epoch is not None:
            start_epoch = opt_start_epoch

    swa: Optional[tools.SWAContainer] = None
    if args.swa:
        swa = tools.SWAContainer(
            model=AveragedModel(model),
            scheduler=SWALR(optimizer=optimizer, swa_lr=args.lr, anneal_epochs=1, anneal_strategy='linear'),
            start=10,
        )
        logging.info(f'Using stochastic weight averaging (after {swa.start} epochs)')

    ema: Optional[ExponentialMovingAverage] = None
    if args.ema:
        ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay)

    logging.info(model)
    logging.info(f'Number of parameters: {tools.count_parameters(model)}')
    logging.info(f'Optimizer: {optimizer}')

    tools.train(
        model=model,
        loss_fn=loss_fn,
        train_loader=train_loader,
        valid_loader=valid_loader,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        checkpoint_handler=checkpoint_handler,
        eval_interval=args.eval_interval,
        start_epoch=start_epoch,
        max_num_epochs=args.max_num_epochs,
        logger=logger,
        patience=args.patience,
        device=device,
        swa=swa,
        ema=ema,
    )

    if swa:
        logging.info('Building averaged model')
        # Update batch norm statistics for the swa_model at the end (actually we are not using bn)
        torch.optim.swa_utils.update_bn(train_loader, swa.model)
        model = swa.model.module
    else:
        epoch = checkpoint_handler.load_latest(state=tools.CheckpointState(model, optimizer, lr_scheduler),
                                               device=device)
        logging.info(f'Loaded model from epoch {epoch}')

    # Evaluation on test datasets
    logging.info('Computing metrics for training, validation, and test sets')
    logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + '_eval')
    for name, subset in [('train', collections.train), ('valid', collections.valid)] + collections.tests:
        data_loader = torch_geometric.dataloader.DataLoader(
            dataset=[data.AtomicData.from_config(config, cutoff_in=args.r_max_in, cutoff_out=args.r_max_out) for config in subset],
            batch_size=args.batch_size,
            shuffle=False,
            drop_last=False,
        )

        loss, metrics = tools.evaluate(model, loss_fn=loss_fn, data_loader=data_loader, device=device)
        logging.info(f"Subset '{name}': "
                     f'loss={loss:.4f}, '
                     f'mae_e={metrics["mae_e"] * 1000:.3f} meV, '
                     f'mae_f={metrics["mae_f"] * 1000:.3f} meV/Ang, '
                     f'rmse_e={metrics["rmse_e"] * 1000:.3f} meV, '
                     f'rmse_f={metrics["rmse_f"] * 1000:.3f} meV/Ang, '
                     f'q95_e={metrics["q95_e"] * 1000:.3f} meV, '
                     f'q95_f={metrics["q95_f"] * 1000:.3f} meV/Ang')
        metrics['subset'] = name
        metrics['name'] = args.name
        metrics['seed'] = args.seed
        logger.log(metrics)

    # Save entire model
    model_path = os.path.join(args.checkpoints_dir, tag + '.model')
    logging.info(f'Saving model to {model_path}')
    torch.save(model, model_path)

    logging.info('Done')


if __name__ == '__main__':
    main()