from torch.utils.data import DataLoader
import torch

import os

from whic_utils import load_whic, pairwise_context, pairwise_direction

from minicons import cwe

from collections import defaultdict

import argparse

import csv

from tqdm import tqdm

from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import GridSearchCV, PredefinedSplit

from sklearn.metrics import classification_report, accuracy_score, f1_score, matthews_corrcoef

import numpy as np

from approximation_model import NonLinearApproximator

from paths import auth1_path

os.environ["TOKENIZERS_PARALLELISM"] = "false"

parser = argparse.ArgumentParser()
parser.add_argument("--approximation", default = "original", type = str)
parser.add_argument("--batchsize", default = 128, type = int)
parser.add_argument("--device", default = 'cpu', type = str)
parser.add_argument("--hidden_size", default = 100, type = int)
parser.add_argument("--penalty", default = 0.001, type = float)
args = parser.parse_args()

# model_name = args.model
approximation_mode = args.approximation
device = args.device
batchsize = args.batchsize
hs = args.hidden_size
penalty = args.penalty

bert = cwe.CWE('bert-base-uncased', 'cuda:1')

def whic_set(dataset, layer = 0, approximation_mode = approximation_mode, model = bert):
    data_dl = DataLoader(dataset, num_workers = 8, batch_size = args.batchsize)
    
    vectors = []
    labels = []

    if approximation_mode == 'laser':
        checkpoint = f'{auth1_path}/makesense_logs/bert/{layer}/version_laser_2048_2_0-0001.ckpt'

        approximator = NonLinearApproximator.load_from_checkpoint(checkpoint)
        approximator.eval()


    elif approximation_mode == 'ser':
        checkpoint = f'{auth1_path}/makesense_logs/bert/{layer}/version_ser_2048_2_0-0001-v1.ckpt'

        approximator = NonLinearApproximator.load_from_checkpoint(checkpoint)
        approximator.eval()
    
    for batch in data_dl:
        context1, context2, label = batch
        context1, context2 = [list(zip(*x)) for x in [context1, context2]]
        context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]
        context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]

        label = list(map(lambda x: int(x), label))

        c1 = model.extract_representation(context1, layer)
        c2 = model.extract_representation(context2, layer)
        
        labels.extend(label)

        if approximation_mode == "laser" or approximation_mode == "ser":
            c1 = approximator(c1).detach()
            c2 = approximator(c2).detach()
        
        vectors.extend(torch.cat((c1, c2), 1))
    
    labels = torch.tensor(labels).numpy()
    vectors = torch.stack(vectors).numpy()
    
    return vectors, labels

def load_train_splits(layer):

    train = load_whic('train')
    val = load_whic('dev')
    test = load_whic('test')

    train_X, train_y = whic_set(train, layer = layer)
    val_X, val_y = whic_set(val, layer = layer)
    test_X, test_y = whic_set(test, layer = layer)

    X = np.concatenate((train_X, val_X), axis = 0)
    y = np.concatenate((train_y, val_y), axis = 0)

    cv_vals = [-1 if i < len(train_X) else 0 for i in range(len(X))]

    pds = PredefinedSplit(test_fold=cv_vals)

    return X, y, pds

def directional_accuracy(clf, layer):
    test_directional = pairwise_direction('test')
    dir_p_X, dir_p_y = whic_set(test_directional[0], layer = layer)
    dir_n_X, dir_n_y = whic_set(test_directional[1], layer = layer)

    accuracy = ((clf.predict(dir_p_X) == 1) * (clf.predict(dir_n_X) == 0)).mean()
    
    return accuracy

def contextual_accuracy(clf, layer):
    test_context = pairwise_context('test')
    context_p_X, context_p_y = whic_set(test_context[0], layer = layer)
    context_n_X, context_n_y = whic_set(test_context[1], layer = layer)

    results = defaultdict(float)
    for word, idx in test_context[2].items():
        positive = clf.predict(context_p_X[idx])
        negative = clf.predict(context_n_X[idx])
        
        pairwise_acc = ((positive == 1) * (negative == 0)).mean()
        
        results[word] = pairwise_acc

    return np.mean(list(results.values()))

seeds = [100, 42, 1234, 211, 2511]
results = []

for layer in range(bert.layers + 1):
    
    print(f'Probing layer {layer}')

    X, y, pds = load_train_splits(layer)

    seed_accuracy = []
    seed_directional = []
    seed_context = []

    test = load_whic('test')
    test_X, test_y = whic_set(test, layer = layer)

    for seed in tqdm(seeds):
        # mlp_probe = MLPClassifier(
        #     random_state=seed, 
        #     max_iter=100, 
        #     n_iter_no_change=5,
        #     learning_rate_init = 0.001,
        #     early_stopping=True
        # )

        # param_grid = {
        #     'hidden_layer_sizes': [(100,), (128,), (256,)],
        #     'alpha': [1, 0.1, 0.01, 0.001]
        # }

        # clf = GridSearchCV(mlp_probe, cv = pds, param_grid=param_grid, scoring = 'neg_log_loss', n_jobs=-2)
        # clf.fit(X, y)

        # print(f"Best model: {clf.best_params_}")

        clf = MLPClassifier(
            random_state=seed,
            max_iter=200,
            n_iter_no_change=5,
            learning_rate_init=0.001,
            early_stopping=True,
            hidden_layer_sizes=(hs,),
            alpha = penalty
        )

        clf.fit(X, y)

        accuracy = f1_score(test_y, clf.predict(test_X), average="weighted")
        directional = directional_accuracy(clf, layer)
        context = contextual_accuracy(clf, layer)

        seed_accuracy.append(accuracy)
        seed_directional.append(directional)
        seed_context.append(context)
    
    seed_whole_mean = np.mean(seed_accuracy)
    seed_whole_sd = np.std(seed_accuracy)
    
    seed_directional_mean = np.mean(seed_directional)
    seed_directional_sd = np.std(seed_directional)
    
    seed_context_mean = np.mean(seed_context)
    seed_context_sd = np.std(seed_context)

    print(f'Layer {layer} F1: {seed_whole_mean} Directional: {seed_directional_mean} Contextual: {seed_context_mean}')

    results.append((approximation_mode, layer, seed_whole_mean, seed_whole_sd, seed_directional_mean, seed_directional_sd, seed_context_mean, seed_context_sd))

with open(f"../results/whic_probing_results/{approximation_mode}_probing_results_{hs}_{str(penalty).replace('.', '-')}.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(['class', 'layer', 'accuracy', 'accuracy_sd', 'directional', 'directional_sd', 'context', 'context_sd'])
    writer.writerows(results)