import numpy as np
import pandas as pd
from scipy.stats import t
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import LeaveOneOut
from sklearn.metrics import r2_score
import statsmodels.api as sm


def r2_loocv(X, y):
    loo = LeaveOneOut()
    y_true = []
    y_pred = []

    for train_idx, test_idx in loo.split(X):
        model = LinearRegression()
        model.fit(X.iloc[train_idx], y.iloc[train_idx])
        y_pred.append(model.predict(X.iloc[test_idx]))
        y_true.append(y.iloc[test_idx])

    # Compute LOOCV R^2
    return 1 - np.sum((np.array(y_true) - np.array(y_pred))**2) / np.sum((np.array(y_true) - np.mean(y_true))**2)


def forward_selection(df, y_label, x_label_list, x_label_choices):
    if not x_label_list:
        cor_list = []
        for label in x_label_choices:
            cor_list.append(df[y_label].corr(df[label]))
        x_label = x_label_choices[cor_list.index(max(cor_list))]
        r2_loocv_base = r2_loocv(df[[x_label]], df[y_label])
        return x_label_choices[cor_list.index(max(cor_list))], 0.0, r2_loocv_base
    
    pvalue_list = []
    label_list = []
    r2_loocv_improve_list = []
    r2_loocv_list = []
    r2_loocv_base = r2_loocv(df[x_label_list], df[y_label])
    for label in x_label_choices:
        if label not in x_label_list:
            x = df[x_label_list + [label]]
            x = sm.add_constant(x)
            y = df[y_label]
            result = sm.OLS(y, x).fit()

            tvalue = result.tvalues.values[list(result.tvalues.index).index(label)]
            pvalue = 2 * (1.0 - t.cdf(abs(tvalue), result.df_resid))
            if pvalue <= 0.05:
                r2_loocv_new = r2_loocv(df[x_label_list + [label]], df[y_label])
                improve = r2_loocv_new - r2_loocv_base
                if improve >= 0.02:
                    label_list.append(label)
                    pvalue_list.append(pvalue)
                    r2_loocv_improve_list.append(improve)
                    r2_loocv_list.append(r2_loocv_new)

    if label_list:
        i = r2_loocv_improve_list.index(max(r2_loocv_improve_list))
        return label_list[i], pvalue_list[i], r2_loocv_list[i]
    return None, 1., 0.0


def loocv_analysis(df, label, choice_list):
    chosen_labels = []
    p_value_list = []
    r2_loocv_list = []
    while True:
        chosen, p_value, r2_loocv_new = forward_selection(df, label, chosen_labels, choice_list)
        if p_value > 0.05:
            break
        p_value_list.append(p_value)
        chosen_labels.append(chosen)
        r2_loocv_list.append(r2_loocv_new)
    
    return chosen_labels, p_value_list, r2_loocv_list


def control_analysis(df, label, chosen, control):
    x = df[[chosen, control]]
    x = sm.add_constant(x)
    y = df[label]
    result = sm.OLS(y, x).fit()

    tvalue = result.tvalues.values[list(result.tvalues.index).index(chosen)]
    pvalue = 2 * (1.0 - t.cdf(abs(tvalue), result.df_resid))

    param = result.params.values[list(result.params.index).index(chosen)]
    control_param = result.params.values[list(result.params.index).index(control)]
    return pvalue, param, control_param


def analyze():
    df = pd.read_csv("junior_golf.csv")

    # State,Population,Courses,MHI,Solar,PGA,LPGA,
    # BoyParticipants,BoyTop50,BoyTop50PP,BoyTop100,BoyTop100PP,BoyTop200,BoyTop200PP,
    # GirlParticipants,GirlTop50,GirlTop50PP,GirlTop100,GirlTop100PP,GirlTop200,GirlTop200PP

    # LOOCV Analysis
    # For each experiments, it will automatically select variables and terminate when there is no
    # more variables available. It prints out the chosen variables, it's p-value, and also LOOCV
    # R^2 values.
    base_choices = ['Population', 'Courses', 'MHI', 'Solar', 'PGA', 'LPGA']
    # Experiments for Top N players
    experiments = [('BoyTop50', ['BoyParticipants', 'BoyTop50PP']),
                   ('BoyTop100', ['BoyParticipants', 'BoyTop100PP']),
                   ('BoyTop200', ['BoyParticipants', 'BoyTop200PP']),
                   ('GirlTop50', ['GirlParticipants', 'GirlTop50PP']),
                   ('GirlTop100', ['GirlParticipants', 'GirlTop100PP']),
                   ('GirlTop200', ['GirlParticipants', 'GirlTop200PP'])]

    # # Experiments for participants
    # experiments = [('BoyParticipants', []),
    #                ('GirlParticipants', [])]

    for label, additional_choices in experiments:
        chosen_labels, p_value_list, r2_loocv_list = loocv_analysis(df, label, base_choices + additional_choices)
        print("----->>>>>", label, chosen_labels, p_value_list, r2_loocv_list)

    # Control Anaysis
    # For each experiments, compute the \beta_X and p-value for X controlling on Participanets.
    # It prints out the p-value, \beta_X, and \beta_{Participants} for each experiment.
    experiments = [('BoyTop50', 'BoyParticipants'),
                   ('BoyTop100', 'BoyParticipants'),
                   ('BoyTop200', 'BoyParticipants'),
                   ('GirlTop50', 'GirlParticipants'),
                   ('GirlTop100', 'GirlParticipants'),
                   ('GirlTop200', 'GirlParticipants')]
    for chosen in ['PGA', 'PP', 'Solar']:
        for label, control in experiments:
            if chosen.endswith('PP'):
                chosen = label + 'PP'
            print("----->>>>>", label, chosen, control, control_analysis(df, label, chosen, control))


if __name__ == '__main__':
    analyze()