'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  forward.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import os
from pathlib import Path

from typing import *

import yaml

import numpy as np

import torch
import torch.nn as nn

from src.data.data import AtomicStructures
from src.data.tools import get_avg_n_neighbors, get_energy_shift_per_atom, get_forces_rms

from src.nn.layers import LinearLayer, RescaledSiLULayer, ScaleShiftLayer
from src.nn.representations import CartesianMACE

from src.utils.torch_geometric import Data
from src.utils.misc import load_object, save_object


def build_model(atomic_structures: Optional[AtomicStructures] = None,
                **config: Any):
    """Builds feed-forward atomistic neural network from the config file.

    Args:
        atomic_structures (Optional[AtomicStructures], optional): Atomic structures, typically those from the training data sets, 
                                                                  used to compute scale and shift for model predictions as well as 
                                                                  the average number of neighbors. Defaults to None.
        
    Returns:
        ForwardAtomisticNetwork: Atomistic neural network.
    """
    torch.manual_seed(config['model_seed'])
    np.random.seed(config['model_seed'])
    
    # compute scale/shift parameters for the total energy
    # also, compute the average number of neighbors, i.e., the normalization factor for messages
    if atomic_structures is None:
        shift_params = np.zeros(config['n_species'])
        scale_params = np.ones(config['n_species'])
    else:
        shift_params = get_energy_shift_per_atom(atomic_structures, n_species=config['n_species'], atomic_energies=config['atomic_energies'],
                                                 compute_regression_shift=config['compute_regression_shift'])
        scale_params = get_forces_rms(atomic_structures, n_species=config['n_species'])
        if config['compute_avg_n_neighbors']:
            config['avg_n_neighbors'] = get_avg_n_neighbors(atomic_structures, config['r_cutoff']).item()

    # prepare (semi-)local atomic representation
    representation = CartesianMACE(**config)

    # prepare readouts
    readouts = nn.ModuleList([])
    for i in range(config['n_interactions']):
        if i == config['n_interactions'] - 1:
            layers = []
            for in_size, out_size in zip([config['n_hidden_feats']] + config['readout_MLP'],
                                        config['readout_MLP'] + [1]):
                layers.append(LinearLayer(in_size, out_size))
                layers.append(RescaledSiLULayer())
            readouts.append(nn.Sequential(*layers[:-1]))
        else:
            readouts.append(LinearLayer(config['n_hidden_feats'], 1))

    scale_shift = ScaleShiftLayer(shift_params=shift_params, scale_params=scale_params)

    return ForwardAtomisticNetwork(representation=representation, readouts=readouts, scale_shift=scale_shift, config=config)


class ForwardAtomisticNetwork(nn.Module):
    """An atomistic model based on feed-forward neural networks.

    Args:
        representation (nn.Module): Local atomic representation layer.
        readouts (nn.ModuleList): List of readout layers.
        scale_shift (nn.Module): Schale/shift transformation applied to the output, i.e., energy re-scaling and shift.
    """
    def __init__(self,
                 representation: nn.Module,
                 readouts: List[nn.Module],
                 scale_shift: nn.Module,
                 config: Dict[str, Any]):
        super().__init__()
        # all necessary modules
        self.representation = representation
        self.readouts = readouts
        self.scale_shift = scale_shift
        
        # provide config file to store it
        self.config = config

    def forward(self, graph: Data) -> torch.Tensor:
        """Computes atomic energies for the provided batch.

        Args:
            graph (Data): Atomic data graph.

        Returns:
            torch.Tensor: Atomic energies.
        """
        # compute representation (atom/node features)
        atom_feats_list = self.representation(graph)
        
        # apply a readout layer to each representation
        atomic_energies_list = []
        for atom_feats, readout in zip(atom_feats_list, self.readouts):
            atomic_energies_list.append(readout(atom_feats).squeeze(-1))
        atomic_energies = torch.sum(torch.stack(atomic_energies_list, dim=0), dim=0)
        
        # scale and shift the output
        atomic_energies = self.scale_shift(atomic_energies, graph)
        
        return atomic_energies

    def get_device(self) -> str:
        """Provides device on which calculations are performed.
        
        Returns: 
            str: Device on which calculations are performed.
        """
        return list(self.representation.parameters())[0].device

    def load_params(self, file_path: Union[str, Path]):
        """Loads network parameters from the file.

        Args:
            file_path (Union[str, Path]): Path to the file where network parameters are stored.
        """
        self.load_state_dict(load_object(file_path))

    def save_params(self, file_path: Union[str, Path]):
        """Stores network parameters to the file.

        Args:
            file_path (Union[str, Path]): Path to the file where network parameters are stored.
        """
        save_object(file_path, self.state_dict())

    def save(self, folder_path: Union[str, Path]):
        """Stores config and network parameters to the file.

        Args:
            folder_path (Union[str, Path]): Path to the folder where network parameters are stored.
        """
        (Path(folder_path) / 'config.yaml').write_text(str(yaml.safe_dump(self.config)))
        self.save_params(Path(folder_path) / 'params.pkl')

    @staticmethod
    def from_folder(folder_path: Union[str, Path]) -> 'ForwardAtomisticNetwork':
        """Loads model from the defined folder.

        Args:
            folder_path (Union[str, Path]): Path to the folder where network parameters are stored.

        Returns:
            ForwardAtomisticNetwork: The `ForwardAtomisticNetwork` object.
        """
        config = yaml.safe_load((Path(folder_path) / 'config.yaml').read_text())
        nn = build_model(None, **config)
        nn.load_params(Path(folder_path) / 'params.pkl')
        return nn


def find_last_ckpt(folder: Union[Path, str]):
    """Finds the last/best checkpoint to load the model from.

    Args:
        folder (Union[Path, str]): Path to the folder where checkpoints are stored.

    Returns:
        Last checkpoint to load the model from.
    """
    # if no checkpoint exists raise an error
    files = list(Path(folder).iterdir())
    if len(files) == 0:
        raise RuntimeError(f'Provided {folder} which is empty.')
    if len(files) >= 2:
        folders = [f for f in files if f.name.startswith('ckpt_')]
        file_epoch_numbers = [int(f.name[5:]) for f in folders]
        newest_file_idx = np.argmax(np.asarray(file_epoch_numbers))
        return folders[newest_file_idx]
    else:
        return files[0]


def load_model_from_folder(model_path: Union[str, Path],
                           key: str = 'best') -> ForwardAtomisticNetwork:
    """Loads model from the provided folder.

    Args:
        model_path (Union[str, Path]): Path to the model.
        key (str, optional): Choose which model to select, the best or last stored one: 'best' and 'log'. 
                             Defaults to 'best'.

    Returns:
        ForwardAtomisticNetwork: Atomistic model.
    """
    path = os.path.join(model_path, key)
    if not os.path.exists(path):
        raise RuntimeError(f'Provided path to the {key} model does not exist: {path=}.')
    return ForwardAtomisticNetwork.from_folder(find_last_ckpt(path))
