import torch
from torch.utils.data import DataLoader

from minicons import cwe

from tqdm import tqdm
import os

import csv

import argparse

from approximation_model import NonLinearApproximator

from paths import common_path

os.environ["TOKENIZERS_PARALLELISM"] = "true"

parser = argparse.ArgumentParser()
parser.add_argument("--model", default = 'bert-base-uncased', type = str)
parser.add_argument("--approximation", default = "original", type = str)
parser.add_argument("--batchsize", default = 128, type = int)
parser.add_argument("--device", default = 'cpu', type = str)
args = parser.parse_args()

model = args.model
approximation_mode = args.approximation
device = args.device
batchsize = args.batchsize

bert = cwe.CWE(model, device)

def load_wic(file = "train"):
    row = [x.strip().split("\t") for x in open(f"../data/WiC_dataset/{file}/{file}.data.txt", "r").readlines()]
    if not file == "test":
        gold = [x.strip() for x in open(f"../data/WiC_dataset/{file}/{file}.gold.txt", "r").readlines()]
    dataset = []
    for i, data in enumerate(row):
        word, pos, idx, sentence1, sentence2 = data
        idx1, idx2 = idx.split('-')
        idx1, idx2 = int(idx1), int(idx2)
        
        context1 = [sentence1, idx1]
        context2 = [sentence2, idx2]
        
        if not file == "test":
            label = gold[i]
            dataset.append((context1, context2, pos, label))
        else:
            dataset.append((context1, context2, pos))
            
    return dataset

train = load_wic('train')
train_dl = DataLoader(train, num_workers = 4, batch_size = batchsize)

validation = load_wic('dev')
val_dl = DataLoader(validation, num_workers = 4, batch_size = batchsize)

label2id = {
    'T': 1,
    'F': 0
}

def threshold_classifier(x1, x2, threshold):
    cosines = torch.cosine_similarity(x1, x2)
    y_hat = [1 if c >= threshold else 0 for c in cosines.tolist()]
    
    return y_hat

def accuracy(y_hat, y):
    y_hat = torch.tensor(y_hat)
    y = torch.tensor(y)
    return (y_hat == y).float().mean().item()

def estimate(dl, layer, threshold, approximation_mode = 'original'):
    target = []
    predicted = []

    if approximation_mode == 'laser':
        checkpoint = f'{common_path}/makesense_logs/bert/{layer}/version_laser_2048_2_0-0001.ckpt'

        approximator = NonLinearApproximator.load_from_checkpoint(checkpoint)
        approximator.eval()


    elif approximation_mode == 'ser':
        checkpoint = f'{common_path}/makesense_logs/bert/{layer}/version_ser_2048_2_0-0001-v1.ckpt'

        approximator = NonLinearApproximator.load_from_checkpoint(checkpoint)
        approximator.eval()

    for batch in dl:
        context1, context2, pos, labels = batch
        context1, context2 = [list(zip(*x)) for x in [context1, context2]]
        context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]
        context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]

        labels = list(map(lambda x: label2id[x], labels))

        c1 = bert.extract_representation(context1, layer)
        c2 = bert.extract_representation(context2, layer)

        if approximation_mode == "laser" or approximation_mode == "ser":
            c1 = approximator(c1)
            c2 = approximator(c2)

        predicted.extend(threshold_classifier(c1, c2, threshold))
        target.extend(labels)
    return predicted, target

threshold_space = torch.linspace(0.00, 1.00, steps = 50).tolist()

layer_stats = []

for layer in range(bert.layers+1):
    best = 0
    best_threshold = 0
    print(f"Computing stats on WiC training set for layer {layer}:")
    for threshold in tqdm(threshold_space):
        train_predicted, train_target = estimate(train_dl, layer = layer, threshold = threshold, approximation_mode=approximation_mode)
        train_acc = accuracy(train_predicted, train_target)
        if train_acc > best:
            best = train_acc
            best_threshold = threshold
    
    
    # print(f'Measuring generalization on WiC validation set for layer {layer}:')
    
    val_predicted, val_target = estimate(val_dl, layer = layer, threshold = best_threshold, approximation_mode=approximation_mode)

    val_acc = accuracy(val_predicted, val_target)

    print(f"Train Accuracy: {best:.2f} Best Threshold: {best_threshold:.2f} Validation Accuracy: {val_acc:.2f}")

    layer_stats.append((layer, best_threshold, best, val_acc))

print("Saving Results..")

with open(f"../results/threshold_results/{approximation_mode}_threshold.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(['model', 'class', 'layer', 'train_threshold', 'train', 'validation'])
    for result in layer_stats:
        layer, threshold, train_acc, validation_acc = result
        writer.writerow([model, approximation_mode, layer, round(threshold, 4), round(train_acc, 4), round(validation_acc, 4)])