from typing import Callable, Any, List, Tuple, Dict

import logging
import torch
import numpy as np
import time

from collections import defaultdict, namedtuple
from torch import Tensor

# TODO: deprecated, remove it
from expground.utils.logging import get_logger
from expground.utils import format

# TODO: deprecated
logger = get_logger("ridge_rider::optimizer", None, log_level=logging.DEBUG)


def _exact_evals_and_evecs(
    theta: Dict[str, Tensor],
    evaluation: Tensor,  # objective: Callable
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray]]:
    """Extract eigenvalues and eigenvectors.

    Parameters
    ----------
    theta
        List[Tensor], the list of tensors which represent the parameters
    evaluation
        Tensor, the loss function

    Returns
    -------
    a list of tuple(hessian, eigenvalues, eigenvectors)
    """
    # evaluation = objective(theta)

    theta_keys = list(theta.keys())
    theta_values = list(theta.values())

    gradients = torch.autograd.grad(evaluation, theta_values, create_graph=True)
    hessians, evals, evecs = {}, {}, {}
    # XXX: make it run in parallel
    for k, gradient, theta_value in zip(theta_keys, gradients, theta_values):
        # XXX: runtime error - one of the differentiated tensors appears to not have been used in the graph.
        #  Set allow_unused=True if this is the desired behavior
        hessian = [
            torch.autograd.grad(g, theta_value, create_graph=True)[0]
            for g in gradient.flatten()
        ]
        hessian = torch.stack(hessian)
        np_hess = np.real(hessian.detach().numpy().reshape(len(hessian), -1))
        EVals, EVecs = np.linalg.eig(np_hess)
        EVecs = np.real(EVecs.T)
        EVals = np.real(EVals)
        idx = np.argsort(-EVals)
        EVals = np.take(EVals, idx, axis=0)
        EVecs = np.take(EVecs, idx, axis=0)
        hessians[k] = np_hess
        evals[k] = EVals
        evecs[k] = EVecs

    return hessians, evals, evecs


def _get_ridges(
    policy: "PolicyInterface",
    theta: Dict[str, np.ndarray],
    objective: Callable,
    filter=True,
    return_vals=False,
    min_val=0.000001,
):
    evaluation: Tensor = objective(theta)
    start = time.time()
    _, evals, evecs = _exact_evals_and_evecs(policy.get_weights(), evaluation)
    logger.debug(
        f"extract EigenVals and EigenVecs with time consumption={time.time() - start} seconds"
    )

    idxes = {}
    if filter:
        for k, _evals in evals.items():
            idxes[k] = [i for i in range(len(_evals)) if _evals[i] > min_val]
    else:
        for k, _evals in evals.items():
            idxes[k] = range(len(_evals))

    index = {}
    for k, idx in idxes.items():
        evals[k] = np.take(evals[k], idx, axis=0)
        evecs[k] = np.take(evecs[k], idx, axis=0)

        evals[k] = np.concatenate([evals[k], evals[k]])
        evecs[k] = np.concatenate([evecs[k], evecs[k]])
        index[k] = np.concatenate([idx, idx])

        idx = np.argsort(index[k])
        evecs[k] = np.take(evecs[k], idx, axis=0)
        evals[k] = np.take(evals[k], idx, axis=0)
        index[k] = np.take(index[k], idx)

    return index, evecs, evals


def _update_ridge(
    policy: "PolicyInterface",
    theta: Dict[str, np.ndarray],
    objective: Callable,
    e_i: Dict[str, np.ndarray],
    lambda_i,
):
    start = time.time()
    index, evecs, evals = _get_ridges(policy, theta, objective, filter=False)
    logger.debug(f"step getting ridges, time consumption={time.time() - start} seconds")

    overlaps = {}
    for k, _evecs in evecs.items():
        _e_i = e_i[k]
        overlaps[k] = [evec @ _e_i for evec in _evecs]
    index = {k: np.argmax(overlap) for k, overlap in overlaps.items()}

    selected_evecs = {k: _evecs[index[k]] for k, _evecs in evecs.items()}
    selected_evals = {k: _evals[index[k]] for k, _evals in evals.items()}
    selected_overlaps = {k: overlap[index[k]] for k, overlap in overlaps.items()}

    return selected_evecs, selected_evals, selected_overlaps


def _end_ride(
    policy: "PolicyInterface",
    objective: Callable,
    theta: Dict[str, np.ndarray],
    e_i: Dict[str, np.ndarray],
    lambda_i: Dict[str, np.ndarray],
    overlaps: Dict[str, np.ndarray],
):
    """We stop if we either cannot follow the eigenvector or curvature is negative and
    we are pointing against the gradient
    """
    stops = {}
    evaluation = objective(theta)
    parameters = policy.get_weights()
    parameter_keys = list(parameters.keys())
    parameter_values = list(parameters.values())
    gradients = [
        grad.detach().numpy().squeeze()
        for grad in torch.autograd.grad(
            evaluation, parameter_values, create_graph=False
        )
    ]
    gradients = dict(zip(list(parameter_keys), gradients))
    for k, overlap in overlaps.items():
        tmp = gradients[k].reshape(-1)
        if overlap < 0.7:
            stops[k] = True
        elif lambda_i[k] < 0 and tmp.dot(e_i[k]) < 0:
            stops[k] = True
        else:
            stops[k] = False

    # judge based on any or all
    return any(list(stops.values()))


def _choose_from_archive(
    archive,
) -> Tuple[
    Dict[str, np.ndarray],
    Dict[str, np.ndarray],
    Dict[str, np.ndarray],
    Dict[str, np.ndarray],
    Dict[str, List[Any]],
]:
    theta, e_i, lambda_i, psi = {}, {}, {}, {}
    for k, _archive in archive.items():
        theta[k], e_i[k], lambda_i[k], psi[k] = _archive.pop(0)

    return theta, e_i, lambda_i, psi, archive


def _check_length(archive: Dict[str, List[Any]]):
    min_length = 10000000
    for item in archive.values():
        min_length = min(min_length, len(item))
    return min_length


def _update_archive(mis: Dict[str, np.ndarray], index, evecs, evals, psi, arxiv):
    for k, _arxiv in arxiv.items():
        for i, _index in enumerate(index[k]):
            _arxiv.append([mis[k], evecs[k][i], evals[k][i], psi[k] + [_index]])


def _ridge_riding(policy: "PolicyInterface", args):
    all_solutions = []
    mis: Dict[str, np.ndarray] = {
        k: v.detach().numpy() for k, v in policy.get_weights().items()
    }
    objective = policy.objective()

    # get init index and evecs with
    logger.debug("Initializing index and evecs with running `_get_ridges`...")
    index, evecs, evals = _get_ridges(
        policy, mis, objective, filter=True, return_vals=True, min_val=args.min_val
    )

    archive = defaultdict(lambda: [])
    archive_sample_info = dict.fromkeys(index.keys(), 0)
    for k, _index in index.items():
        for i, __index in enumerate(_index):
            # assert __index < len(mis[k]), (__index, k, len(mis[k]))
            archive[k].append([mis[k], evecs[k][i], evals[k][i], [__index]])
            archive_sample_info[k] += 1

    # FIXME: maybe a bug - zero samples for some thetas.
    logger.debug(
        f"updated archive samples:\n{format.pretty_dict(archive_sample_info, 1)}"
    )

    trials = 0
    logger.debug(f"Run main loop with maximum trails={args.N}")
    while _check_length(archive) > 0 and trials < args.N:
        trials += 1
        # theta is a set of matrix, also e_i, lambda_i and psi
        theta, e_i, lambda_i, psi, archive = _choose_from_archive(archive)
        t = 0
        logger.debug(f"------ [Trail #{trials}] ------")
        while True and t < args.max_ride_length:
            t += 1
            theta = {
                k: _theta + args.alpha * e_i[k].reshape(_theta.shape)
                for k, _theta in theta.items()
            }
            # replace theta to compute the new loss
            logger.debug("------ update ridges ......")
            e_i, lambda_i, overlap = _update_ridge(
                policy, theta, objective, e_i, lambda_i
            )
            if _end_ride(policy, objective, theta, e_i, lambda_i, overlap):
                break
        evaluation = objective(theta)
        all_solutions.append([theta, psi, evaluation])
        index, evecs, evals = _get_ridges(
            policy,
            theta,
            objective,
            filter=True,
            return_vals=True,
            min_val=args.min_val,
        )
        # update archive
        _update_archive(mis, index, evecs, evals, psi, archive)

    return {"all_solutions": all_solutions, "archive": archive}


def _get_mis(policy):
    loss = policy.loss()
    grads = {
        k: torch.autograd.grad(loss, v, allow_unused=True, create_graph=True)[0]
        for k, v in policy.get_weights().items()
    }
    grad_squares = {k: torch.sum(v ** 2) for k, v in grads.items()}
    grad_norm_gradients = {
        k: -torch.autograd.grad(grad_squares[k], v, create_graph=True)[0]
        for k, v in policy.get_weights().items()
    }
    with torch.no_grad():
        new_theta = {}
        for k, v in policy.get_weights().items():
            v += grad_norm_gradients[k]
            new_theta[k] = v.detach().numpy()
        policy.policy.update_trainable_parameters(new_theta)


ArgsTuple = namedtuple("ArgsTuple", "N, max_ride_length, alpha, min_val")


def step(policy: "PolicyInterface", args: ArgsTuple):
    # minimize 1st Jacobian to find the mis
    _get_mis(policy)
    results = _ridge_riding(policy, args)
    logger.info(
        f"-------> Finished. Found {len(results['all_solutions'])} solutions.\n"
    )
    best_solution = None
    for i, (theta, psi, evaluation) in enumerate(results["all_solutions"]):
        # assign theta then got loss
        policy.policy.update_trainable_parameters(theta)
        new_loss = policy.loss()
        if best_solution is None:
            best_solution = (0, new_loss)
        elif best_solution[1] > new_loss:
            best_solution = (i, new_loss)

    # update network with best solution
    logger.info(f"---- #{best_solution[0]} solution is the best solution")
    policy.policy.update_trainable_parameters(
        results["all_solutions"][best_solution[0]][0]
    )
    return {"total_loss": best_solution[-1]}
