import os
from collections import defaultdict
import argparse
import csv
from tqdm import tqdm

from torch.utils.data import DataLoader
import torch
from minicons import cwe

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

import numpy as np

from whic_utils import load_whic, pairwise_context, pairwise_direction
from whic_torch import WHiCProbe
from approximation_model import NonLinearApproximator

from sklearn.metrics import f1_score, classification_report
from paths import auth1_path

parser = argparse.ArgumentParser()
parser.add_argument("--approximation", default = "original", type = str)
parser.add_argument("--layer", default = 12, type = int)
parser.add_argument("--device", default = 'cpu', type = str)
args = parser.parse_args()

approximation_mode = args.approximation
device = args.device
LAYER = args.layer

train = load_whic('train')
val = load_whic('dev')
test = load_whic('test')

train_dl = DataLoader(train, batch_size = 32, num_workers = 4, shuffle=True)
test_dl = DataLoader(test, batch_size = 128, num_workers = 4)
val_dl = DataLoader(val, batch_size = 128, num_workers = 4)

probe = WHiCProbe(768 * 2, 256, 0.001, approximation_mode, LAYER)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='min'
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{auth1_path}/makesense_logs/whic/",
    filename=f"whic_probe_{approximation_mode}_{LAYER}",
    # filename=f"version_{args.approximator}_{args.hidden_size}_{args.hidden_layers}_{str(args.lr).replace('.', '-')}",
    save_top_k=1,
    verbose=True,
    monitor='val_loss',
    period=1,
    mode='min',
    save_weights_only=True        
)

lr_monitor = LearningRateMonitor(logging_interval='epoch')

logger = TensorBoardLogger("whiclogs/", name='WHiCProbesShuffle', version=f"l{LAYER}_{approximation_mode}")

trainer = pl.Trainer(
    precision = 16,
    gpus = '0',
    gradient_clip_val=1.0,
    max_epochs=40,
    callbacks=[early_stop_callback, lr_monitor],
    checkpoint_callback=checkpoint_callback,
    logger = logger
)

trainer.fit(probe, train_dl, val_dl)

test_metrics = trainer.test(probe, test_dl)

print(f'best model = {checkpoint_callback.best_model_path}')

probe.load_from_checkpoint(checkpoint_callback.best_model_path)
probe.to('cpu')
probe.eval()

def directional(layer = LAYER):
    test_directional = pairwise_direction('test')

    directional = DataLoader(test_directional[0], num_workers = 4, batch_size = 100)
    neg_directional = DataLoader(test_directional[1], num_workers = 4, batch_size = 100)

    positive = []
    for d in directional:
        inputs, labels = probe._build_batch(d, approximation_mode=approximation_mode)
        predicted = (probe(inputs).squeeze().sigmoid() >= 0.5).int().tolist()
        positive.extend(predicted)
        
    negative = []
    for d in neg_directional:
        inputs, labels = probe._build_batch(d, approximation_mode=approximation_mode)
        predicted = (probe(inputs).squeeze().sigmoid() >= 0.5).int().tolist()
        negative.extend(predicted)

    pairwise_accuracy = ((torch.tensor(positive) == 1) * (torch.tensor(negative) == 0)).float().mean().item()
    similar_predictions = (torch.tensor(positive) == torch.tensor(negative)).float().mean().item()
    return pairwise_accuracy, similar_predictions

def contextual(layer = LAYER):
    test_context = pairwise_context('test')

    context = DataLoader(test_context[0], num_workers = 4, batch_size = 100)
    neg_context = DataLoader(test_context[1], num_workers = 4, batch_size = 100)

    positive = []
    for d in context:
        inputs, labels = probe._build_batch(d, approximation_mode=approximation_mode)
        predicted = (probe(inputs).squeeze().sigmoid() >= 0.5).int()
        positive.extend(predicted)

    negative = []
    for d in neg_context:
        inputs, labels = probe._build_batch(d, approximation_mode=approximation_mode)
        predicted = (probe(inputs).squeeze().sigmoid() >= 0.5).int()
        negative.extend(predicted)

    results = defaultdict(float)
    for word, idx in test_context[2].items():
        positive_labels = torch.tensor(positive)[idx]
        negative_labels = torch.tensor(negative)[idx]

        pairwise_acc = ((positive_labels == 1) * (negative_labels == 0)).float().mean().item()

        results[word] = pairwise_acc

    return np.mean(list(results.values()))

directional_accuracy, directional_similar = directional(LAYER)
contextual_sensitivity = contextual(LAYER)

with open(f"../results/whic_torch_results/layer_{LAYER}_{approximation_mode}_shuffle.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(['layer', 'class', 'f1', 'directional_accuracy', 'directional_similar', 'contextual_sensitivity'])
    writer.writerow([LAYER, approximation_mode, test_metrics[0]['test_f1'], directional_accuracy, directional_similar, contextual_sensitivity])