# populate losses
import numpy as np
import pandas as pd
import random
from fairrisk_fmow_helper import *
from sklearn.model_selection import train_test_split
pd.options.display.float_format = "{:,.2f}".format
import os
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

split = 'val'
meta = pd.read_csv('val_aggregate_seed.csv')
num_runs = 50

def get_input_col(control_variable):
    t = []
    for v in control_variable:
        if v == 'y':
            t = t + ['y_' + str(i) for i in range(62)]
        elif v == 'region':
            t = t + ['region_' + str(i) for i in range(6)]
        elif v == 'clipClass':
            continue
        else:
            t.append(v)
    return t

def get_cvar_col_name(control_variable, loss_type):
    s = '_cvar'
    for v in control_variable:
        s =  s + '_' + v
    s = s + '_' + loss_type
    return s

control_var_list = [['year', 'region', 'lat', 'lon', 'cloud_cover'], 
                    ['year', 'region'],
                    ['year', 'region', 'y'],
                    ['year', 'region', 'y', 'lat', 'lon', 'cloud_cover']]

model_list = ['CLIP_B16_ensemble_40_val_adj', 
              'imgnt_dpn68', 'imgnt_resnet50',
              'fmow_erm_seed0', 'fmow_irm_seed0', 'fmow_erm_ID_seed0',
              'CLIP_B16_ensemble_24_val_adj', 'imgnt_resnet18',
              'CLIP_B16_ensemble_27_val_adj', 'imgnt_vgg11']

# populate the h_hat losses; this is for cross-entropy loss
for seed in np.arange(num_runs):
    print('current seed: ' + str(int(seed)))
    if (int(seed) % 5 == 1):
        print('saving...')
        meta.to_csv('val_aggregate_seed.csv', index=False)  
    
    X_cvar1, X_cvar2 = train_test_split(meta, train_size=0.5, random_state=seed)
    X_list = {0: X_cvar1, 1: X_cvar2}
    
    for i in [0,1]:
        print('cur i: ' + str(i))
        est_X = X_list[i]
        eva_X = X_list[1-i]
        for m_name in model_list:
            est_loss = est_X[m_name + '_loss']
        
            for control_var in control_var_list:
                if len(control_var) > 1:  
                    input_col = get_input_col(control_var)
                    input_var = est_X[input_col]
                    
                h = gridsearch(input_var, est_loss, 'XGBoost', seed)
                cvar_col_name = get_cvar_col_name(control_var, 'entropy_' + str(int(seed)))
                
                if len(control_var) > 1:  
                    eva_var = eva_X[input_col]
                    
                eva_X[m_name + cvar_col_name] = h.predict(eva_var.astype(np.double))
                
                meta.loc[eva_X.index, m_name + cvar_col_name ] = eva_X[m_name + cvar_col_name]
                
meta.to_csv('val_aggregate_seed.csv', index=False)  
                