import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from minicons import cwe

from approximation_model import NonLinearApproximator

import torchmetrics

from argparse import ArgumentParser

from paths import auth1_path

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class WiCModel(pl.LightningModule):
    '''
        
    '''
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.input_size = self.hparams.input_size
        self.hidden_size = self.hparams.hidden_size
        self.dp = self.hparams.dropout
        self.hidden_layers = self.hparams.hidden_layers
        self.learning_rate = self.hparams.lr
        self.approximator = self.hparams.approximator
        self.layer = self.hparams.layer
        self.blocks = []
        for l in range(self.hidden_layers-1):
            self.blocks.append(nn.Linear(self.input_size, self.hidden_size))
            self.blocks.append(nn.ReLU())
            self.blocks.append(nn.Dropout(p = self.dp))
            self.input_size = self.hidden_size
        # reset
        self.input_size = self.hparams.input_size

        self.encoder = nn.Sequential(*self.blocks)
        self.decoder = nn.Linear(self.hidden_size, 1)
        self.encoder.apply(self._init_weights)
        self.decoder.apply(self._init_weights)

        self.save_hyperparameters()
        # bert-encoder
        self.cwe = cwe.CWE('bert-base-uncased', 'cuda:0')

        # approximator
        if self.approximator == 1:
            self.approximator_model = NonLinearApproximator.load_from_checkpoint(f'{auth1_path}/makesense_logs/bert/{self.layer}/version_2048_2_0-0001.ckpt')
            self.approximator_model.eval()

            for param in self.approximator_model.parameters():
                param.requires_grad = False

        # metrics
        self.train_accuracy = torchmetrics.Accuracy()
        self.valid_accuracy = torchmetrics.Accuracy()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            m.bias.data.fill_(0.01)
    
    def forward(self, embedding):
        y_hat = self.decoder(self.encoder(embedding))
        return y_hat

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate, weight_decay=1e-5)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = self._build_batch(batch, approximator=self.approximator)
        # x, y = batch
        y_hat = self.decoder(self.encoder(x))
        y_hat = y_hat.squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        accuracy  = ((y_hat >= 0.5) == y).float().mean().item()

        self.log('train_loss', loss, on_epoch=True)
        self.log('train_acc', accuracy, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = self._build_batch(batch, approximator=self.approximator)
        # x, y = batch
        y_hat = self.decoder(self.encoder(x))
        y_hat = y_hat.squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        accuracy  = ((y_hat >= 0.5) == y).float().mean().item()

        self.log('val_loss', loss, on_epoch=True)
        self.log('val_acc', accuracy, on_epoch=True, prog_bar=True)

    def _build_batch(self, batch, approximator = 0):

        label2id = {
            'T': 1,
            'F': 0
        }

        context1, context2, pos, labels = batch
        context1, context2 = [list(zip(*x)) for x in [context1, context2]]
        context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]
        context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]

        labels = torch.tensor(list(map(lambda x: label2id[x], labels)), dtype = torch.float32)
        labels = labels.to(self.device)
        
        c1 = self.cwe.extract_representation(context1, self.layer)
        c2 = self.cwe.extract_representation(context2, self.layer)

        c1 = c1.to(self.device)
        c2 = c2.to(self.device)

        if approximator != 1:
            return torch.cat((c1, c2), dim = 1), labels
        else:
            c1 = self.approximator_model(c1)
            c2 = self.approximator_model(c2)
            return torch.cat((c1, c2), dim = 1), labels

    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
        parser = parent_parser.add_argument_group("WiC Model")
        parser.add_argument(
            '--input_size',
            type=int,
            default=1536,
            help="Input embedding size (defaults to 1536)."
        )
        parser.add_argument(
            '--hidden_layers',
            type=int, 
            default=2,
            help="Number of hidden layers (defaults to 2)."
        )
        parser.add_argument(
            '--hidden_size',
            type=int,
            default=512,
            help="Hidden Size (defaults to 512)."
        )
        parser.add_argument(
            '--dropout', 
            type=float, 
            default=0.5,
            help="Dropout probability."
        )
        parser.add_argument(
            '--lr', 
            type=float, 
            default=1e-3,
            help="Learning rate for approximation."
        )
        parser.add_argument(
            '--approximator',
            type=int,
            default=0,
            help="Whether an approximator needs to be loaded."
        )
        parser.add_argument(
            '--layer',
            type=int,
            default=None,
            help="layer of the approximator"
        )
        return parent_parser