'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  loss_fns.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import torch

from typing import Dict, List, Any, Callable

from src.utils.torch_geometric import Data
from src.utils.misc import recursive_detach


class LossFunction:
    """Computes loss functions on batches and then aggregates them to a single value.
    """
    def compute_batch_loss(self,
                           results: Dict[str, torch.Tensor],
                           batch: Data) -> torch.Tensor:
        """Computes loss value for a single batch.

        Args:
            results (Dict[str, torch.Tensor]): Results dictionary; see 'calculators.py'.
            batch (Data): Atomic data graph.

        Returns:
            torch.Tensor: Loss value on batch.
        """
        raise NotImplementedError()

    def reduce_overall(self,
                       losses: List[Any],
                       n_structures_total: int,
                       n_atoms_total: int) -> torch.Tensor:
        """Combines batch losses into a single loss value.

        Args:
            losses (List[Any]): List of losses evaluated on batches.
            n_structures_total (int): Total number of structures in the data set.
            n_atoms_total (int): Total number of atoms in the data set.

        Returns:
            torch.Tensor: Loss value on the data set.
        """
        raise NotImplementedError()

    def get_output_variables(self) -> List[str]:
        """

        Returns:
            List[str]: List of keys (e.g. ['forces']) that should be present in the 'results' 
                       parameter of compute_batch_loss(). This is used in 'calculators.py' 
                       to determine which quantities should be computed.
        """
        raise NotImplementedError()


class SingleLossFunction(LossFunction):
    """Computes loss using a single property.

    Args:
        output_variable (str): Property key (e.g. ['forces']).
        batch_loss (Callable[[torch.Tensor, torch.Tensor, Data], torch.Tensor]): Loss function which is evaluated on a batch.
        overall_reduction (Callable[[torch.Tensor, int, int], torch.Tensor], optional): Function that implements an overall 
                                                                                        reduction of losses evaluated on batches. 
                                                                                        Defaults to None.
    """
    def __init__(self,
                 output_variable: str,
                 batch_loss: Callable[[torch.Tensor, torch.Tensor, Data], torch.Tensor],
                 overall_reduction: Callable[[torch.Tensor, int, int], torch.Tensor]=None):
        self.output_variable = output_variable
        self.batch_loss = batch_loss
        self.overall_reduction = overall_reduction

    def compute_batch_loss(self,
                           results: Dict[str, torch.Tensor],
                           batch: Data) -> torch.Tensor:
        y_pred = results[self.output_variable]
        y = getattr(batch, self.output_variable)
        return self.batch_loss(y, y_pred, batch=batch)

    def reduce_overall(self,
                       losses: List[Any],
                       n_structures_total: int,
                       n_atoms_total: int) -> torch.Tensor:
        collated = torch.cat([l if len(l.shape) > 0 else l[None] for l in losses], dim=0)
        if self.overall_reduction is None:
            return collated
        return self.overall_reduction(collated, n_structures_total, n_atoms_total)

    def get_output_variables(self) -> List[str]:
        return [self.output_variable]


class WeightedSumLossFunction(LossFunction):
    """Computes weighted sum of losses.

    Args:
        loss_fns (List[LossFunction]): List of loss functions.
        weights (List[float]): List of weights.
        overall_reduction (Callable[[torch.Tensor, int, int], torch.Tensor], optional): Currently is unused by this class. 
                                                                                        Defaults to None.
    """
    def __init__(self,
                 loss_fns: List[LossFunction],
                 weights: List[float],
                 overall_reduction: Callable[[torch.Tensor, int, int], torch.Tensor]=None):
        self.loss_fns = loss_fns
        self.weights = weights
        self.overall_reduction = overall_reduction

    def get_output_variables(self) -> List[str]:
        return sum([l.get_output_variables() for l in self.loss_fns], [])

    def compute_batch_loss(self,
                           results: Dict[str, torch.Tensor],
                           batch: Data):
        return [l.compute_batch_loss(results, batch) for l in self.loss_fns]

    def reduce_overall(self,
                       losses: List[Any],
                       n_structures_total: int,
                       n_atoms_total: int) -> torch.Tensor:
        weighted = [self.weights[i] * self.loss_fns[i].reduce_overall(
            [l[i] for l in losses], n_structures_total, n_atoms_total) for i in range(len(self.loss_fns))]
        return sum(weighted)


class TotalLossTracker:
    """Accumulates losses during training.

    Args:
        loss_fn (LossFunction): Loss function.
        requires_grad (bool): If False, loss values are detached from the 
                              computational graph to save RAM.
    """
    def __init__(self,
                 loss_fn: LossFunction,
                 requires_grad: bool):
        self.loss_fn = loss_fn
        self.batch_losses = []
        self.requires_grad = requires_grad

    def append_batch(self,
                     results: Dict[str, torch.Tensor],
                     batch: Data):
        """Computes and stores the loss on a single batch.

        Args:
            results (Dict[str, torch.Tensor]): Results dictionary; see 'calculators.py'.
            batch (Data): Atomic data graph.
        """
        batch_loss = self.loss_fn.compute_batch_loss(results, batch)
        if not self.requires_grad:
            # use detach() to allow freeing the memory of the computation graph attached to results
            batch_loss = recursive_detach(batch_loss)
        self.batch_losses.append(batch_loss)

    def compute_final_result(self,
                             n_structures_total: int,
                             n_atoms_total: int) -> torch.Tensor:
        """Computes the overall loss.

        Args:
            n_structures_total (int): Total number of structures in the data set.
            n_atoms_total (int): Total number of atoms in the data set.

        Returns:
            torch.Tensor: Overall loss.
        """
        return self.loss_fn.reduce_overall(self.batch_losses, n_structures_total, n_atoms_total)


TRIU_IDXS = torch.combinations(torch.arange(0, 3), r=2, with_replacement=True)
TRIU_IDXS_FLAT = 3 * TRIU_IDXS[:, 0] + TRIU_IDXS[:, 1]


def get_triu(tensor: torch.Tensor) -> torch.Tensor:
    """Computes flattened upper-triangular representation of a 3x3 tensor.

    Args:
        tensor (torch.Tensor): 3x3 tensor.

    Returns: 
        torch.Tensor: Flattened upper-triangular representation of a 3x3 tensor
    """
    return tensor.view(-1, 9)[:, TRIU_IDXS_FLAT]


METRICS = dict(
    energy_sae=SingleLossFunction('energy',
                                  lambda y, y_pred, batch: (y-y_pred).abs().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum()),
    energy_mae=SingleLossFunction('energy',
                                  lambda y, y_pred, batch: (y-y_pred).abs().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum() / n_structures),
    energy_sse=SingleLossFunction('energy',
                                  lambda y, y_pred, batch: (y-y_pred).square().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum()),
    energy_mse=SingleLossFunction('energy',
                                  lambda y, y_pred, batch: (y-y_pred).square().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum() / n_structures),
    energy_rmse=SingleLossFunction('energy',
                                   lambda y, y_pred, batch: (y - y_pred).square().sum(),
                                   lambda losses, n_structures, n_atoms: (losses.sum() / n_structures).sqrt()),
    energy_l4=SingleLossFunction('energy',
                                 lambda y, y_pred, batch: ((y-y_pred) ** 4).sum(),
                                 lambda losses, n_structures, n_atoms: (losses.sum() / n_structures) ** 0.25),
    energy_maxe=SingleLossFunction('energy',
                                   lambda y, y_pred, batch: (y - y_pred).abs().max(),
                                   lambda losses, n_structures, n_atoms: losses.max()),
    forces_sae=SingleLossFunction('forces',
                                  lambda y, y_pred, batch: (y-y_pred).abs().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum()),
    forces_mae=SingleLossFunction('forces',
                                  lambda y, y_pred, batch: (y-y_pred).abs().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum() / (3*n_atoms)),
    forces_sse=SingleLossFunction('forces',
                                  lambda y, y_pred, batch: (y-y_pred).square().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum()),
    forces_mse=SingleLossFunction('forces',
                                  lambda y, y_pred, batch: (y-y_pred).square().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum() / (3*n_atoms)),
    forces_rmse=SingleLossFunction('forces',
                                   lambda y, y_pred, batch: (y-y_pred).square().sum(),
                                   lambda losses, n_structures, n_atoms: (losses.sum() / (3*n_atoms)).sqrt()),
    forces_l4=SingleLossFunction('forces',
                                 lambda y, y_pred, batch: ((y-y_pred) ** 4).sum(),
                                 lambda losses, n_structures, n_atoms: (losses.sum() / (3*n_atoms)) ** 0.25),
    forces_maxe=SingleLossFunction('forces',
                                   lambda y, y_pred, batch: (y - y_pred).abs().max(),
                                   lambda losses, n_structures, n_atoms: losses.max()),
    stress_sae=SingleLossFunction('stress',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).abs().sum(),
                                   lambda losses, n_structures, n_atoms: losses.sum()),
    stress_mae=SingleLossFunction('stress',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).abs().sum(),
                                   lambda losses, n_structures, n_atoms: losses.sum() / (6*n_structures)),
    stress_sse=SingleLossFunction('stress',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).square().sum(),
                                   lambda losses, n_structures, n_atoms: losses.sum()),
    stress_mse=SingleLossFunction('stress',
                                  lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).square().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum() / (6*n_structures)),
    stress_rmse=SingleLossFunction('stress',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).square().sum(),
                                   lambda losses, n_structures, n_atoms: (losses.sum() / (6*n_structures)).sqrt()),
    stress_l4=SingleLossFunction('stress',
                                 lambda y, y_pred, batch: ((get_triu(y)-get_triu(y_pred)) ** 4).sum(),
                                 lambda losses, n_structures, n_atoms: (losses.sum() / (6*n_structures)) ** 0.25),
    stress_maxe=SingleLossFunction('stress',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).abs().max(),
                                   lambda losses, n_structures, n_atoms: losses.max()),
    virials_sae=SingleLossFunction('virials',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).abs().sum(),
                                   lambda losses, n_structures, n_atoms: losses.sum()),
    virials_mae=SingleLossFunction('virials',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).abs().sum(),
                                   lambda losses, n_structures, n_atoms: losses.sum() / (6*n_structures)),
    virials_sse=SingleLossFunction('virials',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).square().sum(),
                                   lambda losses, n_structures, n_atoms: losses.sum()),
    virials_mse=SingleLossFunction('virials',
                                  lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).square().sum(),
                                  lambda losses, n_structures, n_atoms: losses.sum() / (6*n_structures)),
    virials_rmse=SingleLossFunction('virials',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).square().sum(),
                                   lambda losses, n_structures, n_atoms: (losses.sum() / (6*n_structures)).sqrt()),
    virials_l4=SingleLossFunction('virials',
                                 lambda y, y_pred, batch: ((get_triu(y)-get_triu(y_pred)) ** 4).sum(),
                                 lambda losses, n_structures, n_atoms: (losses.sum() / (6*n_structures)) ** 0.25),
    virials_maxe=SingleLossFunction('virials',
                                   lambda y, y_pred, batch: (get_triu(y)-get_triu(y_pred)).abs().max(),
                                   lambda losses, n_structures, n_atoms: losses.max()),
    energy_per_atom_sae=SingleLossFunction('energy',
                                           lambda y, y_pred, batch: ((y-y_pred)/batch.n_atoms).abs().sum(),
                                           lambda losses, n_structures, n_atoms: losses.sum()),
    energy_per_atom_mae=SingleLossFunction('energy',
                                           lambda y, y_pred, batch: ((y-y_pred)/batch.n_atoms).abs().sum(),
                                           lambda losses, n_structures, n_atoms: losses.sum() / n_structures),
    energy_per_atom_sse=SingleLossFunction('energy',
                                           lambda y, y_pred, batch: ((y-y_pred)/batch.n_atoms).square().sum(),
                                           lambda losses, n_structures, n_atoms: losses.sum()),
    energy_per_atom_mse=SingleLossFunction('energy',
                                           lambda y, y_pred, batch: ((y-y_pred)/batch.n_atoms).square().sum(),
                                           lambda losses, n_structures, n_atoms: losses.sum() / n_structures),
    energy_per_atom_rmse=SingleLossFunction('energy',
                                            lambda y, y_pred, batch: ((y-y_pred)/batch.n_atoms).square().sum(),
                                            lambda losses, n_structures, n_atoms: (losses.sum() / n_structures).sqrt()),
    energy_per_atom_l4=SingleLossFunction('energy',
                                          lambda y, y_pred, batch: (((y-y_pred)/batch.n_atoms) ** 4).sum(),
                                          lambda losses, n_structures, n_atoms: (losses.sum() / n_structures) ** 0.25),
    energy_per_atom_maxe=SingleLossFunction('energy',
                                            lambda y, y_pred, batch: ((y-y_pred)/batch.n_atoms).abs().max(),
                                            lambda losses, n_structures, n_atoms: losses.max()),
    energy_by_sqrt_atoms_sse=SingleLossFunction('energy',
                                                lambda y, y_pred, batch: ((y-y_pred).square()/batch.n_atoms).sum(),
                                                lambda losses, n_structures, n_atoms: losses.sum()),
    energy_by_sqrt_atoms_mse=SingleLossFunction('energy',
                                                lambda y, y_pred, batch: ((y-y_pred).square()/batch.n_atoms).sum(),
                                                lambda losses, n_structures, n_atoms: losses.sum() / n_structures),
    stress_by_sqrt_atoms_sse=SingleLossFunction('stress',
                                                lambda y, y_pred, batch: ((get_triu(y)-get_triu(y_pred)).square()/batch.n_atoms[:, None]).sum(),
                                                lambda losses, n_structures, n_atoms: losses.sum()),
    stress_per_atom_mae=SingleLossFunction('stress',
                                           lambda y, y_pred, batch: ((get_triu(y)-get_triu(y_pred))/batch.n_atoms[:, None]).abs().sum(),
                                           lambda losses, n_structures, n_atoms: losses.sum() / (6*n_structures)),
    stress_per_atom_rmse=SingleLossFunction('stress',
                                            lambda y, y_pred, batch: ((get_triu(y)-get_triu(y_pred))/batch.n_atoms[:, None]).square().sum(),
                                            lambda losses, n_structures, n_atoms: (losses.sum() / (6*n_structures)).sqrt()),
    stress_per_atom_mse=SingleLossFunction('stress',
                                           lambda y, y_pred, batch: ((get_triu(y)-get_triu(y_pred))/batch.n_atoms[:, None]).square().sum(),
                                           lambda losses, n_structures, n_atoms: losses.sum() / (6*n_structures)),
)

def config_to_loss(config: dict) -> LossFunction:
    """Unwraps loss types and other details in config file to respective loss functions.

    Args:
        config (dict): Dictionary containing loss types and other details.

    Returns:
        LossFunction: Loss function defined by config file.
    """
    t = config['type']
    if t == 'weighted_sum':
        losses = [config_to_loss(c) for c in config['losses']]
        return WeightedSumLossFunction(losses, config['weights'])
    else:
        if t in METRICS:
            return METRICS[t]
        else:
            raise ValueError(f'Not implemented loss "{t}"! Available losses: {list(METRICS.keys())}.')
