from typing import IO, Optional, Union

import numpy as np
import scipy.optimize as opt
from math import sqrt
import time

from ase import Atoms
from ase.optimize.optimize import Optimizer


class Converged(Exception):
    pass


class OptimizerConvergenceError(Exception):
    pass


class SciPyOptimizer(Optimizer):
    """General interface for SciPy optimizers

    Only the call to the optimizer is still needed
    """

    def __init__(
        self,
        atoms: Atoms,
        logfile: Union[IO, str] = '-',
        trajectory: Optional[str] = None,
        callback_always: bool = False,
        alpha: float = 70.0,
        **kwargs,
    ):
        """Initialize object

        Parameters
        ----------
        atoms: :class:`~ase.Atoms`
            The Atoms object to relax.

        trajectory: str
            Trajectory file used to store optimisation path.

        logfile: file object or str
            If *logfile* is a string, a file with that name will be opened.
            Use '-' for stdout.

        callback_always: bool
            Should the callback be run after each force call (also in the
            linesearch)

        alpha: float
            Initial guess for the Hessian (curvature of energy surface). A
            conservative value of 70.0 is the default, but number of needed
            steps to converge might be less if a lower value is used. However,
            a lower value also means risk of instability.

        kwargs : dict, optional
            Extra arguments passed to
            :class:`~ase.optimize.optimize.Optimizer`.

        """
        restart = None
        Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs)
        self.force_calls = 0
        self.callback_always = callback_always
        self.H0 = alpha
        self.max_steps = 0

    def x0(self):
        """Return x0 in a way SciPy can use

        This class is mostly usable for subclasses wanting to redefine the
        parameters (and the objective function)"""
        return self.optimizable.get_positions().reshape(-1)

    def f(self, x):
        """Objective function for use of the optimizers"""
        self.optimizable.set_positions(x.reshape(-1, 3))
        # Scale the problem as SciPy uses I as initial Hessian.
        return self.optimizable.get_potential_energy() / self.H0

    def fprime(self, x):
        """Gradient of the objective function for use of the optimizers"""
        self.optimizable.set_positions(x.reshape(-1, 3))
        self.force_calls += 1

        if self.callback_always:
            self.callback(x)

        # Remember that forces are minus the gradient!
        # Scale the problem as SciPy uses I as initial Hessian.
        return - self.optimizable.get_forces().reshape(-1) / self.H0

    def callback(self, intermediate_result):
        """Callback function to be run after each iteration by SciPy

        This should also be called once before optimization starts, as SciPy
        optimizers only calls it after each iteration, while ase optimizers
        call something similar before as well.

        :meth:`callback`() can raise a :exc:`Converged` exception to signal the
        optimisation is complete. This will be silently ignored by
        :meth:`run`().
        """

        self.force_calls += intermediate_result["iterations"]

        if self.nsteps < self.max_steps:
            self.nsteps += 1
        f = self.optimizable.get_forces()
        self.log(f)
        self.call_observers()
        if self.converged(f):
            raise Converged

    def run(self, fmax=0.05, steps=100000000):
        self.fmax = fmax

        try:
            # As SciPy does not log the zeroth iteration, we do that manually
            if self.nsteps == 0:
                self.log()
                self.call_observers()

            self.max_steps = steps + self.nsteps

            # Scale the problem as SciPy uses I as initial Hessian.
            self.call_fmin(fmax / self.H0, steps)
        except Converged:
            pass
        return self.converged()

    def dump(self, data):
        pass

    def load(self):
        pass

    def call_fmin(self, fmax, steps):
        raise NotImplementedError

    def log(self, forces=None):
        if forces is None:
            forces = self.optimizable.get_forces()
        fmax = sqrt((forces ** 2).sum(axis=1).max())
        e = self.optimizable.get_potential_energy()
        T = time.localtime()
        if self.logfile is not None:
            name = self.__class__.__name__
            if self.nsteps == 0:
                args = (" " * len(name), "Step", "Time", "Energy", "fmax")
                msg = "%s  %4s %8s %15s  %12s\n" % args
                self.logfile.write(msg)

            args = (name, self.nsteps, self.force_calls, T[3], T[4], T[5], e, fmax)
            msg = "%s:  %3d[%3d] %02d:%02d:%02d %15.6f %15.6f\n" % args
            self.logfile.write(msg)
            self.logfile.flush()


class SciPyFminCG(SciPyOptimizer):
    """Non-linear (Polak-Ribiere) conjugate gradient algorithm"""

    def call_fmin(self, fmax, steps):
        output = opt.fmin_cg(self.f,
                             self.x0(),
                             fprime=self.fprime,
                             # args=(),
                             gtol=fmax * 0.1,  # Should never be reached
                             norm=np.inf,
                             # epsilon=
                             maxiter=steps,
                             full_output=1,
                             disp=0,
                             # retall=0,
                             callback=self.callback)
        warnflag = output[-1]
        if warnflag == 2:
            raise OptimizerConvergenceError(
                'Warning: Desired error not necessarily achieved '
                'due to precision loss')
        



class SciPyFminBFGS(SciPyOptimizer):
    """Quasi-Newton method (Broydon-Fletcher-Goldfarb-Shanno)"""

    def call_fmin(self, fmax, steps):
        output = opt.fmin_bfgs(self.f,
                               self.x0(),
                               fprime=self.fprime,
                               # args=(),
                               gtol=fmax * 0.1,  # Should never be reached
                               norm=np.inf,
                               # epsilon=1.4901161193847656e-08,
                               maxiter=steps,
                               full_output=1,
                               disp=0,
                               # retall=0,
                               callback=self.callback)
        warnflag = output[-1]
        if warnflag == 2:
            raise OptimizerConvergenceError(
                'Warning: Desired error not necessarily achieved '
                'due to precision loss')


class SciPyGradientlessOptimizer(Optimizer):
    """General interface for gradient less SciPy optimizers

    Only the call to the optimizer is still needed

    Note: If you redefine x0() and f(), you don't even need an atoms object.
    Redefining these also allows you to specify an arbitrary objective
    function.

    XXX: This is still a work in progress
    """

    def __init__(
        self,
        atoms: Atoms,
        logfile: Union[IO, str] = '-',
        trajectory: Optional[str] = None,
        callback_always: bool = False,
        **kwargs,
    ):
        """Initialize object

        Parameters
        ----------
        atoms: :class:`~ase.Atoms`
            The Atoms object to relax.

        trajectory: str
            Trajectory file used to store optimisation path.

        logfile: file object or str
            If *logfile* is a string, a file with that name will be opened.
            Use '-' for stdout.

        callback_always: bool
            Should the callback be run after each force call (also in the
            linesearch)

        alpha: float
            Initial guess for the Hessian (curvature of energy surface). A
            conservative value of 70.0 is the default, but number of needed
            steps to converge might be less if a lower value is used. However,
            a lower value also means risk of instability.

        kwargs : dict, optional
            Extra arguments passed to
            :class:`~ase.optimize.optimize.Optimizer`.

        """
        restart = None
        Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs)
        self.function_calls = 0
        self.callback_always = callback_always
        self.force_calls = 0

    def x0(self):
        """Return x0 in a way SciPy can use

        This class is mostly usable for subclasses wanting to redefine the
        parameters (and the objective function)"""
        return self.optimizable.get_positions().reshape(-1)

    def f(self, x):
        """Objective function for use of the optimizers"""
        self.optimizable.set_positions(x.reshape(-1, 3))
        self.function_calls += 1
        # Scale the problem as SciPy uses I as initial Hessian.
        return self.optimizable.get_potential_energy()

    def callback(self, x):
        """Callback function to be run after each iteration by SciPy

        This should also be called once before optimization starts, as SciPy
        optimizers only calls it after each iteration, while ase optimizers
        call something similar before as well.
        """
        # We can't assume that forces are available!
        # f = self.optimizable.get_forces()
        # self.log(f)
        self.call_observers()
        # if self.converged(f):
        #    raise Converged
        self.nsteps += 1

    def run(self, ftol=0.01, xtol=0.01, steps=100000000):
        self.xtol = xtol
        self.ftol = ftol
        # As SciPy does not log the zeroth iteration, we do that manually
        self.callback(None)
        try:
            # Scale the problem as SciPy uses I as initial Hessian.
            self.call_fmin(xtol, ftol, steps)
        except Converged:
            pass
        return self.converged()

    def dump(self, data):
        pass

    def load(self):
        pass

    def call_fmin(self, xtol, ftol, steps):
        raise NotImplementedError

    def log(self, forces=None):
        if forces is None:
            forces = self.optimizable.get_forces()
        fmax = sqrt((forces ** 2).sum(axis=1).max())
        e = self.optimizable.get_potential_energy()
        T = time.localtime()
        if self.logfile is not None:
            name = self.__class__.__name__
            if self.nsteps == 0:
                args = (" " * len(name), "Step", "Time", "Energy", "fmax")
                msg = "%s  %4s %8s %15s  %12s\n" % args
                self.logfile.write(msg)

            args = (name, self.nsteps, T[3], T[4], T[5], e, fmax)
            msg = "%s:  %3d %02d:%02d:%02d %15.6f %15.6f\n" % args
            self.logfile.write(msg)
            self.logfile.flush()


class SciPyFmin(SciPyGradientlessOptimizer):
    """Nelder-Mead Simplex algorithm

    Uses only function calls.

    XXX: This is still a work in progress
    """

    def call_fmin(self, xtol, ftol, steps):
        opt.fmin(self.f,
                 self.x0(),
                 # args=(),
                 xtol=xtol,
                 ftol=ftol,
                 maxiter=steps,
                 # maxfun=None,
                 # full_output=1,
                 disp=0,
                 # retall=0,
                 callback=self.callback)


class SciPyFminPowell(SciPyGradientlessOptimizer):
    """Powell's (modified) level set method

    Uses only function calls.

    XXX: This is still a work in progress
    """

    def __init__(self, *args, **kwargs):
        """Parameters:

        direc: float
            How much to change x to initially. Defaults to 0.04.
        """
        direc = kwargs.pop('direc', None)
        SciPyGradientlessOptimizer.__init__(self, *args, **kwargs)

        if direc is None:
            self.direc = np.eye(len(self.x0()), dtype=float) * 0.04
        else:
            self.direc = np.eye(len(self.x0()), dtype=float) * direc

    def call_fmin(self, xtol, ftol, steps):
        opt.fmin_powell(self.f,
                        self.x0(),
                        # args=(),
                        xtol=xtol,
                        ftol=ftol,
                        maxiter=steps,
                        # maxfun=None,
                        # full_output=1,
                        disp=0,
                        # retall=0,
                        callback=self.callback,
                        direc=self.direc)
