from pynvml import *
import psutil
import numpy as np
import json
import pandas as pd

from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from sklearn.linear_model import HuberRegressor


def get_mem_usage_stats(msg):
    nvml.nvmlInit()
    for idx in range(nvml.nvmlDeviceGetCount()):
        handle = nvml.nvmlDeviceGetHandleByIndex(idx)
        info = nvml.nvmlDeviceGetMemoryInfo(handle)
        total_mem = info.total
        total_mem = total_mem / (1024*1024*1024)
        free_mem = info.free
        free_mem = free_mem / (1024*1024*1024)
        print(f"State {msg} Device {idx}: total mem {total_mem:.2f}, free mem {free_mem:.2f}")
    
    print(f"State {msg} Total virtual memory available {psutil.virtual_memory().available//1024**3} GB | Virtual memory % used: {np.round(psutil.virtual_memory().percent)}% | CPU memory % used: {np.round(psutil.cpu_percent())}%") 


def load_gradient_stats(path):

    with open(path) as f:
        ff = json.load(f)

    ls = []
    for k in ff.keys():
        temp = pd.DataFrame(ff[k]['0']).T.reset_index().rename(columns={'index':'example_num'})
        temp['example_num'] = temp['example_num'].astype(int)
        temp['task'] = k.lower()
        ls.append(temp)
    df_grad = pd.concat(ls)
    
    return df_grad


def load_proba_stats(path):

    with open(path) as f:
        ff = json.load(f)
    ls = []
    for k in ff.keys():
        temp = pd.DataFrame(ff[k]['0']).T.reset_index().rename(columns={'index':'example_num'})
        temp = temp.explode(['target_proba','target','model_pred_proba','model_pred'])
        temp['example_num'] = temp['example_num'].astype(int)
        temp['task'] = k.lower()
        ls.append(temp)

    df_proba = pd.concat(ls)

    df_proba = df_proba.groupby(['task','example_num']).agg({'target_proba':'mean',
                                              'model_pred_proba':'mean',
                                              'avg_error':'mean',
                                              'avg_confidence':'mean',
                                              'target':'unique','model_pred':'unique'
                                              }).reset_index()    
    return df_proba


def load_auc(path, tasks_to_exclude):

    with open(path) as f:
        ff = json.load(f)
    
    df_auc = pd.DataFrame(ff).T.sort_values('extrapolation_auc').reset_index().rename(columns={'extrapolation_auc':'auc','index':'task'})
    df_auc = df_auc[~df_auc['task'].isin(tasks_to_exclude)].reset_index(drop=True)

    return df_auc


def load_cosine_stats(path):

    with open(path) as f:
        ff = json.load(f)
        cos = {}
        for k in ff.keys():
            cos[k] = {}
            for iter in ff[k]:
                cos[k][iter] = []
                for b_num in ff[k][iter]:
                    if b_num != 'all':
                        tmp = pd.DataFrame.from_dict(ff[k][iter][b_num], orient='index').T
                        np.fill_diagonal(tmp.values, None)
                        cos[k][iter].append(tmp)
                t1=pd.concat(cos[k][iter])
                t2=pd.concat(cos[k][iter]).T
                org_cols = t1.columns
                t1.index = list(np.arange(0,t1.shape[0],1))
                t1.columns = list(np.arange(0,t1.shape[0],1))
                t2.index = list(np.arange(0,t2.shape[0],1))
                t2.columns = list(np.arange(0,t2.shape[0],1))
                t1 = t1.fillna(t2)
                t1.columns = org_cols
                t1.index = org_cols
                cos[k][iter] = t1

    d = {}
    ls_df=[]
    for k in cos:
        d[k] = []
        for iter in cos[k]:
            d[k].append(cos[k][iter].mean(axis=0).reset_index())
        t=pd.concat(d[k]).rename(columns={'index':'example_num',0:'cos_sim'})
        t['task']=k.lower()
        ls_df.append(t)

    df_cos = pd.concat(ls_df)
    df_cos['example_num'] = df_cos['example_num'].astype(int)
    df_cos['cos_sim'] = df_cos['cos_sim'].astype(float)
    df_cos['cos_sim_log'] = np.log2(df_cos['cos_sim'] + 1)

    return df_cos


def run_leave_one_out_huber(df_in, depvar, predvar):

    d = {}
    tasks = list(df_in['task'].unique())
    
    for t in tasks:
        d[t] = {}
        train = df_in[df_in['task'] != t].reset_index(drop=True)
        test_heldout = df_in[df_in['task'] == t].reset_index(drop=True) # one heldout task

        huber = HuberRegressor().fit(train[predvar], train[depvar])
        test_heldout['pred'] = huber.predict(test_heldout[predvar])
		
        d[t]['pred'] =  test_heldout['pred'].item()
        d[t]['depvar'] = test_heldout[depvar].item()

        d[t]['abs_diff'] =  abs(test_heldout[depvar] - test_heldout['pred']).item()
        d[t]['diff'] = (test_heldout[depvar] - test_heldout['pred']).item()
        d[t]['sqrd_err'] = ((test_heldout[depvar] - test_heldout['pred'])**2).item()

    res =pd.DataFrame(d).T.reset_index().rename(columns={'index':'task'})

    return res


def run_leave_one_out(df_in, depvar, predvar):
    
	# fit 
	d = {}
	tasks = list(df_in['task'].unique())

	for t in tasks:    
		d[t] = {}
		train = df_in[df_in['task'] != t].reset_index(drop=True)
		test_heldout = df_in[df_in['task'] == t].reset_index(drop=True) # one heldout task

		# fit the simple regression model
		formula = f"""{depvar} ~ """
		for p in predvar:
			formula += f"{p} + "
		formula = formula.strip(" +").strip()
		lm = ols(formula,train).fit()

		test_heldout['pred'] = 0
		for p in predvar:
			test_heldout['pred'] += test_heldout[p] * lm.params[p].item()
			d[t][f'coeff_{p}'] = lm.params[p].item()
			d[t][f'pval_{p}'] = lm.pvalues[p].item()
		
		test_heldout['pred'] += lm.params['Intercept'].item()
		
		d[t]['intercept'] = lm.params['Intercept'].item()
		d[t]['pval_intercept'] = lm.pvalues['Intercept'].item()
		d[t]['pred'] =  test_heldout['pred'].item()
		d[t]['depvar'] = test_heldout[depvar].item()
		d[t]['abs_diff'] =  abs(test_heldout[depvar] - test_heldout['pred']).item()
		d[t]['diff'] = (test_heldout[depvar] - test_heldout['pred']).item()
		d[t]['sqrd_err'] = ((test_heldout[depvar] - test_heldout['pred'])**2).item()

	res =pd.DataFrame(d).T.reset_index().rename(columns={'index':'task'})

	return res

def get_model_prefix(model_name):
    
    model_prefix=None
    if "mistral" in model_name.lower():
        model_prefix = "mistral"
    elif "llama" in model_name.lower():
        model_prefix = "llama"
    elif "qwen" in model_name.lower():
        model_prefix = "qwen"
    elif "smollm" in model_name.lower():
        model_prefix = "smollm"

    return model_prefix 
