import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from minicons import cwe

from approximation_model import NonLinearApproximator

import torchmetrics

from argparse import ArgumentParser

from paths import auth1_path

from sklearn.metrics import f1_score

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class WHiCProbe(pl.LightningModule):
    def __init__(self, input_size, hidden_size, lr, approximator, layer):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.learning_rate = lr
        self.approximator = approximator
        self.layer = layer

        self.encoder = nn.Sequential(
            nn.Linear(self.input_size, self.hidden_size),
            nn.ReLU()
        )
        self.decoder = nn.Linear(self.hidden_size, 1)
        self.encoder.apply(self._init_weights)
        self.decoder.apply(self._init_weights)

        self.save_hyperparameters()
        # bert-encoder
        self.cwe = cwe.CWE('bert-base-uncased', 'cuda:1')

        # approximator

        if self.approximator == 'laser':
            checkpoint = f'{auth1_path}/makesense_logs/bert/{self.layer}/version_laser_2048_2_0-0001.ckpt'

            self.approximator_model = NonLinearApproximator.load_from_checkpoint(checkpoint)
            self.approximator_model.eval()

            for param in self.approximator_model.parameters():
                param.requires_grad = False


        elif self.approximator == 'ser':
            checkpoint = f'{auth1_path}/makesense_logs/bert/{self.layer}/version_ser_2048_2_0-0001-v1.ckpt'

            self.approximator_model = NonLinearApproximator.load_from_checkpoint(checkpoint)
            self.approximator_model.eval()

            for param in self.approximator_model.parameters():
                param.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            m.bias.data.fill_(0.01)

    def forward(self, embedding):
        y_hat = self.decoder(self.encoder(embedding))
        return y_hat

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate, weight_decay=1e-5)
        return {
            'optimizer': optimizer,
            'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience = 1, threshold=1e-3),
            'monitor': 'val_loss'
        }

    def training_step(self, batch, batch_idx):
        x, y = self._build_batch(batch, approximation_mode=self.approximator)
        # x, y = batch
        y_hat = self.decoder(self.encoder(x))
        y_hat = y_hat.squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        predictions = (y_hat.detach().cpu().sigmoid() >= 0.5).int().numpy()
        f1 = f1_score(y.cpu().numpy(), predictions, average = 'weighted')
        # accuracy  = ((torch.sigmoid(y_hat) >= 0.5) == y).float().mean().item()

        self.log('train_loss', loss, on_epoch=True)
        self.log('train_f1', f1, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = self._build_batch(batch, approximation_mode=self.approximator)
        # x, y = batch
        y_hat = self.decoder(self.encoder(x))
        y_hat = y_hat.squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        predictions = (y_hat.detach().cpu().sigmoid() >= 0.5).int().numpy()
        f1 = f1_score(y.cpu().numpy(), predictions, average = 'weighted')

        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_f1', f1, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = self._build_batch(batch, approximation_mode=self.approximator)
        # x, y = batch
        y_hat = self.decoder(self.encoder(x))
        y_hat = y_hat.squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        predictions = (y_hat.detach().cpu().sigmoid() >= 0.5).int().numpy()
        f1 = f1_score(y.cpu().numpy(), predictions, average = 'weighted')

        self.log('test_loss', loss, on_epoch=True)
        self.log('test_f1', f1, on_epoch=True)

    def _build_batch(self, batch, approximation_mode = 'original'):

        label2id = {
            '1': 1,
            '0': 0
        }

        context1, context2, labels = batch
        context1, context2 = [list(zip(*x)) for x in [context1, context2]]
        context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]
        context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]

        labels = torch.tensor(list(map(lambda x: label2id[x], labels)), dtype = torch.float32)
        labels = labels.to(self.device)
        
        c1 = self.cwe.extract_representation(context1, self.layer)
        c2 = self.cwe.extract_representation(context2, self.layer)

        c1 = c1.to(self.device)
        c2 = c2.to(self.device)

        if approximation_mode == "laser" or approximation_mode == "ser":
            c1 = self.approximator_model(c1).detach()
            c2 = self.approximator_model(c2).detach()

            return torch.cat((c1, c2), dim = 1), labels
        else:
            return torch.cat((c1, c2), dim = 1), labels