import argparse
from datetime import datetime

import torch

import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from torch.utils.data import DataLoader, random_split

from operator import itemgetter

from paths import auth1_path

# from wic_model import WiCModel
from wic_sigmoid import WiCModel

parser = argparse.ArgumentParser(description="Word in context experiment arguments.")
parser.add_argument(
    "--model",
    type=str,
    default="bert",
    help="Model name."
)
parser.add_argument(
    "--batchsize",
    type=int,
    default=32,
    help="Batch size during training, validation, and testing."
)
parser.add_argument(
    "--savedir",
    type=str,
    default="WiCModel",
    help="Checkpointing directory"
)

parser = WiCModel.add_model_specific_args(parser)

args = parser.parse_args()

wic_classifier = WiCModel(args)

def load_wic(file = "train"):
    row = [x.strip().split("\t") for x in open(f"../data/WiC_dataset/{file}/{file}.data.txt", "r").readlines()]
    if not file == "test":
        gold = [x.strip() for x in open(f"../data/WiC_dataset/{file}/{file}.gold.txt", "r").readlines()]
    dataset = []
    for i, data in enumerate(row):
        word, pos, idx, sentence1, sentence2 = data
        idx1, idx2 = idx.split('-')
        idx1, idx2 = int(idx1), int(idx2)
        
        context1 = [sentence1, idx1]
        context2 = [sentence2, idx2]
        
        if not file == "test":
            label = gold[i]
            dataset.append((context1, context2, pos, label))
        else:
            dataset.append((context1, context2, pos))
            
    return dataset

train = load_wic('train')
val = load_wic('dev')

train_dl = DataLoader(train, batch_size = args.batchsize, num_workers = 4, shuffle=True)
val_dl = DataLoader(val, batch_size = args.batchsize, num_workers = 4)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{auth1_path}/makesense_logs/wic/{args.model}/{args.layer}/",
    filename=f"wic_model_{args.approximator}_{args.hidden_size}_{args.hidden_layers}_{str(args.lr).replace('.', '-')}",
    # filename=f"version_{args.approximator}_{args.hidden_size}_{args.hidden_layers}_{str(args.lr).replace('.', '-')}",
    save_top_k=1,
    verbose=True,
    monitor='val_loss',
    period=1,
    mode='min',
    save_weights_only=True        
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    verbose=True,
    mode='min'
)

logger = TensorBoardLogger("wiclogs/", name=args.savedir, version=f"l{args.layer}_{str(args.lr).replace('.', '-')}_{args.hidden_size}x{args.hidden_layers}")

trainer = pl.Trainer(
    precision = 16,
    gpus = '1',
    gradient_clip_val=1.0,
    max_epochs=20,
    checkpoint_callback=checkpoint_callback,
    logger = logger,
    callbacks=[early_stop_callback]
)

trainer.fit(wic_classifier, train_dl, val_dl)
