'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  strategies.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import time
from pathlib import Path
from typing import Dict, Optional, Union, Any

import numpy as np

import torch

import ase

from src.data.data import AtomicStructures, AtomicTypeConverter

from src.model.calculators import StructurePropertyCalculator
from src.model.forward import ForwardAtomisticNetwork, build_model, load_model_from_folder

from src.training.callbacks import FileLoggingCallback
from src.training.loss_fns import config_to_loss
from src.training.trainer import Trainer, eval_metrics

from src.utils.config import update_config
from src.utils.misc import save_object
from src.utils.torch_geometric import DataLoader


class TrainingStrategy:
    """Strategy for training interatomic potentials.

    Args:
        config (Optional[Dict[str, Any]], optional): Configuration file with parameters listed in 'utils/config.py'. 
                                                     The default parameters of 'utils/config.py' will be updated by 
                                                     those provided in 'config'. Defaults to None.
    """
    def __init__(self, config: Optional[Dict[str, Any]] = None):
        # Update config containing all parameters (including the model, training, fine-tuning, and evaluation)
        # We store all parameters in one config for simplicity. In the log and best folders the last training
        # config is stored and used when loading a model for inference.
        self.config = update_config(config.copy())

    def run(self,
            train_structures: AtomicStructures,
            valid_structures: AtomicStructures,
            folder: Union[str, Path],
            model_seed: Optional[int] = None) -> ForwardAtomisticNetwork:
        """Runs training using provided training and validation structures.

        Args:
            train_structures (AtomicStructures): Training structures.
            valid_structures (AtomicStructures): Validation structures.
            folder (Union[str, Path]): Folder where the trained model is stored.
            model_seed (int, optional): Random seed to initialize the atomistic model. Defaults to None.

        Returns:
            ForwardAtomisticNetwork: Trained atomistic model.
        """
        # define atomic type converter
        atomic_type_converter = AtomicTypeConverter.from_type_list(self.config['atomic_types'])
        
        # convert atomic numbers to type names
        train_structures = train_structures.to_type_names(atomic_type_converter, check=True)
        valid_structures = valid_structures.to_type_names(atomic_type_converter, check=True)
        
        # store the number of training and validation structures in config
        self.config['n_train'] = len(train_structures)
        self.config['n_valid'] = len(valid_structures)
        
        # build atomic data sets
        train_ds = train_structures.to_data(r_cutoff=self.config['r_cutoff'], 
                                            n_species=atomic_type_converter.get_n_type_names())
        valid_ds = valid_structures.to_data(r_cutoff=self.config['r_cutoff'], 
                                            n_species=atomic_type_converter.get_n_type_names())
        
        # update model seed if provided (can be used to re-run a calculation with a different seed) and build the model
        if model_seed is not None:
            self.config['model_seed'] = model_seed
        model = build_model(train_structures, n_species=atomic_type_converter.get_n_type_names(), **self.config)
            
        # define losses from config
        train_loss = config_to_loss(self.config['train_loss'])
        eval_losses = {l['type']: config_to_loss(l) for l in self.config['eval_losses']}
        early_stopping_loss = config_to_loss(self.config['early_stopping_loss'])
        
        # define callbacks to track training
        callbacks = [FileLoggingCallback()]
        
        # define model training
        trainer = Trainer(model, model_path=folder, callbacks=callbacks, 
                          lr=self.config['lr'], lr_factor=self.config['lr_factor'], scheduler_patience=self.config['scheduler_patience'], 
                          max_epoch=self.config['max_epoch'], save_epoch=self.config['save_epoch'], validate_epoch=self.config['valid_epoch'],
                          train_batch_size=min(self.config['train_batch_size'], len(train_structures)),
                          valid_batch_size=min(self.config['eval_batch_size'], len(valid_structures)),
                          train_loss=train_loss, eval_losses=eval_losses, early_stopping_loss=early_stopping_loss,
                          max_grad_norm=self.config['max_grad_norm'], device=self.config['device'], 
                          amsgrad=self.config['amsgrad'], weight_decay=self.config['weight_decay'],
                          ema=self.config['ema'], ema_decay=self.config['ema_decay'])
        
        # train the model
        trainer.fit(train_ds=train_ds, valid_ds=valid_ds)
        
        # return best models and move them to device
        model = load_model_from_folder(folder, key='best')
        model = model.to(self.config['device'])
        
        return model


class EvaluationStrategy:
    """Strategy for evaluating the performance of interatomic potentials.

    Args:
        config (Optional[Dict[str, Any]], optional): Configuration file with parameters listed in 'utils/config.py'. 
                                                     The default parameters of 'utils/config.py' will be updated by those 
                                                     provided in 'config'. Defaults to None.
    """
    def __init__(self, config: Optional[Dict[str, Any]] = None):
        # Update config containing all parameters (including the model, training, fine-tuning, and evaluation)
        # We store all parameters in one config for simplicity. In the log and best folders the last training
        # config is stored and used when loading a model for inference.
        self.config = update_config(config.copy())

    def run(self,
            model: ForwardAtomisticNetwork,
            test_structures: AtomicStructures,
            folder: Union[str, Path]) -> Dict[str, Any]:
        """Evaluates models using the provided test data set.

        Args:
            model (ForwardAtomisticNetwork): Atomistic model.
            test_structures (AtomicStructures): Test structures.
            folder (Union[str, Path]): Folder where the evaluation results are stored.

        Returns:
            Dict[str, Any]: Results dictionary containing test error metrics.
        """
        folder = Path(folder)
        
        # apply property and ensemble calculators to models
        calc = StructurePropertyCalculator(model, training=True).to(self.config['device'])
        
        # define atomic type converter
        atomic_type_converter = AtomicTypeConverter.from_type_list(self.config['atomic_types'])
        
        # convert atomic numbers to type names
        test_structures = test_structures.to_type_names(atomic_type_converter, check=True)
        
        # build atomic data sets
        test_ds = test_structures.to_data(r_cutoff=self.config['r_cutoff'],
                                          n_species=atomic_type_converter.get_n_type_names())
        
        # define losses from config
        eval_losses = {l['type']: config_to_loss(l) for l in self.config['eval_losses']}
        eval_output_variables = list(set(sum([l.get_output_variables() for l in eval_losses.values()], [])))
        
        # evaluate model on the test data
        use_gpu = self.config['device'].startswith('cuda')
        test_dl = DataLoader(test_ds, batch_size=self.config['eval_batch_size'], shuffle=False, drop_last=False,
                             pin_memory=use_gpu, pin_memory_device=self.config['device'] if use_gpu else '')
        
        # evaluate metrics on test data and store results as a .json file
        test_metrics = eval_metrics(calc=calc, dl=test_dl, eval_loss_fns=eval_losses,
                                    eval_output_variables=eval_output_variables, device=self.config['device'])
        save_object(folder / f'test_results.json', test_metrics['eval_losses'], use_json=True)
        
        return test_metrics['eval_losses']

    def run_on_configs(self,
                       model: ForwardAtomisticNetwork,
                       test_structures: AtomicStructures,
                       folder: Union[str, Path],
                       file_name: str):
        """Evaluates models on the provided test data set and stores configurations in an .xyz file with predicted total energies and atomic forces.

        Args:
            model (ForwardAtomisticNetwork): Atomistic model.
            test_structures (AtomicStructures): Test structures.
            folder (Union[str, Path]): Folder where the evaluation results are stored.
            file_name (str): Name of the .xyz file, which stores predicted total energies and atomic forces.
        """
        folder = Path(folder)
        
        atoms_list = [s.to_atoms() for s in test_structures]
        
        # apply property and ensemble calculators to models
        calc = StructurePropertyCalculator(model, training=True).to(self.config['device'])
        
        # define atomic type converter
        atomic_type_converter = AtomicTypeConverter.from_type_list(self.config['atomic_types'])
        
        # convert atomic numbers to type names
        test_structures = test_structures.to_type_names(atomic_type_converter, check=True)
        
        # build atomic data sets
        test_ds = test_structures.to_data(r_cutoff=self.config['r_cutoff'],
                                          n_species=atomic_type_converter.get_n_type_names())
        
        # define losses from config
        eval_losses = {l['type']: config_to_loss(l) for l in self.config['eval_losses']}
        eval_output_variables = list(set(sum([l.get_output_variables() for l in eval_losses.values()], [])))
        
        # evaluate model on the test data
        use_gpu = self.config['device'].startswith('cuda')
        test_dl = DataLoader(test_ds, batch_size=self.config['eval_batch_size'], shuffle=False, drop_last=False,
                             pin_memory=use_gpu, pin_memory_device=self.config['device'] if use_gpu else '')
        
        # Collect data
        energies_list = []
        forces_collection = []

        for batch in test_dl:
            results = calc(batch.to(self.config['device']), 
                           forces='forces' in eval_output_variables,
                           stress='stress' in eval_output_variables,
                           virials='virials' in eval_output_variables,
                           create_graph=True)

            energies_list.append(results['energy'].detach().cpu().numpy())
            
            forces = np.split(results['forces'].detach().cpu().numpy(), indices_or_sections=batch.ptr[1:], axis=0)
            forces_collection.append(forces[:-1])  # drop last as its empty

        energies = np.concatenate(energies_list, axis=0)
        forces_list = [forces for forces_list in forces_collection for forces in forces_list]
        
        assert len(atoms_list) == len(energies) == len(forces_list)

        # Store data in atoms objects
        for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)):
            atoms.calc = None  # crucial
            atoms.info['ICTP_energy'] = energy
            atoms.arrays['ICTP_forces'] = forces

        # Write atoms to output path
        ase.io.write(folder / file_name, images=atoms_list, format="extxyz")
    
    def measure_inference_time(self,
                               model: ForwardAtomisticNetwork,
                               test_structures: AtomicStructures,
                               folder: Union[str, Path],
                               batch_size: int = 100,
                               n_reps: int = 100) -> Dict[str, Any]:
        """Provide inference time for the defined batch size, i.e., atomic system size.

        Args:
            models (ForwardAtomisticNetwork): Atomistic model.
            test_structures (AtomicStructures): Test structures
            folder (Union[str, Path]): Folder where the results of the inference time measurement are stored.
            batch_size (int, optional): Evaluation batch size. Defaults to 100.
            n_reps (int, optional): Number of repetitions. Defaults to 100.

        Returns:
            Dict[str, Any]: Results dictionary.
        """
        folder = Path(folder)
        
        calc = StructurePropertyCalculator(model, training=False).to(self.config['device'])
        
        atomic_type_converter = AtomicTypeConverter.from_type_list(self.config['atomic_types'])
        
        test_structures = test_structures.to_type_names(atomic_type_converter, check=True)
        
        test_ds = test_structures.to_data(r_cutoff=self.config['r_cutoff'], n_species=atomic_type_converter.get_n_type_names())
        test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, drop_last=False)
        
        batch = next(iter(test_dl)).to(self.config['device'])
        
        # need to re-iterate before time measurement
        for _ in range(10):
            calc(batch, forces=True, features=False)
        
        # start with the time measurement
        if self.config['device'].startswith('cuda'):
            torch.cuda.synchronize()
        
        start_time = time.time()
        for _ in range(n_reps):
            calc(batch, forces=True, features=False)
            
        if self.config['device'].startswith('cuda'):
            torch.cuda.synchronize()
            
        end_time = time.time()
        
        to_save = {'total_time': end_time - start_time,
                   'time_per_repetition': (end_time - start_time) / n_reps,
                   'time_per_structure': (end_time - start_time) / n_reps / batch_size,
                   'time_per_atom': (end_time - start_time) / n_reps / batch.n_atoms.sum().item()}
        save_object(folder / f'timing_results.json', to_save, use_json=True)
        
        return to_save
