import argparse
from datetime import datetime

import torch

import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from torch.utils.data import DataLoader, random_split
from paths import auth1_path, auth2_path
# from sklearn.model_selection import train_test_split

from operator import itemgetter

from approximation_model import NonLinearApproximator

parser = argparse.ArgumentParser(description="Approximation experiment arguments.")
parser.add_argument(
    "--model",
    type=str,
    default="bert",
    help="Model name."
)
parser.add_argument(
    "--layer",
    type=int,
    default=0,
    help="Layer in the model whose embeddings are going to be approximated."
)
parser.add_argument(
    "--mode",
    type=str,
    default="laser",
    help="Choose which retrofitting the model will be optimized to approximate (laser vs ser)"
)
parser.add_argument(
    "--batchsize",
    type=int,
    default=128,
    help="Batch size during training, validation, and testing."
)
parser.add_argument(
    "--savedir",
    type=str,
    default="NonLinearApproximator",
    help="Checkpointing directory"
)

parser = NonLinearApproximator.add_model_specific_args(parser)

args = parser.parse_args()

approximator = NonLinearApproximator(args)

## Loading data
original_embedding = torch.load(f"{auth2_path}/context_div/ms_embs/{args.model}/original/original_all_{args.layer}.pt")
laser_embedding = torch.load(f"{auth2_path}/context_div/ms_embs/{args.model}/{args.mode}/{args.mode}_all_{args.layer}.pt").float()

assert original_embedding.type() == laser_embedding.type()

paired = list(zip(original_embedding, laser_embedding))

# def train_test_val(dataset: list, test_pct:float = 0.1) -> tuple:
#     train, test = train_test_split(dataset, test_size = test_pct, random_state = 42)
#     train, val = train_test_split(train, test_size = len(test), random_state = 42)
    
#     return train, test, val

# # 80-10-10 split
# train, test, val = train_test_val(paired)

train_idx = [int(l.strip()) for l in open("../data/train_idx.txt", "r").readlines()]
test_idx = [int(l.strip()) for l in open("../data/test_idx.txt", "r").readlines()]
val_idx = [int(l.strip()) for l in open("../data/val_idx.txt", "r").readlines()]

train = list(itemgetter(*train_idx)(paired))
test = list(itemgetter(*test_idx)(paired))
val = list(itemgetter(*val_idx)(paired))

train_dl = DataLoader(train, batch_size = args.batchsize, num_workers = 4, shuffle = True)
test_dl = DataLoader(test, batch_size = args.batchsize, num_workers = 4)
val_dl = DataLoader(val, batch_size = args.batchsize, num_workers = 4)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{auth1_path}/makesense_logs/{args.model}/{args.layer}/",
    filename=f"version_{args.mode}_{args.hidden_size}_{args.hidden_layers}_{str(args.lr).replace('.', '-')}",
    save_top_k=1,
    verbose=True,
    monitor='val_loss',
    period=1,
    mode='min',
    save_weights_only=True        
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='min'
)

logger = TensorBoardLogger("logs/", name=args.savedir, version=f"l{args.layer}_{str(args.lr).replace('.', '-')}_{args.hidden_size}x{args.hidden_layers}")

trainer = pl.Trainer(
    precision = 16,
    gpus = 1,
    gradient_clip_val=1.0,
    max_epochs=40,
    checkpoint_callback=checkpoint_callback,
    logger = logger,
    callbacks=[early_stop_callback]
)

trainer.fit(approximator, train_dl, val_dl)

trainer.test(approximator, test_dl)