
# import the add_up function from optweights
from weights import weights
from data import Toy
from model import logistic_regression
from weight_searcher import weight_searcher

from helpers import set_seed
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from matplotlib import pyplot as plt
import sys

def main():

    # number of datapoints to be created
    n_train = 5000
    n_val = 1000

    # set probability of group in training, save in dict
    ptr = 0.1
    pte = 0.5


    # set the other parameters
    beta_1 = 1
    beta_0 = 0
    sigma_1 = 1
    sigma_0 = 1
    gamma = 1
    a_1 = 0
    a_0 = 0
    d=100
    mu = np.zeros(d)


    # instantiate obj. for training and validation + test
    set_seed(0)
    toy_data_tr = Toy(n=n_train, p_1=ptr, beta_1=beta_1, beta_0=beta_0, sigma_1=sigma_1, sigma_0=sigma_0,  mu=mu, gamma=gamma, a_0=a_1, a_1=a_0, d=d)
    toy_data_val = Toy(n=n_val, p_1=ptr, beta_1=beta_1, beta_0=beta_0, sigma_1=sigma_1, sigma_0=sigma_0,  mu=mu, gamma=gamma, a_0=a_1, a_1=a_0, d=d)
    toy_data_te = Toy(n=n_val, p_1=pte, beta_1=beta_1, beta_0=beta_0, sigma_1=sigma_1, sigma_0=sigma_0,  mu=mu, gamma=gamma, a_0=a_1, a_1=a_0, d=d)

    # create training, validation and test data
    X_train, y_train, g_train = toy_data_tr.dgp_mv(logistic=True)
    X_val, y_val, g_val = toy_data_val.dgp_mv(logistic=True)
    X_te, y_te, g_te = toy_data_te.dgp_mv(logistic=True)

    # create a logistic regression model
    model_param  = {'max_iter': 100,
                    'penalty': 'l1',
                    'C': 1,
                    'solver': 'liblinear',
                    'tol': 1e-4,
                    'verbose': 0,
                    'random_state': 0,
                    'fit_intercept': True, 
                    'warm_start': False}
    


     # plot for each p between 0 and 1 the loss
    loss_fn = log_loss
    options=100
    possible_p = np.linspace(0.05, 0.95, options)
    loss_per_p = np.zeros((options, 1))
    weight_obj_val = weights(p_w=p_ood, p_train=p_train)
 
    # create a weight searcher object
    ws = weight_searcher(logreg, X_train, y_train, g_train, X_val, y_val, g_val, p_ood, GDRO=False, weight_rounding=4, p_min=10e-4)

    # start the search
    start_p = p_ood
    T = 100
    lr = 0.9
    momentum = 0.5
    patience = T
    lr_schedule = 'linear'
    stable_exp = False
    subsample_weights = False
    lock_in_p_g = None
    verbose = True
    decay = 0.99

    # optimize the weights    
    p_hat, p_per_t, loss_per_t =  ws.optimize_weights( start_p, T,  lr,  momentum, patience=patience,    
                                  verbose=verbose,  lr_schedule=lr_schedule,stable_exp=stable_exp, subsample_weights=subsample_weights, lock_in_p_g = lock_in_p_g,
                                  save_trajectory=True, decay=decay)


    for i in range(options):
        # set the weights for the group
        p_w = {1: possible_p[i], 2: 1-possible_p[i]}
        
        # create the weight object
        weight_obj_p = weights(p_w=p_w, p_train=p_train)

        # create the model object, fit it, predict and get the loss
        model_obj = model(weight_obj_p, logreg)
        set_seed(0)
        model_obj.fit_model(X_train, y_train, g_train)
        loss_p =  calc_loss_for_model(model_obj, loss_fn, X_val, y_val, g_val, weights_obj = weight_obj_val, type_pred='probabilities')


        # get the prediction and loss
        loss_per_p[i] = loss_p
    

    # plot the loss
    plt.plot(possible_p, loss_per_p)
    plt.xlabel('p')
    plt.ylabel('loss')
    plt.title('Loss for different p')

    # plot the p_hat
    loss_p_hat = loss_per_t[-1]
    plt.plot(p_hat[1], loss_p_hat, 'ro', markersize=10, zorder=10)

    # plot the trajectory of p_t with dots at each step
    plt.plot(p_per_t[:, 0], loss_per_t, linestyle='dashed', color='orange', marker='o')

    # plot the p_ood
    loss_p_ood = loss_per_t[0]
    plt.plot(p_ood[1], loss_p_ood, 'bo')

    # plot the p_tr - one closes to it in possible_p
    idx = np.argmin(np.abs(possible_p - p_train[1]))
    loss_p_tr = loss_per_p[idx]
    plt.plot(possible_p[idx], loss_p_tr, 'go')
    



   
    
    plt.show()


    










   


# if main is run, run the tests
if __name__ == "__main__":
    
    main()