import copy
import torch
import torchmetrics
from network import *
from utils import *
import pytorch_lightning as pl
from lightning.pytorch.utilities import grad_norm
pl.seed_everything(0, workers=True)

ACTION = 0
GRASPS = [False, True, True, True, True, True]
TOP = 1
FRONT = 2
REAR = 3
RIGHT = 4
LEFT = 5
BLOCKING = 0
SIDES = [False, True, True, True, True, True]

class GRNLit(pl.LightningModule):
    def __init__(self, args, hyperparameters):
        super().__init__()
        self.model = GRN(args=args).to(args.device)
        self.args = args
        self.BCELogitsLoss = nn.BCEWithLogitsLoss()
        self.BCELoss = nn.BCELoss()
        self.MSELoss = nn.MSELoss()
        self.F1Score = torchmetrics.F1Score(task="binary")
        self.Precision = torchmetrics.Precision(task="binary")
        self.Recall = torchmetrics.Recall(task="binary")
        self.AUC = torchmetrics.AUROC(task = "binary")
        self.Specificity = torchmetrics.Specificity(task = "binary")
        self.ConfusionMatrix = torchmetrics.ConfusionMatrix(task = "binary")
        self.ROC = torchmetrics.ROC(task = "binary")
        self.PrecisionRecallCurve = torchmetrics.PrecisionRecallCurve(task = "binary")
        self.MSE = torchmetrics.MeanSquaredError()
        self.MAE = torchmetrics.MeanAbsoluteError()
        self.R2 = torchmetrics.R2Score()
        self.validation_outputs = []
        self.test_outputs = []
        self.save_hyperparameters(ignore=['args', 'train_loss_function', 'val_loss_function'])
        
    def forward(self, data):
        return self.model(data, "predict")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        return optimizer
    
    def on_before_optimizer_step(self, optimizer):
        # Compute the 2-norm for each layer
        # If using mixed precision, the gradients are already unscaled here
        if self.args.debug:
            norms = grad_norm(self.model, norm_type=2)
            self.log_dict(norms)
        else:
            return
           
    def training_step(self, batch, batch_idx):
        batch = batch.to(self.args.device)
        feasibility_preds, IK_preds, GO_preds = self.model(batch, "train")
        feasibility_loss = self.BCELogitsLoss(feasibility_preds[batch.movable_mask].view(-1), batch.F_labels[batch.movable_mask].view(-1))
        loss = feasibility_loss
        if "IK" in self.args.edge_features:
            IK_loss = self.BCELogitsLoss(IK_preds[batch.movable_mask].view(-1), batch.IK_labels[batch.movable_mask].view(-1))
            loss += IK_loss
        if "GO" in self.args.edge_features:
            GO_loss = self.MSELoss(GO_preds[batch.proximity_mask == True].view(-1), batch.GO_labels[batch.proximity_mask == True].view(-1))
            loss += 10*GO_loss

        self.log_dict({"train_loss": loss, "step": float(self.current_epoch)}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        batch = batch.to(self.args.device)
        feasibility_preds, IK_preds, GO_preds = self.model(batch, "val")
        feasibility_loss = self.BCELoss(feasibility_preds[batch.movable_mask].view(-1), batch.F_labels[batch.movable_mask].view(-1))
        loss = feasibility_loss
        if "IK" in self.args.edge_features:
            IK_loss = self.BCELoss(IK_preds[batch.movable_mask].view(-1), batch.IK_labels[batch.movable_mask].view(-1))
            loss += IK_loss
        if "GO" in self.args.edge_features:
            GO_loss = self.MSELoss(GO_preds[batch.proximity_mask == True].view(-1), batch.GO_labels[batch.proximity_mask == True].view(-1))
            loss += 10*GO_loss
            
        feasibility_val = self.F1Score(feasibility_preds[batch.movable_mask].view(-1), batch.F_labels[batch.movable_mask].view(-1))
        if "IK" in self.args.edge_features:
            IK_val = self.F1Score(IK_preds[batch.movable_mask].view(-1), batch.IK_labels[batch.movable_mask].view(-1))
        else:
            IK_val = torch.tensor(0., device=torch.device(self.args.device))

        if "GO" in self.args.edge_features:
            if batch.GO_labels[batch.proximity_mask].shape[0] > 0:
                GO_val = self.MAE(GO_preds[batch.proximity_mask == True].view(-1), batch.GO_labels[batch.proximity_mask == True].view(-1))
        else:
            GO_val = torch.tensor(0., device=torch.device(self.args.device))
        
        
        self.validation_outputs.append({'val_loss': loss, 'F_val': feasibility_val, 'IK_val': IK_val, 'GO_val': GO_val})
        self.log_dict({"val_loss": loss, "Feasibility": feasibility_val, "IK": IK_val, "GO": GO_val, "step": float(self.current_epoch)}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        
        return {'val_loss': loss, "Feasibility": feasibility_val, "IK": IK_val, "GO": GO_val}
    
    def test_step(self, batch, batch_idx):
        batch = batch.to(self.args.device)
        feasibility_preds, IK_preds, GO_preds = self.model(batch, "test")
        feasibility_loss = self.BCELoss(feasibility_preds[batch.movable_mask].view(-1), batch.F_labels[batch.movable_mask].view(-1))
        loss = feasibility_loss
        if "IK" in self.args.edge_features:
            IK_loss = self.BCELoss(IK_preds[batch.movable_mask].view(-1), batch.IK_labels[batch.movable_mask].view(-1))
            loss += IK_loss
        if "GO" in self.args.edge_features:
            GO_loss = self.MSELoss(GO_preds[batch.proximity_mask == True].view(-1), batch.GO_labels[batch.proximity_mask == True].view(-1))
            loss += 10*GO_loss
            
    
        self.test_outputs.append({"feasibility_preds": feasibility_preds[batch.movable_mask, :], "feasibility_labels": batch.F_labels[batch.movable_mask, :],
                                  "IK_preds": IK_preds[batch.movable_mask, :], "IK_labels": batch.IK_labels[batch.movable_mask, :],
                                  "GO_preds": GO_preds[batch.proximity_mask[:, 0], :], "GO_labels": batch.GO_labels[batch.proximity_mask[:, 0], :]})
        
        return {"loss": loss, "feasibility_preds": feasibility_preds[batch.movable_mask, :], "feasibility_labels": batch.F_labels[batch.movable_mask, :],
                              "IK_preds": IK_preds[batch.movable_mask, :], "IK_labels": batch.IK_labels[batch.movable_mask, :],
                              "GO_preds": GO_preds[batch.proximity_mask[:, 0], :], "GO_labels": batch.GO_labels[batch.proximity_mask[:, 0], :]}

    def on_train_epoch_end(self):
        print()

    def on_validation_epoch_end(self):
        val_loss = torch.stack([output["val_loss"] for output in self.validation_outputs]).mean()
        F_val = torch.stack([output["F_val"] for output in self.validation_outputs]).mean()
        IK_val = torch.stack([output["IK_val"] for output in self.validation_outputs]).mean()
        GO_val = torch.stack([output["GO_val"] for output in self.validation_outputs]).mean()
        self.validation_outputs.clear()
        
        self.logger.experiment.add_scalar('val_loss', val_loss, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('F_val', F_val, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('IK_val', IK_val, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('GO_val', GO_val, global_step=self.current_epoch)
        
    def on_test_epoch_end(self):
        F_preds = torch.cat([tmp['feasibility_preds'] for tmp in self.test_outputs])
        F_labels = torch.cat([tmp['feasibility_labels'] for tmp in self.test_outputs])
        IK_preds = torch.cat([tmp['IK_preds'] for tmp in self.test_outputs])
        IK_labels = torch.cat([tmp['IK_labels'] for tmp in self.test_outputs])
        GO_preds = torch.cat([tmp['GO_preds'] for tmp in self.test_outputs])
        GO_labels = torch.cat([tmp['GO_labels'] for tmp in self.test_outputs])
        self.test_outputs.clear()
            
        action_F = self.F1Score(F_preds[:, ACTION], F_labels[:, ACTION])
        top_F = self.F1Score(F_preds[:, TOP], F_labels[:, TOP])
        front_F = self.F1Score(F_preds[:, FRONT], F_labels[:, FRONT])
        rear_F = self.F1Score(F_preds[:, REAR], F_labels[:, REAR])
        right_F = self.F1Score(F_preds[:, RIGHT], F_labels[:, RIGHT])
        left_F = self.F1Score(F_preds[:, LEFT], F_labels[:, LEFT])
        grasp_F1_mean = torch.mean(torch.stack([top_F, front_F, rear_F, right_F, left_F]))
        grasp_F1_std = torch.std(torch.stack([top_F, front_F, rear_F, right_F, left_F]))

        if "IK" in self.args.edge_features:
            top_IK = self.F1Score(IK_preds[:, 0], IK_labels[:, 0])
            front_IK = self.F1Score(IK_preds[:, 1], IK_labels[:, 1])
            rear_IK = self.F1Score(IK_preds[:, 2], IK_labels[:, 2])
            right_IK = self.F1Score(IK_preds[:, 3], IK_labels[:, 3])
            left_IK = self.F1Score(IK_preds[:, 4], IK_labels[:, 4])
        else:
            top_IK = 0
            front_IK = 0
            rear_IK = 0
            right_IK = 0
            left_IK = 0

        if "GO" in self.args.edge_features:
            top_ic = self.MAE(GO_preds[:, 0], GO_labels[:, 0])
            front_ic = self.MAE(GO_preds[:, 1], GO_labels[:, 1])
            rear_ic = self.MAE(GO_preds[:, 2], GO_labels[:, 2])
            right_ic = self.MAE(GO_preds[:, 3], GO_labels[:, 3])
            left_ic = self.MAE(GO_preds[:, 4], GO_labels[:, 4])
        else:
            top_ic = 0
            front_ic = 0
            rear_ic = 0
            right_ic = 0
            left_ic = 0
        
        self.log_dict({"Action_F1": action_F, "Grasp_F1_mean": grasp_F1_mean, "Grasp_F1_std": grasp_F1_std, 
                       "Top_F1": top_F, "Front_F1": front_F, "Rear_F1": rear_F, "Right_F1": right_F, "Left_F1": left_F,
                       "Top_IK": top_IK, "Front_IK": front_IK, "Rear_IK": rear_IK, "Right_IK": right_IK, "Left_IK": left_IK,
                       "Top_GO": top_ic, "Front_GO": front_ic, "Rear_GO": rear_ic, "Right_GO": right_ic, "Left_GO": left_ic}, 
                       on_step=False, on_epoch=True, prog_bar=True, logger = True, batch_size=self.args.batch_size)
        

class AGFModuleLit(pl.LightningModule):
    def __init__(self, args, hyperparameters):
        super().__init__()
        self.model = AGFModule(args)
        self.args = args
        self.train_loss_function = nn.BCEWithLogitsLoss()
        self.val_loss_function = nn.BCELoss()
        self.F1Score = torchmetrics.F1Score(task="binary")
        self.Precision = torchmetrics.Precision(task="binary")
        self.Recall = torchmetrics.Recall(task="binary")
        self.AUC = torchmetrics.AUROC(task = "binary")
        self.Specificity = torchmetrics.Specificity(task = "binary")
        self.ConfusionMatrix = torchmetrics.ConfusionMatrix(task = "binary")
        self.ROC = torchmetrics.ROC(task = "binary")
        self.PrecisionRecallCurve = torchmetrics.PrecisionRecallCurve(task = "binary")
        self.validation_outputs = []
        self.test_outputs = []
        self.save_hyperparameters(ignore=['args', 'train_loss_function', 'val_loss_function'])
        
    def forward(self, data):
        return self.model(data, "predict")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        return optimizer
    
    def on_before_optimizer_step(self, optimizer):
        # Compute the 2-norm for each layer
        # If using mixed precision, the gradients are already unscaled here
        if self.args.debug:
            norms = grad_norm(self.model, norm_type=2)
            self.log_dict(norms)
        else:
            return
           
    def training_step(self, batch, batch_idx):
        batch = batch.to(self.args.device)
        preds = self.model(batch, "train")
        action_loss = self.train_loss_function(preds[batch.movable_mask, 0], batch.F_labels[batch.movable_mask, 0])
        top_loss = self.train_loss_function(preds[batch.movable_mask, 1], batch.F_labels[batch.movable_mask, 1])
        front_loss = self.train_loss_function(preds[batch.movable_mask, 2], batch.F_labels[batch.movable_mask, 2])
        rear_loss = self.train_loss_function(preds[batch.movable_mask, 3], batch.F_labels[batch.movable_mask, 3])
        right_loss = self.train_loss_function(preds[batch.movable_mask, 4], batch.F_labels[batch.movable_mask, 4])
        left_loss = self.train_loss_function(preds[batch.movable_mask, 5], batch.F_labels[batch.movable_mask, 5])
        loss = (action_loss + top_loss + front_loss + rear_loss + right_loss + left_loss)/6
        self.log_dict({"train_loss": loss, "step": float(self.current_epoch)}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        batch = batch.to(self.args.device)
        preds = self.model(batch, "val")
        action_loss = self.val_loss_function(preds[batch.movable_mask, 0], batch.F_labels[batch.movable_mask, 0])
        top_loss = self.val_loss_function(preds[batch.movable_mask, 1], batch.F_labels[batch.movable_mask, 1])
        front_loss = self.val_loss_function(preds[batch.movable_mask, 2], batch.F_labels[batch.movable_mask, 2])
        rear_loss = self.val_loss_function(preds[batch.movable_mask, 3], batch.F_labels[batch.movable_mask, 3])
        right_loss = self.val_loss_function(preds[batch.movable_mask, 4], batch.F_labels[batch.movable_mask, 4])
        left_loss = self.val_loss_function(preds[batch.movable_mask, 5], batch.F_labels[batch.movable_mask, 5])
        loss = (action_loss + top_loss + front_loss + rear_loss + right_loss + left_loss)/6
        F1 = self.F1Score(preds[batch.movable_mask], batch.F_labels[batch.movable_mask])
        Prec = self.Precision(preds[batch.movable_mask], batch.F_labels[batch.movable_mask])
        Rec = self.Recall(preds[batch.movable_mask], batch.F_labels[batch.movable_mask])
        objective = F1
        self.validation_outputs.append({'val_loss': loss, 'Val_F1': F1, 'Val_Precision': Prec, 'Val_Recall': Rec})
        self.log_dict({"val_loss": loss, "F1": F1, "Prec":Prec, "Rec": Rec, "objective": objective, "step": float(self.current_epoch)}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        
        return {'val_loss': loss, 'Val_F1': F1, 'Val_Precision': Prec, 'Val_Recall': Rec}
    
        
    def test_step(self, batch, batch_idx):
        batch = batch.to(self.args.device)
        preds = self.model(batch, "test")
        loss = self.val_loss_function(preds[batch.movable_mask], batch.F_labels[batch.movable_mask])
        
        self.test_outputs.append({"preds": preds[batch.movable_mask, :], "labels": batch.F_labels[batch.movable_mask, :]})
        
        return {"loss": loss, "preds": preds[batch.movable_mask, :], "labels": batch.F_labels[batch.movable_mask, :]}

    def on_train_epoch_end(self):
        print()

    def on_validation_epoch_end(self):
        val_loss = torch.stack([output["val_loss"] for output in self.validation_outputs]).mean()
        F1 = torch.stack([output["Val_F1"] for output in self.validation_outputs]).mean()
        Prec = torch.stack([output["Val_Precision"] for output in self.validation_outputs]).mean()
        Rec = torch.stack([output["Val_Recall"] for output in self.validation_outputs]).mean()
        self.validation_outputs.clear()
        
        self.logger.experiment.add_scalar('val_loss', val_loss, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_F1', F1, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_Precision', Prec, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_Recall', Rec, global_step=self.current_epoch)
        
    def on_test_epoch_end(self):
        preds = torch.cat([tmp['preds'] for tmp in self.test_outputs])
        targets = torch.cat([tmp['labels'] for tmp in self.test_outputs])
        self.test_outputs.clear()

        F1 = self.F1Score(preds, targets)
        Prec = self.Precision(preds, targets)
        Rec = self.Recall(preds, targets)
        Action_F1 = self.F1Score(preds[:, 0], targets[:, 0])
        Top_F1 = self.F1Score(preds[:, 1], targets[:, 1])
        Front_F1 = self.F1Score(preds[:, 2], targets[:, 2])
        Rear_F1 = self.F1Score(preds[:, 3], targets[:, 3])
        Right_F1 = self.F1Score(preds[:, 4], targets[:, 4])
        Left_F1 = self.F1Score(preds[:, 5], targets[:, 5])
        grasp_F1_mean = torch.mean(torch.stack([Top_F1, Front_F1, Rear_F1, Right_F1, Left_F1]))
        grasp_F1_std = torch.std(torch.stack([Top_F1, Front_F1, Rear_F1, Right_F1, Left_F1]))

        self.log_dict({"Test_F1": F1, "Test_Precision": Prec, "Test_Recall": Rec, 
                       "Action_F1": Action_F1, "Grasp_F1_mean": grasp_F1_mean, "Grasp_F1_std": grasp_F1_std,
                       "Top_F1": Top_F1, "Front_F1": Front_F1, "Rear_F1": Rear_F1, 
                       "Right_F1": Right_F1, "Left_F1": Left_F1}, 
                       on_step=False, on_epoch=True, prog_bar=True, logger = True, batch_size=self.args.batch_size)
        

#===============================================================================================================================================
            

class GOModuleLit(pl.LightningModule):
    def __init__(self, args, hyperparameters):
        super().__init__()
        self.model = GOModule()
        self.args = args
        self.train_loss_function = nn.MSELoss()
        self.val_loss_function = nn.MSELoss()
        self.F1Score = torchmetrics.F1Score(task="binary")
        self.Precision = torchmetrics.Precision(task="binary")
        self.Recall = torchmetrics.Recall(task="binary")
        self.AUC = torchmetrics.AUROC(task = "binary")
        self.Specificity = torchmetrics.Specificity(task = "binary")
        self.ConfusionMatrix = torchmetrics.ConfusionMatrix(task = "binary")
        self.ROC = torchmetrics.ROC(task = "binary")
        self.PrecisionRecallCurve = torchmetrics.PrecisionRecallCurve(task = "binary")
        self.MAE = torchmetrics.MeanAbsoluteError()
        self.MSE = torchmetrics.MeanSquaredError()
        self.R2 = torchmetrics.R2Score()
        self.validation_outputs = []
        self.test_outputs = []
        self.save_hyperparameters(ignore=['args', 'train_loss_function', 'val_loss_function'])
        
    def forward(self, x, mask):
        preds = self.model(copy.deepcopy(x), mask, "predict")
        return preds

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        return optimizer

    def on_before_optimizer_step(self, optimizer):
        # Compute the 2-norm for each layer
        # If using mixed precision, the gradients are already unscaled here
        if self.args.debug:
            norms = grad_norm(self.model, norm_type=2)
            self.log_dict(norms)
        else:
            return
           
    def training_step(self, batch, batch_idx):
        inputs, labels, masks = batch
        preds = self.model(inputs, "train")
        top_loss = self.train_loss_function(preds[masks[:, 0] == 1, 0].view(-1), labels[masks[:, 0] == 1, 0].view(-1))
        front_loss = self.train_loss_function(preds[masks[:, 1] == 1, 1].view(-1), labels[masks[:, 1] == 1, 1].view(-1))
        rear_loss = self.train_loss_function(preds[masks[:, 2] == 1, 2].view(-1), labels[masks[:, 2] == 1, 2].view(-1))
        right_loss = self.train_loss_function(preds[masks[:, 3] == 1, 3].view(-1), labels[masks[:, 3] == 1, 3].view(-1))
        left_loss = self.train_loss_function(preds[masks[:, 4] == 1, 4].view(-1), labels[masks[:, 4] == 1, 4].view(-1))
        loss = (top_loss + front_loss + rear_loss + right_loss + left_loss)/5
        self.log_dict({"train_loss": loss, "step": float(self.current_epoch)}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, labels, masks = batch
        preds = self.model(inputs, "val")
        top_loss, top_MAE, top_MSE, top_R2 = self.compute_metrics(preds[masks[:, 0] == 1, 0], labels[masks[:, 0] == 1, 0])
        front_loss, front_MAE, front_MSE, front_R2 = self.compute_metrics(preds[masks[:, 1] == 1, 1], labels[masks[:, 1] == 1, 1])
        rear_loss, rear_MAE, rear_MSE, rear_R2 = self.compute_metrics(preds[masks[:, 2] == 1, 2], labels[masks[:, 2] == 1, 2])
        right_loss, right_MAE, right_MSE, right_R2 = self.compute_metrics(preds[masks[:, 3] == 1, 3], labels[masks[:, 3] == 1, 3])
        left_loss, left_MAE, left_MSE, left_R2 = self.compute_metrics(preds[masks[:, 4] == 1, 4], labels[masks[:, 4] == 1, 4])
        loss = (top_loss + front_loss + rear_loss + right_loss + left_loss)/5
        MAE = (top_MAE + front_MAE + rear_MAE + right_MAE + left_MAE)/5
        MSE = (top_MSE + front_MSE + rear_MSE + right_MSE + left_MSE)/5
        R2 = (top_R2 + front_R2 + rear_R2 + right_R2 + left_R2)/5
        objective = MSE
        self.validation_outputs.append({'val_loss': loss, 'Val_MAE': MAE, 'Val_MSE': MSE, 'Val_R2': R2})
        self.log_dict({"val_loss": loss, "MAE": MAE, "MSE":MSE, "R2": R2, "objective": objective, "step": float(self.current_epoch)},
                    on_step=False, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        return {'val_loss': loss, 'Val_MAE': MAE, 'Val_MSE': MSE, 'Val_R2': R2}

    def test_step(self, batch, batch_idx):
        inputs, labels, masks = batch
        preds = self.model(inputs, "test")
        top_loss = self.val_loss_function(preds[masks[:, 0] == 1, 0].view(-1), labels[masks[:, 0] == 1, 0].view(-1))
        front_loss = self.val_loss_function(preds[masks[:, 1] == 1, 1].view(-1), labels[masks[:, 1] == 1, 1].view(-1))
        rear_loss = self.val_loss_function(preds[masks[:, 2] == 1, 2].view(-1), labels[masks[:, 2] == 1, 2].view(-1))
        right_loss = self.val_loss_function(preds[masks[:, 3] == 1, 3].view(-1), labels[masks[:, 3] == 1, 3].view(-1))
        left_loss = self.val_loss_function(preds[masks[:, 4] == 1, 4].view(-1), labels[masks[:, 4] == 1, 4].view(-1))
        loss = (top_loss + front_loss + rear_loss + right_loss + left_loss)/5
        self.test_outputs.append({"preds": preds, "labels": labels, "masks": masks})
        return {"loss": loss, "preds": preds, "labels": labels}

    def on_train_epoch_end(self):
        print()

    def on_validation_epoch_end(self):
        val_loss = torch.stack([output["val_loss"] for output in self.validation_outputs]).mean()
        MAE = torch.stack([output["Val_MAE"] for output in self.validation_outputs]).mean()
        MSE = torch.stack([output["Val_MSE"] for output in self.validation_outputs]).mean()
        R2 = torch.stack([output["Val_R2"] for output in self.validation_outputs]).mean()
        self.validation_outputs.clear()
        self.logger.experiment.add_scalar('val_loss', val_loss, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_MAE', MAE, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_MSE', MSE, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_R2', R2, global_step=self.current_epoch)
    
    def on_test_epoch_end(self):
        preds = torch.cat([tmp['preds'] for tmp in self.test_outputs])
        labels = torch.cat([tmp['labels'] for tmp in self.test_outputs])
        masks = torch.cat([tmp['masks'] for tmp in self.test_outputs])
        self.test_outputs.clear()
        top_loss, top_MAE, top_MSE, top_R2 = self.compute_metrics(preds[masks[:, 0] == 1, 0], labels[masks[:, 0] == 1, 0])
        front_loss, front_MAE, front_MSE, front_R2 = self.compute_metrics(preds[masks[:, 1] == 1, 1], labels[masks[:, 1] == 1, 1])
        rear_loss, rear_MAE, rear_MSE, rear_R2 = self.compute_metrics(preds[masks[:, 2] == 1, 2], labels[masks[:, 2] == 1, 2])
        right_loss, right_MAE, right_MSE, right_R2 = self.compute_metrics(preds[masks[:, 3] == 1, 3], labels[masks[:, 3] == 1, 3])
        left_loss, left_MAE, left_MSE, left_R2 = self.compute_metrics(preds[masks[:, 4] == 1, 4], labels[masks[:, 4] == 1, 4])
        loss = (top_loss + front_loss + rear_loss + right_loss + left_loss)/5
        MAE = (top_MAE + front_MAE + rear_MAE + right_MAE + left_MAE)/5
        MSE = (top_MSE + front_MSE + rear_MSE + right_MSE + left_MSE)/5
        R2 = (top_R2 + front_R2 + rear_R2 + right_R2 + left_R2)/5
        self.log_dict({"test_loss": loss, "MAE": MAE, "MSE":MSE, "R2": R2, "Top_MAE": top_MAE, "Front_MAE": front_MAE, "Rear_MAE": rear_MAE, 
                        "Right_MAE": right_MAE, "Left_MAE": left_MAE, "step": float(self.current_epoch)},
                        on_step=False, on_epoch=True, prog_bar=True, logger = True, batch_size=self.args.batch_size)

    def compute_metrics(self, preds, targets):
        loss = self.val_loss_function(preds.view(-1), targets.view(-1))
        MAE = self.MAE(preds.view(-1), targets.view(-1))
        MSE = self.MSE(preds.view(-1), targets.view(-1))
        R2 = self.R2(preds.view(-1), targets.view(-1))
        return loss, MAE, MSE, R2
        
       

#===============================================================================================================================================

class IKModuleLit(pl.LightningModule):
    def __init__(self, args, hyperparameters):
        super().__init__()
        self.model = IKModule()
        self.args = args
        self.train_loss_function = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([args.pos_weight], device=args.device))
        self.val_loss_function = nn.BCELoss()
        self.F1Score = torchmetrics.F1Score(task="binary")
        self.Precision = torchmetrics.Precision(task="binary")
        self.Recall = torchmetrics.Recall(task="binary")
        self.AUC = torchmetrics.AUROC(task = "binary")
        self.Specificity = torchmetrics.Specificity(task = "binary")
        self.ConfusionMatrix = torchmetrics.ConfusionMatrix(task = "binary")
        self.ROC = torchmetrics.ROC(task = "binary")
        self.PrecisionRecallCurve = torchmetrics.PrecisionRecallCurve(task = "binary")
        self.validation_outputs = []
        self.test_outputs = []
        self.save_hyperparameters(ignore=['args', 'train_loss_function', 'val_loss_function'])
        
    def forward(self, x):
        return self.model(copy.deepcopy(x), "predict")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        return optimizer
    
    def on_before_optimizer_step(self, optimizer):
        # Compute the 2-norm for each layer
        # If using mixed precision, the gradients are already unscaled here
        if self.args.debug:
            norms = grad_norm(self.model, norm_type=2)
            self.log_dict(norms)
        else:
            return
           
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        preds = self.model(inputs, "train")
        top_loss = self.train_loss_function(preds[:, 0], labels[:, 0])
        front_loss = self.train_loss_function(preds[:, 1], labels[:, 1])
        rear_loss = self.train_loss_function(preds[:, 2], labels[:, 2])
        right_loss = self.train_loss_function(preds[:, 3], labels[:, 3])
        left_loss = self.train_loss_function(preds[:, 4], labels[:, 4])
        loss = (top_loss + front_loss + rear_loss + right_loss + left_loss)/5
        self.log_dict({"train_loss": loss, "step": float(self.current_epoch)}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=self.args.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        preds = self.model(inputs, "val")
        loss = self.val_loss_function(preds, labels)
        F1 = self.F1Score(preds, labels)
        Prec = self.Precision(preds, labels)
        Rec = self.Recall(preds, labels)
        objective = F1
        self.validation_outputs.append({'val_loss': loss, 'Val_F1': F1, 'Val_Precision': Prec, 'Val_Recall': Rec})
        self.log_dict({"val_loss": loss, "F1": F1, "Prec": Prec, "Rec": Rec, "objective": objective, "step": float(self.current_epoch)}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=False, batch_size=self.args.batch_size)
        
        return {'val_loss': loss, 'Val_F1': F1, 'Val_Precision': Prec, 'Val_Recall': Rec}
        
    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        preds = self.model(inputs, "test")
        loss = self.val_loss_function(preds, labels)
        self.test_outputs.append({"preds": preds, "labels": labels})
        return {"loss": loss, "preds": preds, "labels": labels}

    def on_train_epoch_end(self):
        print()

    def on_validation_epoch_end(self):
        val_loss = torch.stack([output["val_loss"] for output in self.validation_outputs]).mean()
        F1 = torch.stack([output["Val_F1"] for output in self.validation_outputs]).mean()
        Prec = torch.stack([output["Val_Precision"] for output in self.validation_outputs]).mean()
        Rec = torch.stack([output["Val_Recall"] for output in self.validation_outputs]).mean()
        self.validation_outputs.clear()
        
        self.logger.experiment.add_scalar('val_loss', val_loss, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_F1', F1, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_Precision', Prec, global_step=self.current_epoch)
        self.logger.experiment.add_scalar('Val_Recall', Rec, global_step=self.current_epoch)
        
    def on_test_epoch_end(self):
        preds = torch.cat([tmp['preds'] for tmp in self.test_outputs])
        labels = torch.cat([tmp['labels'] for tmp in self.test_outputs])
        self.test_outputs.clear()
        F1 = self.F1Score(preds, labels)
        Prec = self.Precision(preds, labels)
        Rec = self.Recall(preds, labels)
        top_F1 = self.F1Score(preds[:, 0], labels[:, 0])
        front_F1 = self.F1Score(preds[:, 1], labels[:, 1])
        rear_F1 = self.F1Score(preds[:, 2], labels[:, 2])
        right_F1 = self.F1Score(preds[:, 3], labels[:, 3])
        left_F1 = self.F1Score(preds[:, 4], labels[:, 4])
        self.log_dict({"Test_F1": F1, "Test_Precision": Prec, "Test_Recall": Rec,
                       "Top_F1": top_F1, "Front_F1": front_F1, "Rear_F1": rear_F1, "Right_F1": right_F1, "Left_F1": left_F1}, 
                       on_step=False, on_epoch=True, prog_bar=True, logger = True, batch_size=self.args.batch_size)
       
