import torch
import torch.nn.functional as F
from torchmetrics import F1Score
from torch import optim
import pytorch_lightning as pl


class Classifier_Lighting(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
    def training_step(self, batch, batch_idx):
        loss = self.forward_performance(batch) 
        for k, v in loss.items():
            self.log("Train_" + k, v, prog_bar=True, batch_size=len(batch['documents_ids']))
        return loss["loss"]
        
    def validation_step(self, batch, batch_idx):
        loss = self.forward_performance(batch) 
        for k, v in loss.items():
            self.log("Val_" + k, v, prog_bar=True, batch_size=len(batch['documents_ids']))  
        return loss["loss"]

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    def forward_performance(self, data):        
        out, _ = self(data['documents_ids'].to(self.device), data['src_key_padding_mask'].to(self.device), data['matrix_mask'].to(self.device))  
        
        loss = self.criterion(out, data['labels'])
        pred = out.argmax(dim=1) 
        acc=(pred == data['labels']).sum()/len(data['labels'])
        f1_score = F1Score(task='multiclass', num_classes=self.num_classes, average="macro").to(self.device)

        f1_ma=f1_score(pred, data['labels'])
        return {"loss": loss, "f1-ma":f1_ma, 'acc':acc} 

    def predict(self, test_loader, cpu_store=True, flag_file=False):
        self.eval()
        preds=[]
        full_attn_weights=[]
        all_labels = []
        all_doc_ids = []
        all_article_identifiers = []

        with torch.no_grad():
            for data in test_loader:    
                if flag_file==False:         
                    out, att_w = self(data['documents_ids'].to(self.device), data['src_key_padding_mask'].to(self.device), data['matrix_mask'].to(self.device))   
                    full_attn_weights.extend(att_w)
                    pred = out.argmax(dim=1)
                
                all_doc_ids.extend(data['documents_ids'])

                if flag_file==False: 
                    if cpu_store:
                        pred = pred.detach().cpu().numpy() 
                    preds+=list(pred) 

                all_labels.extend(data['labels'])
                all_article_identifiers.extend(data['article_id'])

            if not cpu_store:
                preds = torch.Tensor(preds)
        
        return preds, full_attn_weights, all_labels, all_doc_ids, all_article_identifiers


    def predict_single(self, batch_single, cpu_store=True):
        self.eval()
        preds=[]    
        with torch.no_grad():                
            out, att_w = self(batch_single['documents_ids'].to(self.device), batch_single['src_key_padding_mask'].to(self.device), batch_single['matrix_mask'].to(self.device))   
            pred = out.argmax(dim=1) 

            if cpu_store:
                pred = pred.detach().cpu().numpy() 
            preds+=list(pred) 

        if not cpu_store:
            preds = torch.Tensor(preds)
        
        return preds, att_w
