from pathlib import Path
import os 
from itertools import product as iter_product
from typing import Iterable, Union

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def apply_regression_coefficient(x: str, y: str, log: bool):
    """
    Returns a function which does a linear regression: 
        y = a x + b
    Or, if `log` is True:
        y = b * x^a

    Usage
    -----
    Useful with `apply`: 
        df.groupby(<keys>).apply(apply_regression_coefficient('x', 'y'))
    """
    def f(df: pd.DataFrame):
        x_val, y_val = df[x], df[y]
        if log:
            x_val, y_val = np.log(x_val), np.log(y_val)
        x_mean, y_mean = x_val.mean(), y_val.mean()
        a = np.dot(x_val-x_mean, y_val - y_mean) / np.sum((x_val-x_mean)**2)
        b = y_mean - a * x_mean
        return pd.Series({'a': a, 'b': b})
    return f


def regression_coefficients(df: pd.DataFrame, x: str, y: str, groupby: Union[str, Iterable[str]], log: bool):
    """
    Returns a DataFrame with:
        - index: groupby (possibly MultiIndex)
        - columns: (a, b) s.t y = ax+b or y = b * x^a

    Makes a copy of the DataFrame.
    """
    result = df.copy()
    result = result.groupby(groupby).apply(apply_regression_coefficient(x, y, log))
    return result


def linear_trend(df: pd.DataFrame, x: str, y: str, a: float, log: bool) -> pd.DataFrame:
    """
    Given a DataFrame, returns a Serie z s.t:
        z = a x + b  //  z = b x^a
    with b minimizing:
        sum((y - z)**2)

    Usage
    -----
    To plot a trend line. 
    """
    x_val, y_val = df[x], df[y]
    if log:
        x_val, y_val = np.log(x_val), np.log(y_val)
    x_mean, y_mean = x_val.mean(), y_val.mean()
    b = y_mean - a * x_mean
    if log:
        return np.exp(a * x_val + b)
    return a * x_val + b


def rate_risk(r, t, alpha):
    return 1 - 1/(1 + (min(r, t-1/2) + 1/2) * 2 * alpha)


def rate_reg(r, t, alpha):
    return alpha/(1 + (min(r, t-1/2) + 1/2) * 2 * alpha)

def add_ci(x, mean, std, log: bool, edges_kwargs, fill_kwargs, ax: plt.Axes =None):
    if ax is None:
        ax = plt.gca()
    top = mean+std if not log else np.exp(mean+std)
    bot = mean-std if not log else np.exp(mean-std)
    ax.plot(x, bot, **edges_kwargs)
    ax.plot(x, top, **edges_kwargs)
    ax.fill_between(x, bot, top, **fill_kwargs)
    return ax