import os
from omegaconf import OmegaConf, DictConfig
env = OmegaConf.load("configs/config.yaml").env
for k, v in env.items():
    os.environ[k] = v

import sys
import hydra
import logging
import warnings
from utils import *
import numpy as np
import pandas as pd
from tqdm import tqdm
from stats_count import *
from transformers import AutoConfig
from sklearn import linear_model, preprocessing
from sklearn.metrics import accuracy_score, matthews_corrcoef, f1_score

warnings.filterwarnings('ignore')
tqdm._instances.clear()
tqdm.monitor_interval = 0
sys.stdout.flush()
np.random.seed(42) # For reproducibility.
logging.basicConfig(level=logging.INFO, format="%(message)s")

def pred_by_Xy(X_train, y_train, X_test, classifier, verbose=False, scale=True):

    if scale:
        scaler  = preprocessing.StandardScaler().fit(X_train)
        X_train = scaler.transform(X_train)

    classifier.fit(X_train, y_train)
    
    if verbose:
        print("train matt:", matthews_corrcoef(y_train, classifier.predict(X_train)))
        print("train acc: ", accuracy_score(y_train, classifier.predict(X_train)))
        print("train f1: ", f1_score(y_train, classifier.predict(X_train)))
    
    if scale:
        X_test = scaler.transform(X_test)
        
    return classifier.predict(X_test), \
           matthews_corrcoef(y_train, classifier.predict(X_train)), \
           accuracy_score(y_train, classifier.predict(X_train)), \
           f1_score(y_train, classifier.predict(X_train))

@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg: DictConfig):
    model_path = cfg.model.model_path
    config = AutoConfig.from_pretrained(model_path)
    num_layers = getattr(config, "num_hidden_layers", None)
    layers_of_interest = [i for i in range(num_layers)]

    max_examples_to_train = 10**10
    solver  = "lbfgs"
    is_dual = False
    data_name = cfg.data.name.split('_')[0]
    
    if 'hypo' or 'trofi' in cfg.data.name:
        data_name = "_".join(cfg.data.name.split('_')[:-1])
    
    train_subset = f'{data_name}_train'
    test_subset = f'{data_name}_test'
    
    old_f_train_file = f'./experiments/{cfg.model.model_name}/{train_subset}/features/all_heads_{str(len(layers_of_interest))}_layers_s_e_v_c_b0b1_lists_array_6_thrs_MAX_LEN_128_{train_subset}.npy'
    old_f_test_file = f'./experiments/{cfg.model.model_name}/{test_subset}/features/all_heads_{str(len(layers_of_interest))}_layers_s_e_v_c_b0b1_lists_array_6_thrs_MAX_LEN_128_{test_subset}.npy'
    ripser_train_file = f'./experiments/{cfg.model.model_name}/{train_subset}/features/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_128_{train_subset}_ripser.npy'
    ripser_test_file = f'./experiments/{cfg.model.model_name}/{test_subset}/features/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_128_{test_subset}_ripser.npy'
    templ_train_file = f'./experiments/{cfg.model.model_name}/{train_subset}/features/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_128_{train_subset}_template.npy'
    templ_test_file = f'./experiments/{cfg.model.model_name}/{test_subset}/features/all_heads_{str(len(layers_of_interest))}_layers_MAX_LEN_128_{test_subset}_template.npy'

    train_data = pd.read_csv(f'{cfg.data.input_dir}/{train_subset}.csv').reset_index(drop=True)
    test_data = pd.read_csv(f'{cfg.data.input_dir}/{test_subset}.csv').reset_index(drop=True)

    if 'sarcasm_v2' in cfg.data.input_dir:
        train_data["labels"] = (train_data["labels"] == "sarc").astype(int)
        test_data["labels"] = (test_data["labels"] == "sarc").astype(int)

    old_features_train = np.load(old_f_train_file, allow_pickle=True)[:,:,:,:max_examples_to_train,:]
    old_features_test  = np.load(old_f_test_file, allow_pickle=True)[:,:,:,:max_examples_to_train,:]
    ripser_train = np.load(ripser_train_file, allow_pickle=True)[:,:,:max_examples_to_train,:]
    ripser_test  = np.load(ripser_test_file, allow_pickle=True)[:,:,:max_examples_to_train,:]
    templ_train = np.load(templ_train_file, allow_pickle=True)[:,:,:,:max_examples_to_train]
    templ_test  = np.load(templ_test_file, allow_pickle=True)[:,:,:,:max_examples_to_train]
    train_data = train_data[:max_examples_to_train]

    X_train = []
    for i in range(len(train_data)):
        features = np.concatenate((old_features_train[:,:,:,i,:].flatten(),
                                ripser_train[:,:,i,:].flatten(),
                                templ_train[:,:,:,i].flatten()))
        X_train.append(features)
    y_train = train_data["labels"]

    X_test = []
    for i in range(len(test_data)):
        features = np.concatenate((old_features_test[:,:,:,i,:].flatten(),
                                ripser_test[:,:,i,:].flatten(),
                                templ_test[:,:,:,i].flatten()))
        X_test.append(features)
    y_test = test_data["labels"]

    X_train = X_train[:max_examples_to_train]
    train_data = train_data[:max_examples_to_train]

    try:
        assert(len(train_data) == len(X_train))
        assert(len(test_data) == len(X_test))
    except:
        print("ASSERTION ERROR!!!")

    C_range = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 2]
    max_iter_range = [1, 2, 3, 5, 10, 25, 50, 100, 500, 1000, 2000]

    matt_scores = dict()
    acc_scores  = dict()
    f1_scores  = dict()
    matt_scores_train = dict()
    acc_scores_train  = dict()
    f1_scores_train  = dict()
    results     = dict()

    print(f'train features path : ./experiments/{cfg.model.model_name}/{train_subset}/features')
    print(f'test features path : ./experiments/{cfg.model.model_name}/{test_subset}/features')

    for C in tqdm(C_range,total=len(C_range), file=sys.stdout):
        for max_iter in max_iter_range:
            classifier = linear_model.LogisticRegression(penalty='l2', C=C, max_iter=max_iter, dual=is_dual, solver=solver)
            result, train_matt, train_acc, train_f1 = pred_by_Xy(X_train, y_train, X_test, classifier)
            results[(C, max_iter)] = result

            matt_scores_train[(C, max_iter)] = matthews_corrcoef(result, y_test)
            acc_scores_train[(C, max_iter)]  = accuracy_score(result, y_test)
            f1_scores_train[(C, max_iter)]  = f1_score(result, y_test)

            try:
                matt_scores[(C, max_iter)] = matthews_corrcoef(result, y_test)
                acc_scores[(C, max_iter)]  = accuracy_score(result, y_test)
                f1_scores[(C, max_iter)]  = f1_score(result, y_test)
                #print("test matt: ", matthews_corrcoef(result, y_test))
                #print("test acc:  ", accuracy_score(result, y_test))
            except:
                #print("Not labeled")
                pass

    try:
        print("---")
        best_acc_key = max(acc_scores, key=acc_scores.get)
        best_f1_key = max(f1_scores, key=f1_scores.get)
        best_matt_key = max(matt_scores, key=matt_scores.get)
        print(f"{data_name}")
        print("Accuracy:", round(acc_scores[best_acc_key], 3))
        print("F1:", round(f1_scores[best_f1_key], 3))
        #print("MCC:", round(matt_scores[best_matt_key], 3))
        #print("total : ",round(acc_scores[best_acc_key], 3), round(f1_scores[best_f1_key], 3), round(matt_scores[best_matt_key], 3))
    except:
        print("Data is not labeled")

if __name__ == "__main__":
    main()