from itertools import product as iter_product
from pathlib import Path
import json
import datetime as dt
from typing import Iterable

import numpy as np
import pandas as pd

from tqdm import tqdm

from gscfunc.kernels import spline_kernel
from gscfunc.config import scratch_dir
from gscfunc.utils import *

save_dir = scratch_dir / 'spline_ls'
assert save_dir.is_dir()


def spectral_it(diag, reg, t):
    """Spectral function of Iterated Tikhonov regularization."""
    return diag**(-1) * (1 - (reg / (diag + reg))**t)


def check_xp_params(ns, regs, seeds, rs, t, samples_mc, early_stopping, sigma, save_to_dir, xp_name, alpha):
    """
    Checks the experiments params and save them. 
    Could be put in a class to avoid repetition, but this might be simpler this way.
    """
    xp_dir = save_dir / (dt.datetime.now().strftime("%m-%d_%Hh%Mm%Ss") + "_" + xp_name)
    # Save the params
    if save_to_dir:
        if not xp_dir.exists():
            xp_dir.mkdir()
        with open(xp_dir / 'config.json', 'a') as fp:
            config_to_save = {
                'ns': ns,
                'regs': regs,
                'seeds': seeds,
                'rs': rs,
                't': t,
                'samples_mc': samples_mc,
                'early_stopping': early_stopping,
                'save_to_dir': save_to_dir,
                'sigma': sigma,
                'xp_name': xp_name,
                'alpha': alpha,
            }
            # Put numpy array to list if any
            config_to_save = {k: (v.tolist() if type(v) == np.ndarray else v)
                              for k, v in config_to_save.items()}
            json.dump(config_to_save, fp)

    # Put low n/high lambda first
    ns = np.sort(np.unique(ns).astype('int'))
    regs = np.flip(np.sort(regs))
    seeds = np.unique(seeds).astype('int')
    alpha = int(alpha)
    samples_mc = np.linspace(0, 1, int(samples_mc))

    # Check that we will be able to compute the spline kernel
    assert alpha % 2 == 0
    for r, t in iter_product(rs, range(1, t+1)):
        degree_target = (r+1/2)*alpha + 1/2
        assert degree_target % 2 == 0

    # Returns the params
    return ns, regs, seeds, alpha, samples_mc, xp_dir


def it_estimator(K: np.ndarray, D: np.ndarray, U: np.ndarray, y: np.ndarray, y_rotated: np.ndarray, regs: np.ndarray, t: int, kernel_mc: np.ndarray, target_mc: np.ndarray, early_stopping: bool):
    n = K.shape[0]
    dtype = K.dtype
    results = []
    
    if early_stopping:  # Best loss achieved by IT(k)
        best_loss_mc = np.empty(t, dtype=dtype)
        best_loss_mc.fill(+np.inf)

    for reg in regs:
        for k in range(t):
            coeff = 1/n * U @ (spectral_it(D/n, reg, k+1) * y_rotated)
            pred_train = K @ coeff
            loss_train = np.mean((pred_train - y)**2)/2

            pred_mc = kernel_mc @ coeff
            loss_mc = np.mean((pred_mc - target_mc)**2)/2
                
            results.append({
                'reg': reg,
                't': k+1,
                'loss_train': loss_train,
                'loss_l2': loss_mc,
            })

            if early_stopping:
                if loss_mc <= best_loss_mc[k]:
                    best_loss_mc[k] = loss_mc
                else:
                    # We stop comuputing lower regularization for IT(l>=k):
                    # It will only result in overfitting, hence higher l2_loss
                    t = k
    return results


def learning_rate(
    ns: Iterable[int],
    regs: Iterable[float],
    seeds: Iterable[int],
    rs: Iterable[int],
    t: int,
    samples_mc: int,
    early_stopping: bool,
    sigma: float = 1.,
    save_to_dir: bool = True,
    xp_name: str = "",
    alpha: float = 2.,
) -> pd.DataFrame:
    """Computes the l2 loss for IT(reg, t) with X,Y ~ n, seed, r, alpha.

    Parameters
    ----------
    ns : Iterable[int]
        The number of samples to draw.
    regs : Iterable[float]
        The regularization to test. 
    seeds : Iterable[int]
        The seed for the random number generator to ensure reproducibility.
    rs : Iterable[int]
        The source condition for the target function. 
    t : int
        The maximum number of proximal steps to do.
    samples_mc : int
        The number of samples for the MC estimation of the l2 loss.
    save_to_dir : bool, optional
        Whether to save the results on disk, by default True
    xp_name : str, optional
        A name to describe the directory of the experiment, by default ""
    alpha : float, optional
        The capacity condition, by default 2.

    Returns
    -------
    pd.DataFrame
        DataFrame of results; each column is a parameter. 
    """
    # Simple preprocessing on the params
    ns, regs, seeds, alpha, samples_mc, xp_dir = check_xp_params(
        ns, regs, seeds, rs, t, samples_mc, early_stopping, sigma, save_to_dir, xp_name, alpha)

    pb_x = tqdm(iter_product(seeds, ns), total=len(seeds)*len(ns))
    results = None
    for seed, n in pb_x:
        # x data
        rng_x = np.random.default_rng(seed+n)  # Taking seed+n to avoid same data accross n
        x = rng_x.random(size=n)
        K = spline_kernel(x[:, np.newaxis], x[np.newaxis, :], alpha)
        D, U = np.linalg.eigh(K)
        D, U = np.flip(D), np.fliplr(U)
        kernel_mc = spline_kernel(samples_mc[:, np.newaxis], x[np.newaxis, :], alpha)
        
        for r in rs:
            # y data
            # Take a different seed to be able to reproduce the results with specific r
            rng_y = np.random.default_rng(int(r)*1000+seed)

            # We evaluate the optimum on x and the mc samples.
            degree_target = int((r+1/2)*alpha + 1/2)
            target_x = spline_kernel(x, 0, degree_target)
            target_mc = spline_kernel(samples_mc, 0, degree_target)
            
            y_label = target_x + rng_y.normal(loc=0, scale=sigma, size=n)
            y_label_rotated = U.T @ y_label

            pb_x.set_postfix(seed=seed, n=n, r=r)

            it_grid_results = it_estimator(K=K, D=D, U=U, y=y_label, y_rotated=y_label_rotated, regs=regs, t=t, kernel_mc=kernel_mc, target_mc=target_mc, early_stopping=early_stopping)

            # Updates the dataframe with x values
            it_grid_results = pd.DataFrame(it_grid_results)
            it_grid_results[['r', 'n', 'seed']] = r, n, seed
        
            results = it_grid_results if results is None else pd.concat([results, it_grid_results])

            if save_to_dir:
                results.to_csv(xp_dir / 'results.csv', index=False)
    return results


if __name__ == "__main__":
    df = learning_rate(
        ns = np.logspace(2, 4, 20, endpoint=False),
        regs = np.logspace(-4, 0, 50),
        seeds = np.arange(100),
        rs = [1/4 + k for k in [0, 3, 10]],
        t = 8,
        samples_mc = int(1e3),
        sigma = 1.,
        early_stopping=False,
        save_to_dir = True, 
        xp_name = 'n4',
        alpha = 2.,
    )
    # df = learning_rate(
    #     ns = np.logspace(1, 2, 5),
    #     regs = np.logspace(-4, 0, 20),
    #     seeds = np.arange(10),
    #     rs = [1/4 + k for k in range(3)],
    #     t = 3,
    #     samples_mc = int(1e3),
    #     sigma = 1.,
    #     early_stopping=True,
    #     save_to_dir = False, 
    #     xp_name = 'test',
    #     alpha = 2.,
    # )