import json
import os
import csv
import argparse
from tqdm import tqdm, trange

from minicons import cwe
from torch import cosine_similarity
from torch.utils.data import DataLoader
from approximation_model import NonLinearApproximator

from paths import auth1_path

parser = argparse.ArgumentParser()
parser.add_argument("--model", default = 'bert-base-uncased', type = str)
parser.add_argument("--approximation_mode", default = "F", type = str)
parser.add_argument("--batchsize", default = 70, type = int)
parser.add_argument("--device", default = 'cpu', type = str)
args = parser.parse_args()

model_name = args.model
directory_name = model_name.replace("-", "_")
if "/" in model_name:
    directory_name = directory_name.replace("/", "_")

batch_size = args.batchsize
device = args.device
approximation_mode = args.approximation_mode

def load_approximator(layer = 0, approximation_mode = 'laser'):
    if approximation_mode == "laser":
        approximator = NonLinearApproximator.load_from_checkpoint(f'{auth1_path}/makesense_logs/bert/{layer}/version_laser_2048_2_0-0001.ckpt')
        approximator.eval()

        for param in approximator.parameters():
            param.requires_grad = False
    
    elif approximation_mode == "ser":
        checkpoint = f'{auth1_path}/makesense_logs/bert/{layer}/version_ser_2048_2_0-0001-v1.ckpt'

        approximator = NonLinearApproximator.load_from_checkpoint(checkpoint)
        approximator.eval()

        for param in approximator.parameters():
            param.requires_grad = False
    
    return approximator

if not os.path.exists(f"../data/results/{directory_name}"):
    os.makedirs(f"../data/results/{directory_name}")

def load_wic(file = "train"):
    row = [x.strip().split("\t") for x in open(f"../data/WiC_dataset/{file}/{file}.data.txt", "r").readlines()]
    if not file == "test":
        gold = [x.strip() for x in open(f"../data/WiC_dataset/{file}/{file}.gold.txt", "r").readlines()]
    dataset = []
    for i, data in enumerate(row):
        word, pos, idx, sentence1, sentence2 = data
        idx1, idx2 = idx.split('-')
        idx1, idx2 = int(idx1), int(idx2)
        
        context1 = [sentence1, idx1]
        context2 = [sentence2, idx2]
        
        if not file == "test":
            label = gold[i]
            dataset.append((context1, context2, pos, label))
        else:
            dataset.append((context1, context2, pos))
            
    return dataset

train = load_wic("train")
val = load_wic("dev")

supervised_wic_dl = DataLoader(train + val, batch_size = batch_size)

model = cwe.CWE(model_name, device)

for l in range(model.layers+1):
    results = []
    print(f"Computing similarities in layer {l}:")
    for batch in tqdm(supervised_wic_dl):
        context1, context2, pos, labels = batch
        current_batch_size = len(context1)
        context1, context2 = [list(zip(*x)) for x in [context1, context2]]
        context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]
        context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]
        
        c1 = model.extract_representation(context1, l)
        c2 = model.extract_representation(context2, l)

        if approximation_mode == "laser" or approximation_mode == "ser":
            approximator = load_approximator(l, approximation_mode=approximation_mode)
            c1 = approximator(c1)
            c2 = approximator(c2)

        layer_sim = cosine_similarity(c1, c2).tolist()
        results.extend(list(zip([l]*current_batch_size, labels, pos, layer_sim)))
        
    with open(f"../results/delta/{directory_name}/layer_{l}_{approximation_mode}.csv", "w") as f:
        writer = csv.writer(f)
        writer.writerow(["layer", "label", "pos", "cosine"])
        writer.writerows(results)
