'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  tools.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
from typing import Optional

import numpy as np

from src.data.data import AtomicStructures


def get_energy_shift_per_atom(structures: AtomicStructures, 
                              n_species: int,
                              atomic_energies: Optional[np.ndarray] = None,
                              compute_regression_shift: bool = True) -> np.ndarray:
    """Computes energy shift parameters for each atomic species in the data set. If atomic 
    energies are provided, they are subtracted from the total energy before computing 
    the mean and the regression solution.

    Args:
        structures (AtomicStructures): Atomic structures in the data set.
        n_species (int): Total number of atom species/types.
        atomic_energies (np.ndarray, optional): Atomic energies. Defaults to None.

    Returns:
        np.ndarray: Atomic energy shift parameters.
    """
    if atomic_energies is None:
        atomic_energies = np.zeros(n_species)
    else:
        assert len(atomic_energies) == n_species
        atomic_energies = np.array(atomic_energies)
    
    if compute_regression_shift:
        energy_sum = 0.0
        atoms_sum = 0
        for structure in structures:
            atomic_energies_sum = sum(atomic_energies.take(structure.species))
            energy_sum += (structure.energy - atomic_energies_sum)
            atoms_sum += structure.n_atoms
        energy_per_atom_mean = energy_sum / atoms_sum
        print(energy_per_atom_mean)
        
        # compute regression from (n_per_species_1, ...) to energy - atomic_energies_sum - n_atoms * energy_per_atom_mean
        # the reason that we subtract energy_per_atom_mean and atomic_energies_sum is that we don't want to regularize 
        # the mean and atomic energies
        XTy = np.zeros(n_species)
        XTX = np.zeros(shape=(n_species, n_species), dtype=np.int64)
        for structure in structures:
            Z_counts = np.zeros(n_species, dtype=np.int64)
            for z in structure.species:
                Z_counts[int(z)] += 1
            atomic_energies_sum = sum(atomic_energies.take(structure.species))
            err = structure.energy - atomic_energies_sum - structure.n_atoms * energy_per_atom_mean
            XTy += err * Z_counts
            XTX += Z_counts[None, :] * Z_counts[:, None]

        lam = 1.0  # regularization, should be a float such that the integer matrix XTX is converted to float
        regression_shift = np.linalg.solve(XTX + lam * np.eye(n_species), XTy)
        print(regression_shift)
        return regression_shift + energy_per_atom_mean + atomic_energies
    else:
        return atomic_energies


def get_forces_rms(structures: AtomicStructures,
                   n_species: int) -> np.ndarray:
    """Computes the root mean square of forces across atomic structures in the data set.

    Args:
        structures (AtomicStructures): Atomic structures in the data set.
        n_species (int): Total number of atom species.

    Returns:
        np.ndarray: Root mean square of forces across atomic structures in the data set.
    """
    sq_forces_sum = 0.0
    atoms_sum = 0
    
    for structure in structures:
        sq_forces_sum += (structure.forces ** 2).sum()
        atoms_sum += structure.n_atoms
    forces_rms = np.sqrt(sq_forces_sum / atoms_sum / 3.0) * np.ones(n_species)
    
    # set zeros to ones
    forces_rms[forces_rms == 0.0] = 1.0
            
    return forces_rms


def get_avg_n_neighbors(structures: AtomicStructures,
                        r_cutoff: float) -> float:
    """Computes the average number of neighbors in the data set. Use `skin=0` during training 
    or inference if the neighbor list should not be re-computed.
    
    Adapted from MACE (https://github.com/ACEsuit/mace/blob/main/mace/modules/utils.py).

    Args:
        structures (AtomicStructures): Atomic structures in the data set.
        r_cutoff (float): Cutoff radius for computing the neighbor list.

    Returns:
        float: Average number of neighbors in the data set.
    """
    n_neighbors = []

    for structure in structures:
        # set skin value to zero because otherwise the average number of neighbors is incorrect
        idx_i, _ = structure.get_edge_index(r_cutoff, skin=0.0)
        _, counts = np.unique(idx_i, return_counts=True)
        n_neighbors.extend(counts)
    
    return np.mean(n_neighbors)
