from itertools import product as iter_product
import json
import datetime as dt
from typing import Dict, Iterable, List, Union

import numpy as np
from scipy.linalg import solve_triangular
import pandas as pd
from scipy.linalg import solve_triangular

from tqdm import tqdm
from cyanure import BinaryClassifier

from gscfunc.kernels import spline_kernel
from gscfunc.losses import LogisticLoss, Loss
from gscfunc.config import scratch_dir

save_dir = scratch_dir / 'spline_logistic'
assert save_dir.is_dir()


def proximal_newton(reg: float, loss: Loss, prox_point: np.ndarray, initial_guess: np.ndarray, precision: float, T: np.ndarray, y: np.ndarray, n_iter_max: int = 20):
    """
    Returns `x` s.t `x` approximately minimizes:
    1/n sum_i loss(y_i, [T' x]_i) + reg/2 || x - prox_point ||^2
    Returns a dictionary with iterations, decrement and coefficient.
    """
    cur_precision = + np.inf
    cur_iter = 0
    cur_coeff = initial_guess.copy()
    n = T.shape[0]

    results = {'iter': [],
               'f': [],
               'norm_df': [],
               'decrement': [],
               'coeff': None,
               }

    while cur_precision > precision and cur_iter < n_iter_max:
        y_pred = T.T @ cur_coeff
        f, df, ddf = loss.all_derivatives(y, y_pred)
        grad = 1/n * T @ df + reg * (cur_coeff - prox_point)
        hess = 1/n * T @ (ddf[:, np.newaxis] * T.T) + reg * np.eye(n)
        newton_direction = np.linalg.solve(hess, grad)
        cur_precision = np.dot(grad, newton_direction)
        cur_coeff = cur_coeff - newton_direction

        results['iter'].append(cur_iter)
        results['f'].append(1/n * np.sum(f) + reg/2 * np.sum(cur_coeff - prox_point)**2)
        results['norm_df'].append(1/n * np.sum(grad**2))
        results['decrement'].append(cur_precision)

        cur_iter += 1
    results['coeff'] = cur_coeff

    if cur_iter == n_iter_max:
        print(f"Warn: max number if iterations reached.")
    return results


def check_xp_params(ns, regs, seeds, rs, t, samples_mc, tol, early_stopping, cyan, save_to_dir, xp_name, alpha):
    """
    Checks the experiments params and save them. 
    Could be put in a class to avoid repetition, but this might be simpler this way.
    """
    xp_dir = save_dir / (dt.datetime.now().strftime("%m-%d_%Hh%Mm%Ss") + "_" + xp_name)
    # Save the params
    if save_to_dir:
        if not xp_dir.exists():
            xp_dir.mkdir()
        with open(xp_dir / 'config.json', 'a') as fp:
            config_to_save = {
                'ns': ns,
                'regs': regs,
                'seeds': seeds,
                'rs': rs,
                't': t,
                'samples_mc': samples_mc,
                'tol': tol,
                'early_stopping': early_stopping,
                'save_to_dir': save_to_dir,
                'cyan': cyan,
                'xp_name': xp_name,
                'alpha': alpha,
            }
            # Put numpy array to list if any
            config_to_save = {k: (v.tolist() if type(v) == np.ndarray else v)
                              for k, v in config_to_save.items()}
            json.dump(config_to_save, fp)

    # Put low n/high lambda first
    ns = np.sort(np.unique(ns).astype('int'))
    regs = np.flip(np.sort(regs))
    seeds = np.unique(seeds).astype('int')
    alpha = int(alpha)
    samples_mc = np.linspace(0, 1, int(samples_mc))

    # Check that we will be able to compute the spline kernel
    assert alpha % 2 == 0
    for r, t in iter_product(rs, range(1, t+1)):
        degree_target = (r+1/2)*alpha + 1/2
        assert degree_target % 2 == 0

    # Returns the params
    return ns, regs, seeds, alpha, samples_mc, xp_dir


def it_estimator(T: np.ndarray, y: np.ndarray, risk_target: np.ndarray, regs: np.ndarray, t: int, kernel_mc: np.ndarray, target_mc: np.ndarray, tol: float, cyan: bool, early_stopping: bool, cyan_model: BinaryClassifier = None) -> List[Dict[str, Union[np.ndarray, float, int]]]:
    """
    Computes IT estimator with matrix T, for each (reg, t) in (regs, ts).
    Returns a list of dictionaries.  
    """
    n = T.shape[0]
    dtype = T.dtype
    results = []

    # Starting point for Tikhonov (only used by Newton method)
    warm_start_t0 = np.zeros(n, dtype=dtype)
    prox_point = np.empty_like(warm_start_t0)  # Proximal point at each step
    if early_stopping:  # Best loss achieved by IT(k)
        best_risk_mc = np.empty(t, dtype=dtype)
        best_risk_mc.fill(+np.inf)

    for reg in regs:
        # Values for Tikhonov estimator (t=0)
        prox_point.fill(0.)
        warm_start = warm_start_t0.copy()

        for k in range(t):
            if cyan:
                solver_output = cyan_model.fit(T.T, y, w0=prox_point, lambd=reg, solver='qning-ista',
                                               verbose=False, restart=False, it0=5, tol=tol, max_epochs=1000)
                n_steps = solver_output[0, -1]
                coeff = cyan_model.get_weights()
            else:
                solver_output = proximal_newton(reg=reg, loss=LogisticLoss(
                ), prox_point=prox_point, initial_guess=warm_start, precision=tol, T=T, y=y, n_iter_max=20)
                n_steps = solver_output['iter'][-1]
                coeff = solver_output['coeff']

            pred_train = T.T @ coeff
            loss_train = np.mean(y != np.sign(pred_train))

            mc_pred = kernel_mc @ solve_triangular(T, coeff, lower=False)
            mc_l2loss = np.mean((mc_pred - target_mc)**2)/2
            mc_risk = np.mean(np.log(1 + np.exp(-mc_pred))/(1 + np.exp(-target_mc)) +
                              np.log(1 + np.exp(+mc_pred))/(1 + np.exp(+target_mc)))
            mc_hloss = np.mean(np.exp(target_mc)/(1 + np.exp(target_mc))**2 * (mc_pred - target_mc) ** 2)

            results.append({
                'reg': reg,
                't': k+1,
                'loss_train': loss_train,
                'loss_l2': mc_l2loss,
                'loss_h': mc_hloss,
                'risk_mc': mc_risk,
                'excess_risk': mc_risk - risk_target,
                'n_steps': n_steps,
            })

            # Save this coefficient for next t
            warm_start = coeff.copy()
            prox_point = coeff.copy()

            if t == 1:
                # Save this coefficient for next Tokhonov
                warm_start_t0 = coeff.copy()

            if early_stopping:
                if mc_risk <= best_risk_mc[k]:
                    best_risk_mc[k] = mc_risk
                else:
                    # print(f"Early stopping for it {k+1} at reg: {reg:.1e}.")
                    # We stop comuputing lower regularization for IT(l>=k):
                    # It will only result in overfitting, hence higher l2_loss
                    t = k
    return results


def learning_rate(
    ns: Iterable[int],
    regs: Iterable[float],
    seeds: Iterable[int],
    rs: Iterable[int],
    t: int,
    samples_mc: int,
    tol: float,
    early_stopping: bool,
    cyan: bool = True,
    save_to_dir: bool = True,
    xp_name: str = "",
    alpha: float = 2.,
) -> pd.DataFrame:
    """Computes the l2 loss for IT(reg, t) with X,Y ~ n, seed, r, alpha.

    Parameters
    ----------
    ns : Iterable[int]
        The number of samples to draw.
    regs : Iterable[float]
        The regularization to test. 
    seeds : Iterable[int]
        The seed for the random number generator to ensure reproducibility.
    rs : Iterable[int]
        The source condition for the target function. 
    t : int
        The maximum number of proximal steps to do.
    samples_mc : int
        The number of samples for the MC estimation of the l2 loss.
    tol : float
        The tolerance on the error of the solver
    cyan : bool, optional
        Whether to use Cyanure or a Newton method to compute IT, by default True
    save_to_dir : bool, optional
        Whether to save the results on disk, by default True
    xp_name : str, optional
        A name to describe the directory of the experiment, by default ""
    alpha : float, optional
        The capacity condition, by default 2.

    Returns
    -------
    pd.DataFrame
        DataFrame of results; each column is a parameter. 
    """
    # Simple preprocessing on the params
    ns, regs, seeds, alpha, samples_mc, xp_dir = check_xp_params(
        ns, regs, seeds, rs, t, samples_mc, tol, early_stopping, cyan, save_to_dir, xp_name, alpha)

    pb_x = tqdm(iter_product(seeds, ns), total=len(seeds)*len(ns))
    results = None
    for seed, n in pb_x:
        # x data
        rng_x = np.random.default_rng(seed+n)  # Taking seed+n to avoid same data accross n
        x = rng_x.random(size=n)
        K = spline_kernel(x[:, np.newaxis], x[np.newaxis, :], alpha)
        T = np.linalg.cholesky(K).T  # K = T' T, T is upper triangular
        kernel_mc = spline_kernel(samples_mc[:, np.newaxis], x[np.newaxis, :], alpha)

        cyan_model = BinaryClassifier('logistic', penalty='l2-prox',
                                      fit_intercept=False) if cyan else None
        for r in rs:
            # y data
            # Take a different seed to be able to reproduce the results with specific r
            rng_y = np.random.default_rng(int(r)*1000+seed)

            # We evaluate the optimum on x, on the mc samples and we set the labels.
            degree_target = int((r+1/2)*alpha + 1/2)
            target_x = spline_kernel(x, 0, degree_target)
            target_mc = spline_kernel(samples_mc, 0, degree_target)

            y_label_x = rng_y.random(size=n) < (1/(1 + np.exp(-target_x)))
            y_label_x = y_label_x.astype('float32')
            y_label_x[y_label_x == 0] = -1

            # y_label_mc = rng_y.random(size=samples_mc.shape[0]) < (1/(1 + np.exp(-target_mc)))
            # y_label_mc = y_label_mc.astype('float32')
            # y_label_mc[y_label_mc == 0] = -1

            # risk_target = np.mean(np.log(1+np.exp(-y_label_mc * target_mc)))
            risk_target = np.mean(np.log(1 + np.exp(-target_mc))/(1 + np.exp(-target_mc)) +
                                  np.log(1 + np.exp(+target_mc))/(1 + np.exp(+target_mc)))

            pb_x.set_postfix(seed=seed, n=n, r=r)

            it_grid_results = it_estimator(T=T, y=y_label_x, risk_target=risk_target, regs=regs, t=t,
                                           kernel_mc=kernel_mc, target_mc=target_mc, tol=tol, cyan=cyan, early_stopping=early_stopping, cyan_model=cyan_model)

            # Updates the dataframe with x values
            it_grid_results = pd.DataFrame(it_grid_results)
            it_grid_results[['r', 'n', 'seed']] = r, n, seed

            results = it_grid_results if results is None else pd.concat([results, it_grid_results])

            if save_to_dir:
                results.to_csv(xp_dir / 'results.csv', index=False)
    return results


if __name__ == "__main__":
    df = learning_rate(
        ns=np.logspace(2, 4, 20, endpoint=False),
        regs=np.logspace(-4, 0, 50),
        seeds=np.arange(100),
        rs=[1/4 + k for k in [0, 3, 10]],
        t=8,
        samples_mc=int(1e4),
        tol=1e-10,
        cyan=True,
        early_stopping=True,
        save_to_dir=True,
        xp_name='cyan_n4',
        alpha=2.,
    )
    # df = learning_rate(
    #     ns = np.logspace(1, 2, 5),
    #     regs = np.logspace(-4, 0, 20),
    #     seeds = np.arange(10),
    #     rs = [1/4 + k for k in range(3)],
    #     t = 3,
    #     samples_mc = int(1e4),
    #     tol = 1e-12,
    #     cyan = True,
    #     early_stopping=True,
    #     save_to_dir = False,
    #     xp_name = 'test_cyan',
    #     alpha = 2.,
    # )
