from minicons import cwe

from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier

import csv

from torch.utils.data import DataLoader
import torch

from sklearn.metrics import classification_report, accuracy_score

import numpy as np

import os

from approximation_model import NonLinearApproximator

import argparse

from tqdm import tqdm

from paths import auth1_path

os.environ["TOKENIZERS_PARALLELISM"] = "true"

parser = argparse.ArgumentParser()
parser.add_argument("--model", default = 'bert-base-uncased', type = str)
parser.add_argument("--approximation", default = "original", type = str)
parser.add_argument("--batchsize", default = 128, type = int)
parser.add_argument("--device", default = 'cpu', type = str)
parser.add_argument("--hidden_size", default = 100, type = int)
args = parser.parse_args()

model_name = args.model
approximation_mode = args.approximation
device = args.device
batchsize = args.batchsize
hs = args.hidden_size

model = cwe.CWE(model_name, 'cuda:1')

def load_wic(file = "train"):
    row = [x.strip().split("\t") for x in open(f"../data/WiC_dataset/{file}/{file}.data.txt", "r").readlines()]
    if not file == "test":
        gold = [x.strip() for x in open(f"../data/WiC_dataset/{file}/{file}.gold.txt", "r").readlines()]
    dataset = []
    for i, data in enumerate(row):
        word, pos, idx, sentence1, sentence2 = data
        idx1, idx2 = idx.split('-')
        idx1, idx2 = int(idx1), int(idx2)
        
        context1 = [sentence1, idx1]
        context2 = [sentence2, idx2]
        
        if not file == "test":
            label = gold[i]
            dataset.append((context1, context2, pos, label))
        else:
            dataset.append((context1, context2, pos))
            
    return dataset

label2id = {
    'T': 1,
    'F': 0
}

def wic_set(dataset = 'train', layer = 0, approximation_mode = approximation_mode, model = model):
    data = load_wic(dataset)
    data_dl = DataLoader(data, num_workers = 4, batch_size = 128)
    
    vectors = []
    labels = []

    if approximation_mode == 'laser':
        checkpoint = f'{auth1_path}makesense_logs/bert/{layer}/version_laser_2048_2_0-0001.ckpt'

        approximator = NonLinearApproximator.load_from_checkpoint(checkpoint)
        approximator.eval()


    elif approximation_mode == 'ser':
        checkpoint = f'{auth1_path}/makesense_logs/bert/{layer}/version_ser_2048_2_0-0001-v1.ckpt'

        approximator = NonLinearApproximator.load_from_checkpoint(checkpoint)
        approximator.eval()

    
    for batch in data_dl:
        if dataset == 'test':
            context1, context2, pos = batch
        else:
            context1, context2, pos, label = batch
        context1, context2 = [list(zip(*x)) for x in [context1, context2]]
        context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]
        context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]

        if dataset != 'test':
            label = list(map(lambda x: label2id[x], label))
            labels.extend(label)

        c1 = model.extract_representation(context1, layer)
        c2 = model.extract_representation(context2, layer)

        if approximation_mode == "laser" or approximation_mode == "ser":
            c1 = approximator(c1).detach()
            c2 = approximator(c2).detach()
        
        vectors.extend(torch.cat((c1, c2), 1))
    
    vectors = torch.stack(vectors).numpy()

    if dataset != 'test':
        labels = torch.tensor(labels).numpy()
        return vectors, labels

    else:
        return vectors

seeds = [100, 42, 1234, 211, 2511]
results = []

for layer in range(model.layers+1):

    print(f'Probing layer {layer}')

    train_X, train_y = wic_set('train', layer, approximation_mode=approximation_mode)
    val_X, val_y = wic_set('dev', layer, approximation_mode=approximation_mode)

    seed_accuracy = []
    for seed in tqdm(seeds):
        mlp_probe = MLPClassifier(
            hidden_layer_sizes = (hs, ),
            random_state=seed, 
            max_iter=150, 
            learning_rate_init=0.001,
            n_iter_no_change=5
        )

        mlp_probe.fit(train_X, train_y)

        accuracy = accuracy_score(val_y, mlp_probe.predict(val_X))
        seed_accuracy.append(accuracy)

    seed_mean = np.mean(seed_accuracy)
    seed_sd = np.std(seed_accuracy)

    print(f'Layer {layer} Accuracy: {seed_mean} +- {seed_sd}')

    results.append((model_name, approximation_mode, layer, seed_mean, seed_sd))

with open(f"../results/wic_probing_results/{approximation_mode}_{hs}_probing_results.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(['model', 'class', 'layer', 'accuracy', 'accuracy_sd'])
    writer.writerows(results)