import numpy as np
import argparse
import sys
import os
import json
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
# import my package
from proto_classifier import ProtoClassifier
from query_oracle_features import get_feature_proto
import utils
from data import get_dataset_few_shot, XY2df
from config import n_class

def get_gpt_weights(N_cols,C_cols,dataset):
    weights_list = []
    for file in os.listdir(f'./gpt_weights/{dataset}'):
        if file.endswith('.json'):
            with open(f'./gpt_weights/{dataset}/{file}') as f:
                weights = json.load(f)
                weights_list.append(weights)
    # normalize weight
    for i in range(len(weights_list)):
        sum = 0
        for col in N_cols+C_cols:
            sum += weights_list[i][col]
        for col in N_cols+C_cols:
            weights_list[i][col] = weights_list[i][col] /sum
    mean_weights = {}
    for col in N_cols+C_cols:
        mean_weights[col] = np.mean([weights[col] for weights in weights_list])
    return mean_weights
        

def preprocess(weight,support_x_n, query_x_n,gpt_x_n, support_x_c, query_x_c,x_c_unlabel,gpt_x_c, gpt_shot,N_cols, C_cols):
    if len(N_cols) > 0:
        x_n_train = []
        x_n_test = []
        x_n_gpt = []
        for col in N_cols:
            scaler = StandardScaler()
            col_values = [support_x_n[col].values.reshape(-1,1)] if support_x_n is not None else []

            if gpt_shot > 0:
                label_class = np.sort(list(gpt_x_n[col].keys()))
                for target in label_class.tolist():
                    all_gpt_values = np.array(sum(gpt_x_n[col][target],[]))
                    col_values.append(all_gpt_values.reshape(-1,1))
            col_values = np.concatenate(col_values, axis=0)
            scaler.fit(col_values)
            if support_x_n is not None:
                x_n_train.append(scaler.transform(support_x_n[col].values.reshape(-1,1))*weight[col])
            x_n_test.append(scaler.transform(query_x_n[col].values.reshape(-1,1))*weight[col])

            if gpt_shot > 0:
                proto_feature = []
                y_gpt = []
                for target in label_class.tolist():
                    for gpt_shot_values in gpt_x_n[col][target]:
                        gpt_values = np.array(gpt_shot_values)
                        proto_feature.append(scaler.transform(gpt_values.mean(axis=0).reshape(1,-1)))
                        y_gpt.append(target)

                    #proto_feature.append(scaler.transform(np.array(gpt_x_n[col][target]).mean(axis=0).reshape(1,-1)))
                proto_feature = np.concatenate(proto_feature, axis=0)*weight[col]
                x_n_gpt.append(proto_feature)
        x_n_train = np.concatenate(x_n_train, axis=1) if support_x_n is not None else None
        x_n_test = np.concatenate(x_n_test, axis=1)
        x_n_gpt = np.concatenate(x_n_gpt, axis=1) if gpt_shot > 0 else None

    X_c_gpt = []
    x_c_train = []
    x_c_test = []
    if len(C_cols) > 0:
        for col in C_cols:
            ohe = OneHotEncoder()
            # need all data to fit onehotencoder, to avoid missing categories in test data
            all_data = np.concatenate([support_x_c[col].values.reshape(-1,1), query_x_c[col].values.reshape(-1,1), x_c_unlabel[col].values.reshape(-1,1)], axis=0) if support_x_c is not None else np.concatenate([query_x_c[col].values.reshape(-1,1), x_c_unlabel[col].values.reshape(-1,1)], axis=0)
            ohe.fit(all_data)
            if support_x_c is not None:
                x_c_train.append(ohe.transform(support_x_c[col].values.reshape(-1,1)).toarray()*weight[col])
            x_c_test.append(ohe.transform(query_x_c[col].values.reshape(-1,1)).toarray()*weight[col])
            if gpt_shot > 0:
                proto_feature = []
                y_gpt = []
                label_class = np.sort(list(gpt_x_c[col].keys()))
                for target in label_class.tolist():
                    for gpt_shot_values in gpt_x_c[col][target]:
                        gpt_values = np.array(gpt_shot_values)
                        proto_feature.append(ohe.transform(gpt_values.reshape(-1,1)).toarray().mean(axis=0).reshape(1,-1))
                        y_gpt.append(target)
                    #proto_feature.append(ohe.transform(np.array(gpt_x_c[col][target]).reshape(-1,1)).toarray().mean(axis=0).reshape(1,-1))
                proto_feature = np.concatenate(proto_feature, axis=0)*weight[col]
                X_c_gpt.append(proto_feature)
        
            # X_c_gpt.append(np.mean(ohe.transform(gpt_x_c[col].values.reshape(-1,1)).toarray(),axis=0))
        x_c_train = np.concatenate(x_c_train, axis=1) if support_x_c is not None else None
        x_c_test = np.concatenate(x_c_test, axis=1)
        X_c_gpt = np.concatenate(X_c_gpt, axis=1) if gpt_shot > 0 else None
    y_gpt = np.array(y_gpt) if gpt_shot > 0 else None
    if len(N_cols) > 0 and len(C_cols) > 0:
        X_train = np.concatenate((x_n_train, x_c_train), axis=1) if support_x_n is not None and support_x_c is not None else None
        X_test = np.concatenate((x_n_test, x_c_test), axis=1)
        x_gpt = np.concatenate((x_n_gpt, X_c_gpt), axis=1) if gpt_shot > 0 else None
    elif len(N_cols) > 0:
        X_train = x_n_train
        X_test = x_n_test
        x_gpt = x_n_gpt
    elif len(C_cols) > 0:
        X_train = x_c_train
        X_test = x_c_test
        x_gpt = X_c_gpt
    if gpt_shot == 0:
        x_gpt = None
        y_gpt = None
    return X_train, X_test,x_gpt, y_gpt

def get_args(command=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--shot', type=int)
    parser.add_argument('--seed', type=int)
    parser.add_argument('--data', type=str)
    parser.add_argument('--gpt_shot', type=int, default=10)
    parser.add_argument('--dist_type', type=str, choices=['euclidean', 'manhattan', 'cosine', 'dot'], default='euclidean')
    if command is not None:
        args = parser.parse_args(command.split())
    else:
        args = parser.parse_args()
    return args
def get_file_prefix(args):
    return f'{args.data}-{args.shot}-{args.gpt_shot}-{args.seed}-{args.dist_type}-weighted'
if __name__ == '__main__':
    args = get_args()
    
    _SEED = args.seed
    _DATA = args.data
    _SHOT = args.shot // n_class[_DATA]
    assert _SHOT>0 or args.gpt_shot>0, 'SHOT and gpt_shot cannot be both 0'
    df, info, X, y = get_dataset_few_shot(_DATA, _SHOT, _SEED)
    X,y = XY2df(X,y,info)
    N_cols = info['N_cols']

    C_cols = info['C_cols']
    # remove the c_cols that only have one unique value
    C_cols = [col for col in C_cols if len(df[col].unique()) > 1]
    info['C_cols'] = C_cols


    target = info['target']
    X_n_support = X['support_x_n'] if _SHOT >= 1 else None
    X_c_support = X['support_x_c'] if _SHOT >= 1 else None
    X_n_query = X['query_x_n']
    X_c_query = X['query_x_c']
    X_n_unlabel = X['unlabeled_x_n']
    X_c_unlabel = X['unlabeled_x_c']
    y_support = y['support_y'] if _SHOT >= 1 else None
    y_query = y['query_y']

    cat_condidates = {}
    for col in C_cols:
        cat_condidates[col] = list(df[col].unique())
    cat_condidates[target] = list(df[target].unique())

    weight = get_gpt_weights(N_cols,C_cols,_DATA)
    
    x_n_gpt = get_feature_proto(info,cat_condidates,_DATA, N_cols, args.gpt_shot)
    x_c_gpt = get_feature_proto(info,cat_condidates,_DATA, C_cols, args.gpt_shot)
    
    X_train, X_test,x_gpt, y_gpt = preprocess(weight,X_n_support, X_n_query,x_n_gpt, X_c_support, X_c_query,X_c_unlabel, x_c_gpt,args.gpt_shot, N_cols, C_cols)
    # y_train, y_test = y_support[target], y_query[target]

    y_train = y_support[target] if _SHOT >= 1 else None
    y_test = y_query[target]
    
    
    if _SHOT > 0 and args.gpt_shot > 0:
        X_train = np.concatenate((X_train, x_gpt), axis=0)
        y_train = np.concatenate((y_train, y_gpt), axis=0)
    elif args.gpt_shot > 0:
        X_train = x_gpt
        y_train = y_gpt
    elif _SHOT > 0:
        pass
    model = ProtoClassifier(X_train, y_train)

    pred,probs = model.predict(X_test,dist_type=args.dist_type, predict_type='nearest_prototype',return_probs=True)
    prefix = get_file_prefix(args)

    probIdx = utils.classLabel2IdxInProb(y_test, model.classes_)
    multiclass = True if len(model.classes_) > 2 else False
    auc = utils.evaluate(probs, probIdx, multiclass=multiclass)
    print(f'auc: {auc}')
    np.save(f'./results/{prefix}-auc.npy', auc)

