"""
Module for extracting information from ORCA output files. Can be used to read arbitrary
information from the output files, as well as gradient and hessian files generated by
the ORCA code package. Containes several predefined parsers.
"""

from typing import Optional, List, Dict, Union

import logging
import os
import numpy as np
from ase import Atoms, units
from tqdm import tqdm

import schnetpack as spk
from schnetpack import properties

logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))

__all__ = [
    "ppm2au",
    "OrcaParserException",
    "OrcaParser",
    "format_dipole_derivatives",
    "format_polarizability_derivatives",
    "OrcaOutputParser",
    "OrcaFormatter",
    "OrcaPropertyParser",
    "OrcaMainFileParser",
    "OrcaHessianFileParser",
]

# Conversion from ppm to atomic units. Alpha is the fine structure constant and 1e6 are
# the ppm
ppm2au = 2.0 / (units.alpha**2 * 1e6)


class OrcaParserException(Exception):
    """
    Exception for OrcaParser class.
    """

    pass


class OrcaParser:
    """
    Main parsers utility for ORCA output files. Runs over a list of output files, extracts the data
    and stores it into a formatted ASE database in SchNetPack format (:obj:`schnetpack.data.AtomsData`).
    This class makes use of the :obj:`OrcaMainFileParser` and :obj:`OrcaHessianFileParser` defined below.

    Args:
        dbpath (str): Path to the target database.
        target_properties (list): List of properties to extract from ORCA files.
        filter (dict, optional): Dictionary giving the name of a property and a threshold value. Entries in
                                 in the output files with values exceeding the threshold in magnitude are
                                 discarded. This can be used to e.g. screen for numerical noise in implicit
                                 solvent computations, etc.
        mask_charges (bool, optional): If the ORCA calculation used external charges, these are removed from the
                                       positions and atom types read by the parser.
    """

    main_properties = [
        properties.energy,
        properties.forces,
        properties.dipole_moment,
        properties.polarizability,
        properties.shielding,
    ]
    hessian_properties = [
        properties.hessian,
        properties.dipole_derivatives,
        properties.polarizability_derivatives,
    ]
    molecular_properties = [properties.dipole_moment, properties.polarizability]
    file_extensions = {properties.forces: ".engrad", properties.hessian: ".oinp.hess"}
    atomistic = ["atoms", properties.forces, properties.shielding]

    def __init__(
        self,
        dbpath: str,
        target_properties: List,
        filter: Optional[Dict[str, float]] = None,
        mask_charges: bool = False,
        property_units: Dict[str, Union[str, float]] = {},
        distance_unit: Union[str, float] = 1.0,
    ):
        self.dbpath = dbpath

        main_properties = []
        hessian_properties = []
        dummy_properties = []

        # Initialize property unit dict
        self.property_units = property_units
        for p in target_properties:
            if p not in self.property_units:
                self.property_units[p] = 1.0

        for p in target_properties:
            if p in self.main_properties:
                main_properties.append(p)
            elif p in self.hessian_properties:
                hessian_properties.append(p)
            else:
                print("Unrecognized property {:s}".format(p))

        if properties.electric_field in target_properties:
            dummy_properties.append(properties.electric_field)
        if properties.magnetic_field in target_properties:
            dummy_properties.append(properties.magnetic_field)

        all_properties = main_properties + hessian_properties + dummy_properties
        self.all_properties = all_properties

        self.atomsdata = spk.data.ASEAtomsData.create(
            dbpath, distance_unit=distance_unit, property_unit_dict=self.property_units
        )

        # The main file parser is always needed
        self.main_parser = OrcaMainFileParser(
            target_properties=main_properties + ["atoms"]
        )

        if len(hessian_properties) > 0:
            self.hessian_parser = OrcaHessianFileParser(
                target_properties=hessian_properties
            )
        else:
            self.hessian_parser = None

        # Set up filter dictionary to e.g. remove numerically unstable solvent computations
        self.filter = filter

        # If requested, mask Q charges introduced by Orca
        self.mask_charges = mask_charges

    def parse_data(self, data_files: List[str], buffer_size: int = 10):
        """
        Reads in a list of ORCA output files, extracts the data, performs reformatting and
        then stores structures and properties to an ASE database.

        Args:
            data_files (list): List of the paths to the ORCA output files.
            buffer_size (int, optional): Collects a certain number of molecules before writing them to
                                         the database.
        """
        atom_buffer = []
        property_buffer = []

        for file in tqdm(sorted(data_files), ncols=100):

            if os.path.exists(file):

                atoms, properties = self._parse_molecule(file)

                if properties is not None:
                    # Filter properties for problematic values
                    if self.filter is not None:
                        filtered = False
                        for p in self.filter:
                            if np.linalg.norm(properties[p]) > self.filter[p]:
                                filtered = True
                                logging.info(f"Filtered output {file} due to {p}")
                        if not filtered:
                            atom_buffer.append(atoms)
                            property_buffer.append(properties)
                    else:
                        atom_buffer.append(atoms)
                        property_buffer.append(properties)

                    if len(atom_buffer) >= buffer_size:
                        self.atomsdata.add_systems(property_buffer, atom_buffer)
                        atom_buffer = []
                        property_buffer = []

        # Collect leftovers
        if len(atom_buffer) > 0:
            self.atomsdata.add_systems(property_buffer, atom_buffer)

    def _parse_molecule(self, datafile: str):
        """
        Parser for a single molecule. This routine first checks for convergence
        of the calculation and then extracts all properties as instructed in the
        individual parsers.

        Args:
            datafile (str): Path to the datafile.

        Returns:
            (ase.Atoms, dict):
                atoms:
                    ase.Atoms object holding atom types and positions.
                properties:
                    Dictionary containing all extracted properties. Naming conventions are the
                    same as used in :obj:`schnetpack.Properties`.
        """
        # check if computation converged
        if not self._check_convergence(datafile):
            return None, None

        # Get main properties
        self.main_parser.parse_file(datafile)
        main_properties = self.main_parser.get_parsed()

        if self.mask_charges:
            main_properties = self._mask_charges(main_properties)

        atoms = None
        target_properties = {}
        for p in main_properties:
            if main_properties[p] is None:
                print("Error parsers {:s}".format(p))
                return None, None
            elif p == "atoms":
                atypes, coords = main_properties[p]
                atoms = Atoms(atypes, coords)
            else:
                target_properties[p] = main_properties[p].astype(np.float64)

        if self.hessian_parser is not None:
            hessian_file = (
                os.path.splitext(datafile)[0] + self.file_extensions["hessian"]
            )
            if not os.path.exists(hessian_file):
                print("Could not open Hessian file {:s}".format(hessian_file))
                return atoms, None
            else:
                self.hessian_parser.parse_file(hessian_file)
                hessian_properties = self.hessian_parser.get_parsed()

                for p in hessian_properties:
                    if p is None:
                        return atoms, None
                    elif p == properties.dipole_derivatives:
                        target_properties[p] = format_dipole_derivatives(
                            hessian_properties[p]
                        )
                    elif p == properties.polarizability_derivatives:
                        target_properties[p] = format_polarizability_derivatives(
                            hessian_properties[p]
                        )
                    else:
                        target_properties[p] = hessian_properties[p]

        # Unsqueeze first dimension for new batch format
        for p in target_properties:
            if p in self.molecular_properties:
                target_properties[p] = target_properties[p][None, ...]

        # Add dummy fields if requested
        if properties.electric_field in self.all_properties:
            target_properties[properties.electric_field] = np.zeros((1, 3))
            if properties.magnetic_field in self.all_properties:
                target_properties[properties.magnetic_field] = np.zeros((1, 3))

        return atoms, target_properties

    @staticmethod
    def _check_convergence(datafile: str):
        """
        Check whether calculation has converged by searching for the
        typical ORCA closing statement.

        Args:
            datafile (str): Path to the output file.

        Returns:
            bool: Indicator if computation finished successfully.
        """
        flag = open(datafile).readlines()[-2].strip()

        if not flag == "****ORCA TERMINATED NORMALLY****":
            return False
        else:
            return True

    def _mask_charges(self, main_properties: Dict[str, np.array]):
        """
        Remove the external charges Q introduced in orca input file. This is
        only necessary, if the charges are given in the input file. This in
        turn is necessary to get the right shielding tensors, as Orca is buggy
        in this case. All other properties need to be taken from a computation
        with external charges given in a file, since external charges in the
        input file are added to the potential energy and the dipole moment...

        Args:
            main_properties (dict): main property dictionary generated by the
                                    :obj:`OrcaParser`.

        Returns:
            dict: Updated property dictionary with the structures, atom types
                  and atomistic properties masked.
        """
        n_atoms = np.sum(main_properties["atoms"][0] != "Q")

        for p in main_properties:
            if p in self.atomistic:
                if p == "atoms":
                    main_properties[p][0] = main_properties[p][0][:n_atoms]
                    main_properties[p][1] = main_properties[p][1][:n_atoms]
                else:
                    main_properties[p] = main_properties[p][:n_atoms]

        return main_properties


def format_dipole_derivatives(target_property: np.array):
    """
    Reshape the extracted dipole derivatives to the correct
    format. Format is Natoms x (dx dy dz) x (property x y z)

    Args:
        property (numpy.array):

    Returns:
        numpy.array: Reshaped array
    """
    N, _ = target_property.shape
    N = N // 3
    target_property = target_property.reshape(N, 3, 3)
    return target_property


def format_polarizability_derivatives(target_property: np.array):
    """
    Reshape the extracted polarizability derivatives to the correct
    format. Format is Natoms x (dx dy dz) x (property Tensor)

    Args:
        target_property (numpy.array):

    Returns:
        numpy.array: Reshaped array
    """
    N, _ = target_property.shape
    N = N // 3
    target_property = target_property.reshape(N, 3, 6)
    triu_idx = np.triu_indices(3)
    reshaped = np.zeros((N, 3, 3, 3))
    reshaped[:, :, triu_idx[0], triu_idx[1]] = target_property
    reshaped[:, :, triu_idx[1], triu_idx[0]] = target_property
    return reshaped


class OrcaFormatter:
    """
    Format raw ORCA data collected by an :obj:`OrcaPropertyParser`. Behavior is determined by the datatype option.
    This is e.g. used to extract the correct gradient values from the associated text block taken from the ORCA
    output and convert it to a properly formatted force array.

    Args:
        position (int): Position to start formatting. If no stop is provided returns only value at position, otherwise
                        all values between position and stop are returned. (Only used for `vector` mode)
        stop (int, optional): Stop value for range. (Only used for `vector` mode)
        datatype (str, optional): Change formatting behavior. The possible options are:

                                   - `vector`:
                                        Formats data between position and stop argument, if provided
                                        converting it to the type given in the converter.
                                   - `matrix`:
                                        Formats collected matrix data into the shape of a square, symmetric
                                        numpy.ndarray. Ignores other options.
        converter (type, optional): Convert data to type. (Only used for 'vector' mode)
        default (float, optional): Default value to be returned if nothing is parsed (e.g. 1.0 for vacuum in case of
                         dielectric constant.
        skip_first (int, optional): If not None, skip the first N lines (default=None).
    """

    def __init__(
        self,
        position,
        stop: Optional[int] = None,
        datatype: str = "vector",
        converter: type = np.double,
        skip_first: Optional[int] = None,
        unit: Optional[Union[float]] = None,
        default: Optional[float] = None,
    ):
        self.position = position
        self.stop = stop
        self.datatype = datatype
        self.converter = converter
        self.matrix_dim = None
        self.skip_first = skip_first
        self.unit = unit
        self.default = default

    def format(self, parsed: List[str]):
        """
        Format the raw parsed data according to the given instructions.

        Args:
            parsed (list): List of raw parsed lines taken from the output file.

        Returns:
            numpy.array: Formatted numpy array holding the processed properties.
        """
        if parsed is None:
            if self.default is not None:
                return np.array([self.default])
            else:
                return None
        else:
            if self.skip_first is not None:
                parsed = parsed[self.skip_first :]

        if len(parsed) == 0:
            return None

        elif self.datatype == "vector":
            formatted = self._format_vector(parsed)
        elif self.datatype == "matrix":
            formatted = self._format_matrix(parsed)
        elif self.datatype == "shielding":
            formatted = self._format_shielding(parsed)
        else:
            raise NotImplementedError(
                "Unrecognized data type {:s}".format(self.datatype)
            )

        if self.unit is not None:
            formatted *= self.unit

        return formatted

    def _format_vector(self, parsed: List[str]):
        """
        Take numerical entries in a line and collect them into an array.
        It is possible to extract only certain slices of an array.
        Although the basic quantity is called `vector`, this also is used to
        extract matrices, such as forces. The special matrix formatter deals
        with cases where ORCA stores special matrices, such as Hamiltionians
        and Hessians. In this case it introduces line breaks after six entries,
        leading this routine to fail.

        Args:
            parsed (list): List of raw parsed lines taken from the output file.

        Returns:
            numpy.array: Formatted numpy array holding the processed properties.
        """
        vector = []
        for line in parsed:
            line = line.split()
            if self.stop is None:
                vector.append(self.converter(line[self.position]))
            else:
                vector.append(
                    [self.converter(x) for x in line[self.position : self.stop]]
                )

        vector = np.array(vector)

        # Remove trailing dimension if only one line is read (for dipole moment)
        if vector.shape[0] == 1 and vector.size != 1:
            vector = vector[0]

        return vector

    def _format_matrix(self, parsed: List[str]):
        """
        Format raw extracted matrices. Unlike the vector formatter, this routine
        deals with cases where ORCA stores special matrices, such as Hamiltionians
        and Hessians. In this case it introduces line breaks after six entries,
        which have to be reformatted.

        Args:
            parsed (list): List of raw parsed lines taken from the output file.

        Returns:
            numpy.array: Formatted numpy array holding the processed properties.
        """
        n_entries = len(parsed[1].split())

        # Get matrix dimensions
        for line in parsed[1:]:
            line = line.split()
            if len(line) != n_entries:
                self.matrix_dim = int(line[0]) + 1

        subdata = [
            parsed[i : i + self.matrix_dim + 1]
            for i in range(0, len(parsed), self.matrix_dim + 1)
        ]

        matrix = [[] for _ in range(self.matrix_dim)]

        for block in subdata:
            for i, entry in enumerate(block[1:]):
                matrix[i] += [self.converter(x) for x in entry.split()[1:]]

        matrix = np.array(matrix)
        return matrix

    def _format_shielding(self, parsed: List[str]):
        """
        Format the raw shielding tensors taken from the ORCA output.

        Args:
            parsed (list): List of raw parsed lines taken from the output file.

        Returns:
            numpy.array: Formatted numpy array holding the processed properties.
        """
        shielding = []
        current_shielding = []
        parse = False
        for line in parsed:
            if line.startswith("Total shielding tensor (ppm):"):
                parse = True
            elif parse:
                if line.startswith("Diagonalized sT*s matrix:"):
                    shielding.append(current_shielding)
                    current_shielding = []
                    parse = False
                else:
                    current_shielding.append([self.converter(x) for x in line.split()])
            else:
                continue

        shielding = np.array(shielding)
        return shielding


class OrcaPropertyParser:
    """
    Basic property parser for ORCA output files. Takes a start flag and a stop flag/list of stop flags and collects
    the data entries in between. If a :obj:`OrcaFormatter` is provided, the data is formatted accordingly upon
    retrieval. Operates in a line-wise fashion.

    Args:
        start (str): begins to collect data starting from this string
        stop (str/list(str)): stops data collection if any of these strings is encounteres
        formatters (OrcaFormatter): OrcaFormatter to convert collected data
    """

    def __init__(
        self,
        start: str,
        stop: Union[str, List[str]],
        formatters: Optional[Union[OrcaFormatter, List[OrcaFormatter]]] = None,
    ):
        self.start = start
        self.stop = stop
        self.formatters = formatters

        self.read = False
        self.parsed = None

    def parse_line(self, line: str):
        """
        Parses a line in the output file and updates the main container.

        Args:
            line (str): line of Orca output file
        """
        line = line.strip()
        if line.startswith("---------") or len(line) == 0:
            pass
        # if line.startswith("*********") or len(line) == 0:
        #     pass
        elif line.startswith(self.start):
            # Avoid double reading and restart for multiple files and repeated instances of data.
            self.parsed = []
            self.read = True
            # For single line output
            if self.stop is None:
                self.parsed.append(line)
                self.read = False
        elif self.read:
            # Check for stops
            if isinstance(self.stop, list):
                for stop in self.stop:
                    if self.read and line.startswith(stop):
                        self.read = False
                if self.read:
                    self.parsed.append(line)
            else:
                if line.startswith(self.stop):
                    self.read = False
                else:
                    self.parsed.append(line)

    def get_parsed(self):
        """
        Returns data, if formatters are specified in the corresponding format.

        Returns:
            numpy.array: Formatted data.
        """
        if self.formatters is None:
            return self.parsed
        elif hasattr(self.formatters, "__iter__"):
            return [formatter.format(self.parsed) for formatter in self.formatters]
        else:
            return self.formatters.format(self.parsed)

    def reset(self):
        """
        Reset state of the parser.
        """
        self.read = False
        self.parsed = None


class OrcaOutputParser:
    """
    Basic ORCA output parser class. Parses an Orca output file according to the parsers specified in the 'parsers'
    dictionary. Parsed data is stored in an dictionary, using the same keys as the parsers. If a list of formatters is
    provided to a parser, a list of the parsed entries is stored in the output dictionary.

    Args:
        parsers (dict[str->callable]): dictionary of :obj:`OrcaPropertyParser`,
                                       each with their own :obj:`OrcaFormatter`.
    """

    def __init__(self, parsers: Dict[str, OrcaPropertyParser]):
        self.parsers = parsers
        self.parsed = None

    def parse_file(self, path: str):
        """
        Open the file and iterate over its lines, applying all parsers. In the end, all data is collected in a
        dictionary.

        Args:
            path (str): path to Orca output file.
        """
        # Reset for new file
        for parser in self.parsers:
            self.parsers[parser].reset()

        with open(path, "r") as f:
            for line in f:
                for parser in self.parsers:
                    self.parsers[parser].parse_line(line)

        self.parsed = {}

        for parser in self.parsers:
            self.parsed[parser] = self.parsers[parser].get_parsed()

    def get_parsed(self):
        """
        Auxiliary routine to collect the data from the parser.

        Returns:
            dict[str->list]: Dictionary of data entries according to parser keys.
        """
        return self.parsed


class OrcaMainFileParser(OrcaOutputParser):
    """
    Predefined :obj:`OrcaOutputParser` for extracting data from the main ORCA output file.
    Can read and format structure and atom types, as well as the energies, forces, dipole moments,
    polarizabilities and nuclear shielding tensors.
    """

    target_properties = [
        "atoms",
        properties.forces,
        properties.energy,
        properties.dipole_moment,
        properties.polarizability,
        properties.shielding,
    ]

    starts = {
        "atoms": "CARTESIAN COORDINATES (ANGSTROEM)",
        properties.forces: "CARTESIAN GRADIENT",
        properties.energy: "FINAL SINGLE POINT ENERGY",
        properties.dipole_moment: "Total Dipole Moment",
        properties.polarizability: "The raw cartesian tensor (atomic units):",
        properties.shielding: "CHEMICAL SHIFTS",
    }

    stops = {
        "atoms": "CARTESIAN COORDINATES (A.U.)",
        properties.forces: "Difference to translation invariance",
        properties.energy: None,
        properties.dipole_moment: None,
        properties.polarizability: "diagonalized tensor:",
        properties.shielding: "CHEMICAL SHIELDING SUMMARY",
    }

    formatters = {
        "atoms": (
            OrcaFormatter(0, converter=str),
            OrcaFormatter(1, stop=4, unit=1.0 / units.Bohr),
        ),
        properties.energy: OrcaFormatter(4),
        properties.forces: OrcaFormatter(3, stop=6, unit=-1.0),
        properties.dipole_moment: OrcaFormatter(4, stop=7),
        properties.polarizability: OrcaFormatter(0, stop=4),
        properties.shielding: OrcaFormatter(0, datatype="shielding", unit=ppm2au),
    }

    def __init__(self, target_properties: Optional[List[str]] = None):

        if target_properties is None:
            to_parse = self.target_properties
        else:
            to_parse = []
            for p in target_properties:
                if p not in self.target_properties:
                    print("Cannot parse property {:s}".format(p))
                else:
                    to_parse.append(p)

        parsers = {
            p: OrcaPropertyParser(
                self.starts[p], self.stops[p], formatters=self.formatters[p]
            )
            for p in to_parse
        }

        super(OrcaMainFileParser, self).__init__(parsers)


class OrcaHessianFileParser(OrcaMainFileParser):
    """
    Predefined :obj:`OrcaOutputParser` for extracting data from the hessian output file written
    by ORCA if some higher order derivatives are requested. Can read and format Hessians, as well as
    Cartesian dipole moment and polarizability derivatives.
    """

    target_properties = [
        properties.hessian,
        properties.dipole_derivatives,
        properties.polarizability_derivatives,
    ]

    starts = {
        properties.hessian: "$hessian",
        properties.dipole_derivatives: "$dipole_derivatives",
        properties.polarizability_derivatives: "$polarizability_derivatives",
    }

    stops = {
        properties.hessian: "$vibrational_frequencies",
        properties.dipole_derivatives: "#",
        properties.polarizability_derivatives: "#",
    }

    formatters = {
        properties.hessian: OrcaFormatter(0, datatype="matrix", skip_first=1),
        properties.dipole_derivatives: OrcaFormatter(0, stop=4, skip_first=1),
        properties.polarizability_derivatives: OrcaFormatter(0, stop=6, skip_first=1),
    }

    def __init__(self, target_properties: Optional[List[str]] = None):
        super(OrcaHessianFileParser, self).__init__(target_properties)
