import os
import platform
import sys
import time
import pickle
from IPython.display import clear_output
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random

import seaborn as sns

import sklearn
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

### Base functions
def fix_seed(seed):
    # Numpy
    np.random.seed(seed)
    # for built-in random
    random.seed(seed)
    # for hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    
def fun_NormScale(data, params=None):
    if params is None:
        tmp_mean = data.mean()
        tmp_std = data.std()
    else:
        tmp_mean = params['Mean']
        tmp_std = params['Std']
    return (data - tmp_mean)/(tmp_std), {'Mean':tmp_mean, 'Std':tmp_std}

def fun_invNormScale(data, params):
    mean = params['Mean']
    std = params['Std']
    return data*std+mean

def plot_scatter(y_obs_list, 
                 y_prd_list, 
                 title_list, 
                 plt_row, 
                 plt_col, 
                 position_list, 
                 col_list,
                 alpha_list,
                 fig_size, 
                 save_name, 
                 title, 
                 show_flg=True):
    fig = plt.figure(figsize=fig_size)

    for i_plt in range(len(position_list)):
        ax = fig.add_subplot(plt_row, plt_col, position_list[i_plt], 
                             title=title_list[i_plt], 
                             xlabel='Observation', 
                             ylabel='Prediction')
        ax.scatter(y_obs_list[i_plt], y_prd_list[i_plt], color=col_list[i_plt], alpha=alpha_list[i_plt])
        xy_min = min(ax.get_xlim()[0], ax.get_ylim()[0])
        xy_max = max(ax.get_xlim()[1], ax.get_ylim()[1])
        ax.axis('equal')
        ax.axis('square')
        ax.set_xlim([xy_min, xy_max])
        ax.set_ylim([xy_min, xy_max])
        ax.grid(color='gray', linestyle='dotted', linewidth=1, alpha=0.5)
        ax.text(0.03, 0.93, 'Corr : '+str(round(np.corrcoef(y_prd_list[i_plt], y_obs_list[i_plt])[0,1], 4)), size=15, transform=ax.transAxes)
        ax.text(0.03, 0.87, 'RMSE : '+str(round(mean_squared_error(y_obs_list[i_plt], y_prd_list[i_plt], squared=False), 4)), size=15, transform=ax.transAxes)
        ax.text(0.03, 0.81, 'MAE : '+str(round(mean_absolute_error(y_obs_list[i_plt], y_prd_list[i_plt]), 4)), size=15, transform=ax.transAxes)
        _ = ax.plot([-300, 300], [-300, 300], color='gray', linewidth=0.5)

    fig.tight_layout(rect=[0,0,1,0.90])
    
    plt.suptitle(title,fontsize=20)

    fig.savefig(save_name)
    if show_flg==False:
        plt.close(fig)