from typing import Any, Dict, Tuple
import torch.nn as nn
import numpy as np
import torch
from lightning import LightningModule
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import MaxMetric, MeanMetric, MeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError, \
    MinMetric, Accuracy, Precision, Recall, F1Score, AUROC
from torchmetrics.functional import accuracy, precision, recall, f1_score, mean_absolute_error, mean_squared_error, mean_absolute_percentage_error
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

from .powergpt_components.revin import RevIN
from .PowerGPT import patch_masking, MaskMSELoss, create_patch


class Output(object):
    def __init__(self):
        self.loss = None
        self.acc = None
        self.pre = None
        self.rec = None
        self.f1 = None
        self.auroc = None

        self.zb_mse = None
        self.zb_mae = None

        self.gb_mse = None
        self.gb_mae = None

        self.qy_mse = None
        self.qy_mae = None

        self.hy_mse = None
        self.hy_mae = None

        self.cs_mse = None
        self.cs_mae = None

        self.zj_mse = None
        self.zj_mae = None

def get_model(model):
    return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model

class PowerGPTModule(LightningModule):

    def __init__(
            self,
            name,
            net: torch.nn.Module,
            optimizer: torch.optim.Optimizer,
            scheduler: torch.optim.lr_scheduler,
    ) -> None:
        """Initialize a `MNISTLitModule`.

        :param net: The model to train.
        :param optimizer: The optimizer to use for training.
        :param scheduler: The learning rate scheduler to use for training.
        """
        super().__init__()
        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.net = net
        self.patch_len = self.net.patch_len
        self.stride = self.net.stride
        self.mask_ratio = self.net.mask_ratio
        self.name = name
        # loss function
        self.criterion = None
        if self.net.head_type == 'pretrain' or self.net.head_type == 'imputation':
            self.criterion = MaskMSELoss()
        elif self.net.head_type == 'prediction':
            self.criterion = nn.MSELoss()
        else:
            self.criterion = nn.CrossEntropyLoss()


        # pretrain, imputation, prediction
        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        self.test_mse = MeanSquaredError()


        self.train_mae = MeanAbsoluteError()
        self.val_mae = MeanAbsoluteError()
        self.test_mae = MeanAbsoluteError()

        self.train_mape = MeanAbsolutePercentageError()
        self.val_mape = MeanAbsolutePercentageError()
        self.test_mape = MeanAbsolutePercentageError()

        # detail category for prediction and imputation
        # zb
        self.train_zb_mse = MeanSquaredError()
        self.val_zb_mse = MeanSquaredError()
        self.test_zb_mse = MeanSquaredError()

        self.train_zb_mae = MeanAbsoluteError()
        self.val_zb_mae = MeanAbsoluteError()
        self.test_zb_mae = MeanAbsoluteError()

        self.train_zb_mape = MeanAbsolutePercentageError()
        self.val_zb_mape = MeanAbsolutePercentageError()
        self.test_zb_mape = MeanAbsolutePercentageError()

        # gb
        self.train_gb_mse = MeanSquaredError()
        self.val_gb_mse = MeanSquaredError()
        self.test_gb_mse = MeanSquaredError()

        self.train_gb_mae = MeanAbsoluteError()
        self.val_gb_mae = MeanAbsoluteError()
        self.test_gb_mae = MeanAbsoluteError()

        self.train_gb_mape = MeanAbsolutePercentageError()
        self.val_gb_mape = MeanAbsolutePercentageError()
        self.test_gb_mape = MeanAbsolutePercentageError()

        # industry
        self.train_industry_mse = MeanSquaredError()
        self.val_industry_mse = MeanSquaredError()
        self.test_industry_mse = MeanSquaredError()

        self.train_industry_mae = MeanAbsoluteError()
        self.val_industry_mae = MeanAbsoluteError()
        self.test_industry_mae = MeanAbsoluteError()

        self.train_industry_mape = MeanAbsolutePercentageError()
        self.val_industry_mape = MeanAbsolutePercentageError()
        self.test_industry_mape = MeanAbsolutePercentageError()

        # area
        self.train_area_mse = MeanSquaredError()
        self.val_area_mse = MeanSquaredError()
        self.test_area_mse = MeanSquaredError()

        self.train_area_mae = MeanAbsoluteError()
        self.val_area_mae = MeanAbsoluteError()
        self.test_area_mae = MeanAbsoluteError()

        self.train_area_mape = MeanAbsolutePercentageError()
        self.val_area_mape = MeanAbsolutePercentageError()
        self.test_area_mape = MeanAbsolutePercentageError()

        # city
        self.train_city_mse = MeanSquaredError()
        self.val_city_mse = MeanSquaredError()
        self.test_city_mse = MeanSquaredError()

        self.train_city_mae = MeanAbsoluteError()
        self.val_city_mae = MeanAbsoluteError()
        self.test_city_mae = MeanAbsoluteError()

        self.train_city_mape = MeanAbsolutePercentageError()
        self.val_city_mape = MeanAbsolutePercentageError()
        self.test_city_mape = MeanAbsolutePercentageError()

        # province
        self.train_province_mse = MeanSquaredError()
        self.val_province_mse = MeanSquaredError()
        self.test_province_mse = MeanSquaredError()

        self.train_province_mae = MeanAbsoluteError()
        self.val_province_mae = MeanAbsoluteError()
        self.test_province_mae = MeanAbsoluteError()

        self.train_province_mape = MeanAbsolutePercentageError()
        self.val_province_mape = MeanAbsolutePercentageError()
        self.test_province_mape = MeanAbsolutePercentageError()


        # classification
        if self.net.head_type == 'classification':
            if self.net.target_dim != 2:
                self.train_acc = Accuracy(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.train_pre = Precision(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.train_rec = Recall(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.train_f1 = F1Score(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.train_auroc = AUROC(task='multiclass', average='macro', num_classes=self.net.target_dim)

                self.val_acc = Accuracy(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.val_pre = Precision(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.val_rec = Recall(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.val_f1 = F1Score(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.val_auroc = AUROC(task='multiclass', average='macro', num_classes=self.net.target_dim)

                self.test_acc = Accuracy(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.test_pre = Precision(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.test_rec = Recall(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.test_f1 = F1Score(task='multiclass', average='macro', num_classes=self.net.target_dim)
                self.test_auroc = AUROC(task='multiclass', average='macro', num_classes=self.net.target_dim)

            else:
                self.train_acc = Accuracy(task='binary', average='macro')
                self.train_pre = Precision(task='binary', average='macro')
                self.train_rec = Recall(task='binary', average='macro')
                self.train_f1 = F1Score(task='binary', average='macro')
                self.train_auroc = AUROC(task='binary', average='macro')

                self.val_acc = Accuracy(task='binary', average='macro')
                self.val_pre = Precision(task='binary', average='macro')
                self.val_rec = Recall(task='binary', average='macro')
                self.val_f1 = F1Score(task='binary', average='macro')
                self.val_auroc = AUROC(task='binary', average='macro')

                self.test_acc = Accuracy(task='binary', average='macro')
                self.test_pre = Precision(task='binary', average='macro')
                self.test_rec = Recall(task='binary', average='macro')
                self.test_f1 = F1Score(task='binary', average='macro')
                self.test_auroc = AUROC(task='binary', average='macro')

            self.val_acc.reset()
            self.val_acc_best.reset()
            self.val_pre.reset()
            self.val_pre_best.reset()
            self.val_rec.reset()
            self.val_rec_best.reset()
            self.val_f1.reset()
            self.val_f1_best.reset()
            self.val_auroc.reset()
            self.val_auroc_best.reset()


        # for averaging loss across batches
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        # for tracking best so far validation accuracy
        self.val_loss_best = MaxMetric() if self.net.head_type == 'classification' else MinMetric()

        self.val_mse_best = MinMetric()
        self.val_mae_best = MinMetric()
        self.val_mape_best = MinMetric()

        self.val_acc_best = MaxMetric()
        self.val_pre_best = MaxMetric()
        self.val_rec_best = MaxMetric()
        self.val_f1_best = MaxMetric()
        self.val_auroc_best = MaxMetric()

        self.test_step_pred = []
        self.test_step_y = []
        self.test_step_types = []
        self.test_step_masks = []

        self.revin = RevIN(num_features=net.n_vars, affine=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform a forward pass through the model `self.net`.

        :param x: A tensor of images.
        :return: A tensor of logits.
        """
        return self.net(x)

    def on_train_start(self) -> None:
        """Lightning hook that is called when training begins."""
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks

        self.val_loss.reset()
        self.val_loss_best.reset()

        self.val_mse.reset()
        self.val_mse_best.reset()
        self.val_mae.reset()
        self.val_mae_best.reset()
        self.val_mape.reset()
        self.val_mape_best.reset()


    def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
        x = batch.x
        batch.batch_size_ = batch.input_id.shape[0]
        batch.x_cov, _ = create_patch(batch.x_cov.transpose(2,1), patch_len=self.patch_len, stride=self.stride) 
        if self.net.head_type == "prediction":
            # x = self.revin(x, 'norm', types=batch.node_attr)
            x_patch, _ = create_patch(x, stride=self.stride, patch_len=self.patch_len)
            batch.x = x_patch
            return batch
        elif self.net.head_type == "classification":
            x_patch, _ = create_patch(x, stride=self.stride, patch_len=self.patch_len)
            batch.x = x_patch
            batch.y = batch.y.reshape(-1)
            return batch
        elif self.net.head_type == "imputation":
            # x = self.revin(x, 'norm', val_mask=batch.val_mask, types=batch.node_attr)
            x_patch, _ = create_patch(x, stride=self.stride, patch_len=self.patch_len)
            batch.x = x_patch
            batch.mask = batch.val_mask.unsqueeze(-1)
            return batch
        else:
            x = self.revin(x, 'norm', types=batch.node_attr)
            x, y, mask = patch_masking(x, stride=self.stride, patch_len=self.patch_len, mask_ratio=self.mask_ratio)
            mask = mask.unsqueeze(-1).repeat(1, 1, 1, self.patch_len)
            batch.x = x
            batch.y = y
            batch.mask = mask
            return batch

    def model_step(
            self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """
        if self.net.head_type == "prediction":
            batch.y = batch.y[:batch.batch_size_]
            pred = self.forward(batch)[:batch.batch_size_]
            # pred = self.forward(batch)
            # pred = self.revin(pred, 'denorm')[:batch.batch_size_]
            batch.node_attr = batch.node_attr[:batch.batch_size_]
            # loss = self.criterion(pred, batch.y)
        elif self.net.head_type == "classification":
            pred = self.forward(batch)
            pred = pred[:batch.batch_size_]
            batch.y = batch.y[:batch.batch_size_]
            loss = self.criterion(pred, batch.y)
            batch.node_attr = batch.node_attr[:batch.batch_size_]
            pred = torch.max(pred, dim=-1)[1]
        elif self.net.head_type == "imputation":
            batch.y = batch.y[:batch.batch_size_]
            pred = self.forward(batch)[:batch.batch_size_]
            bs = pred.shape[0]
            pred = pred.reshape(bs, -1, 1)
            # pred = self.revin(pred.reshape(bs, -1, 1), 'denorm')[:batch.batch_size_]
            batch.mask = batch.mask[:batch.batch_size_]
            batch.node_attr = batch.node_attr[:batch.batch_size_]
            # print('bn', len(batch.node_attr))
            loss = self.criterion(pred, batch.y, batch.mask)
        else:
            batch.y = batch.y[:batch.batch_size_]
            pred = self.forward(batch)[:batch.batch_size_]
            batch.mask = batch.mask[:batch.batch_size_]
            batch.node_attr = batch.node_attr[:batch.batch_size_]
            loss = self.criterion(pred, batch.y, batch.mask)
        # return loss, pred, batch.y
        return pred, batch.y

    def training_step(
            self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Perform a single training step on a batch of data from the training set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        :return: A tensor of losses between model predictions and targets.
        """

        # type2index = {'zb': 0, 'gb': 1, 'hy':2, 'qy':3, 'hz':4, 'zj': 5}
        if self.net.head_type == "prediction":
            loss, preds, targets= self.model_step(batch)
            y, t = targets, batch.node_attr
            # update and log metrics
            self.train_loss(loss)
            self.train_mse(preds, y)
            self.train_mae(preds, y)
            self.train_mape(preds, y)

        elif self.net.head_type == "classification":
            loss, preds, targets = self.model_step(batch)
            self.train_loss(loss)
            self.train_acc(preds, targets)
            self.train_pre(preds, targets)
            self.train_rec(preds, targets)
            self.train_f1(preds, targets)
            if self.net.target_dim == 2:
                self.train_auroc(preds, targets)
        elif self.net.head_type == "imputation":
            loss, preds, targets = self.model_step(batch)
            mask = batch.mask

            self.train_loss(loss)
            self.train_mse(preds[mask], targets[mask])
            self.train_mae(preds[mask], targets[mask])
            self.train_mape(preds[mask], targets[mask])

        if self.net.head_type != "classification":
            self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/mse", self.train_mse, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/mae", self.train_mae, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/mape", self.train_mape, on_step=False, on_epoch=True, prog_bar=True)
        else:
            self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/pre", self.train_pre, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/rec", self.train_rec, on_step=False, on_epoch=True, prog_bar=True)
            self.log("train/f1", self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
            if self.net.target_dim == 2:
                self.log("train/auroc", self.train_auroc, on_step=False, on_epoch=True, prog_bar=True)

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        pass

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single validation step on a batch of data from the validation set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        if self.net.head_type == "prediction":
            loss, preds, targets = self.model_step(batch)
            y, t = batch.y, batch.node_attr
            # update and log metrics
            self.val_loss(loss)
            self.val_mse(preds, y)
            self.val_mae(preds, y)
            self.val_mape(preds, y)

            # ex 2
            if torch.sum(t == 2) != 0:
                m = t == 2
                self.val_zb_mse(preds[m], y[m])
                self.val_zb_mae(preds[m], y[m])
                self.val_zb_mape(preds[m], y[m])
                self.log("val/zb_mse", self.val_zb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/zb_mae", self.val_zb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/zb_mape", self.val_zb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # pb 3
            if torch.sum(t == 3) != 0:
                m = t == 3
                self.val_gb_mse(preds[m], y[m])
                self.val_gb_mae(preds[m], y[m])
                self.val_gb_mape(preds[m], y[m])
                self.log("val/gb_mse", self.val_gb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/gb_mae", self.val_gb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/gb_mape", self.val_gb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # industry
            # if torch.sum(t == 2) != 0:
            #     m = t == 2
            #     self.val_industry_mse(preds[m], y[m])
            #     self.val_industry_mae(preds[m], y[m])
            #     self.val_industry_mape(preds[m], y[m])
            #     self.log("val/industry_mse", self.val_industry_mse, on_step=False, on_epoch=True, prog_bar=True)
            #     self.log("val/industry_mae", self.val_industry_mae, on_step=False, on_epoch=True, prog_bar=True)
            #     self.log("val/industry_mape", self.val_industry_mape, on_step=False, on_epoch=True, prog_bar=True)

            # area ae 0
            if torch.sum(t == 0) != 0:
                m = t == 0
                self.val_area_mse(preds[m], y[m])
                self.val_area_mae(preds[m], y[m])
                self.val_area_mape(preds[m], y[m])
                self.log("val/area_mse", self.val_area_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/area_mae", self.val_area_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/area_mape", self.val_area_mape, on_step=False, on_epoch=True, prog_bar=True)

            # city ct
            if torch.sum(t == 1) != 0:
                m = t == 1
                self.val_city_mse(preds[m], y[m])
                self.val_city_mae(preds[m], y[m])
                self.val_city_mape(preds[m], y[m])
                self.log("val/city_mse", self.val_city_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/city_mae", self.val_city_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/city_mape", self.val_city_mape, on_step=False, on_epoch=True, prog_bar=True)

            # province
            # if torch.sum(t == 5) != 0:
            #     m = t == 5
            #     self.val_province_mse(preds[m], y[m])
            #     self.val_province_mae(preds[m], y[m])
            #     self.val_province_mape(preds[m], y[m])
            #     self.log("val/province_mse", self.val_province_mse, on_step=False, on_epoch=True, prog_bar=True)
            #     self.log("val/province_mae", self.val_province_mae, on_step=False, on_epoch=True, prog_bar=True)
            #     self.log("val/province_mape", self.val_province_mape, on_step=False, on_epoch=True, prog_bar=True)


        elif self.net.head_type == "classification":
            loss, preds, targets = self.model_step(batch)
            self.val_loss(loss)
            self.val_acc(preds, targets)
            self.val_pre(preds, targets)
            self.val_rec(preds, targets)
            self.val_f1(preds, targets)
            if self.net.target_dim == 2:
                self.val_auroc(preds, targets)

        elif self.net.head_type == "imputation":
            loss, preds, targets = self.model_step(batch)
            t = batch.node_attr
            mask = batch.mask

            self.val_loss(loss)
            self.val_mse(preds[mask], targets[mask])
            self.val_mae(preds[mask], targets[mask])
            self.val_mape(preds[mask], targets[mask])

            if torch.sum(t == 2) != 0:
                m = t == 2
                # print(m.shape)
                # print(mask.shape)
                # print(mask[m].shape)

                preds_zb = preds[m][mask[m] == 1]
                targets_zb = targets[m][mask[m] == 1]
                self.val_zb_mse(preds_zb, targets_zb)
                self.val_zb_mae(preds_zb, targets_zb)
                self.val_zb_mape(preds_zb, targets_zb)
                self.log("val/zb_mse", self.val_zb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/zb_mae", self.val_zb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/zb_mape", self.val_zb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # pb
            if torch.sum(t == 3) != 0:
                m = t == 3
                preds_gb = preds[m][mask[m] == 1]
                targets_gb = targets[m][mask[m] == 1]
                self.val_gb_mse(preds_gb, targets_gb)
                self.val_gb_mae(preds_gb, targets_gb)
                self.val_gb_mape(preds_gb, targets_gb)
                self.log("val/gb_mse", self.val_gb_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/gb_mae", self.val_gb_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/gb_mape", self.val_gb_mape, on_step=False, on_epoch=True, prog_bar=True)

            # industry
            # if torch.sum(t == 2) != 0:
            #     m = t == 2
            #     preds_industry = preds[m][mask[m] == 1]
            #     targets_industry = targets[m][mask[m] == 1]
            #     self.val_industry_mse(preds_industry, targets_industry)
            #     self.val_industry_mae(preds_industry, targets_industry)
            #     self.val_industry_mape(preds_industry, targets_industry)
            #     self.log("val/industry_mse", self.val_industry_mse, on_step=False, on_epoch=True, prog_bar=True)
            #     self.log("val/industry_mae", self.val_industry_mae, on_step=False, on_epoch=True, prog_bar=True)
            #     self.log("val/industry_mape", self.val_industry_mape, on_step=False, on_epoch=True, prog_bar=True)

            # area ae
            if torch.sum(t == 0) != 0:
                m = t == 0
                preds_area = preds[m][mask[m] == 1]
                targets_area = targets[m][mask[m] == 1]
                self.val_area_mse(preds_area, targets_area)
                self.val_area_mae(preds_area, targets_area)
                self.val_area_mape(preds_area, targets_area)
                self.log("val/area_mse", self.val_area_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/area_mae", self.val_area_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/area_mape", self.val_area_mape, on_step=False, on_epoch=True, prog_bar=True)

            # ct
            if torch.sum(t == 1) != 0:
                m = t == 1
                preds_city = preds[m][mask[m] == 1]
                targets_city = targets[m][mask[m] == 1]
                self.val_city_mse(preds_city, targets_city)
                self.val_city_mae(preds_city, targets_city)
                self.val_city_mape(preds_city, targets_city)
                self.log("val/city_mse", self.val_city_mse, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/city_mae", self.val_city_mae, on_step=False, on_epoch=True, prog_bar=True)
                self.log("val/city_mape", self.val_city_mape, on_step=False, on_epoch=True, prog_bar=True)

            # province
            # if torch.sum(t == 5) != 0:
            #     m = t == 5
            #     preds_province = preds[m][mask[m] == 1]
            #     targets_province = targets[m][mask[m] == 1]
            #     self.val_province_mse(preds_province, targets_province)
            #     self.val_province_mae(preds_province, targets_province)
            #     self.val_province_mape(preds_province, targets_province)
            #     self.log("val/province_mse", self.val_province_mse, on_step=False, on_epoch=True, prog_bar=True)
            #     self.log("val/province_mae", self.val_province_mae, on_step=False, on_epoch=True, prog_bar=True)
            #     self.log("val/province_mape", self.val_province_mape, on_step=False, on_epoch=True, prog_bar=True)

        else:
            loss, preds, targets = self.model_step(batch)
            mask = batch.mask
            # update and log metrics
            self.val_loss(loss)
            self.val_mse(preds[mask], targets[mask])
            self.val_mae(preds[mask], targets[mask])
            self.val_mape(preds[mask], targets[mask])

        if self.net.head_type != "classification":
            self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/mse", self.val_mse, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/mae", self.val_mae, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/mape", self.val_mape, on_step=False, on_epoch=True, prog_bar=True)
        else:
            self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/pre", self.val_pre, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/rec", self.val_rec, on_step=False, on_epoch=True, prog_bar=True)
            self.log("val/f1", self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
            if self.net.target_dim == 2:
                self.log("val/auroc", self.val_auroc, on_step=False, on_epoch=True, prog_bar=True)


    def on_validation_epoch_end(self) -> None:
        "Lightning hook that is called when a validation epoch ends."
        loss = self.val_loss.compute()  # get current val acc
        self.val_loss_best(loss)  # update best so far val acc
        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log("val/loss_best", self.val_loss_best.compute(), sync_dist=True, prog_bar=True)

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single test step on a batch of data from the test set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        if self.net.head_type == "prediction":
            preds, targets = self.model_step(batch)
            y, t = targets, batch.node_attr
            # print(set(list(t)))
            self.test_step_pred.append(preds)
            self.test_step_y.append(y)
            self.test_step_types.append(t)

        elif self.net.head_type == "classification":
            loss, preds, targets = self.model_step(batch)
            self.test_step_pred.append(preds)
            self.test_step_y.append(targets)

        elif self.net.head_type == "imputation":
            loss, preds, targets = self.model_step(batch)
            t = batch.node_attr
            mask = batch.mask
            self.test_step_pred.append(preds)
            self.test_step_y.append(targets)
            self.test_step_types.append(t)
            self.test_step_masks.append(mask)


    def on_test_epoch_end(self) -> None:
        preds = torch.cat(self.test_step_pred).cpu().numpy()
        types = torch.cat(self.test_step_types).cpu().numpy()
        # preds = [preds[i].reshape(-1) for i in range(preds.shape[0])]
        print(len(preds))
        print(len(types))
        torch.save(preds,'./rep_preds.pt')
        torch.save(types,'./rep_types.pt')
    # def on_test_epoch_end(self) -> None:
    #     """Lightning hook that is called when a test epoch ends."""
    #     preds = torch.cat(self.test_step_pred).cpu().numpy()
    #     trues = torch.cat(self.test_step_y).cpu().numpy()

    #     if self.net.head_type in ["imputation", "prediction"]:
    #         types = torch.cat(self.test_step_types).cpu().numpy()
    #         preds = [preds[i].reshape(-1) for i in range(preds.shape[0])]
    #         trues = [trues[i].reshape(-1) for i in range(trues.shape[0])]
    #         types = types.tolist()
    #         from collections import Counter
    #         print(Counter(types))
    #         if self.net.head_type == "imputation":
    #             masks = torch.cat(self.test_step_masks).cpu().numpy() == 1
    #             masks = [masks[i].reshape(-1) for i in range(masks.shape[0])]
    #     zb_trues = []
    #     gb_trues = []
    #     qy_trues = []
    #     # hy_trues = []
    #     cs_trues = []
    #     # zj_trues = []

    #     zb_preds = []
    #     gb_preds = []
    #     qy_preds = []
    #     # hy_preds = []
    #     cs_preds = []
    #     # zj_preds = []

    #     tot_preds = []
    #     tot_trues = []

    #     if self.net.head_type == "imputation":
    #         for pred, true, type, mask in tqdm(zip(preds, trues, types, masks)):
    #             tot_preds.append(pred[mask])
    #             tot_trues.append(true[mask])
    #             if type == 2:
    #                 zb_preds.append(pred[mask])
    #                 zb_trues.append(true[mask])
    #             if type == 3:
    #                 gb_preds.append(pred[mask])
    #                 gb_trues.append(true[mask])
    #             if type == 0:
    #                 qy_preds.append(pred[mask])
    #                 qy_trues.append(true[mask])
    #             # if type == 2:
    #             #     hy_preds.append(pred[mask])
    #             #     hy_trues.append(true[mask])
    #             if type == 1:
    #                 cs_preds.append(pred[mask])
    #                 cs_trues.append(true[mask])
    #             # if type == 5:
    #             #     zj_preds.append(pred[mask])
    #             #     zj_trues.append(true[mask])
    #     elif self.net.head_type == "prediction":
    #         for pred, true, type in tqdm(zip(preds, trues, types)):
    #             tot_preds.append(pred)
    #             tot_trues.append(true)
    #             if type == 2:
    #                 zb_preds.append(pred)
    #                 zb_trues.append(true)
    #             if type == 3:
    #                 gb_preds.append(pred)
    #                 gb_trues.append(true)
    #             if type == 0:
    #                 qy_preds.append(pred)
    #                 qy_trues.append(true)
    #             # if type == 2:
    #             #     hy_preds.append(pred)
    #             #     hy_trues.append(true)
    #             if type == 1:
    #                 cs_preds.append(pred)
    #                 cs_trues.append(true)
    #             # if type == 5:
    #             #     zj_preds.append(pred)
    #             #     zj_trues.append(true)

    #     else:
    #         tot_preds = preds
    #         tot_trues = trues

    #     if self.net.head_type in ["imputation", "prediction"]:
    #         trues = torch.from_numpy(np.concatenate(tot_trues).reshape(-1))
    #         preds = torch.from_numpy(np.concatenate(tot_preds).reshape(-1))

    #         try:
    #             zb_preds = torch.from_numpy(np.concatenate(zb_preds).reshape(-1))
    #             zb_trues = torch.from_numpy(np.concatenate(zb_trues).reshape(-1))
    #         except:
    #             zb_preds = None

    #         try:
    #             gb_preds = torch.from_numpy(np.concatenate(gb_preds).reshape(-1))
    #             gb_trues = torch.from_numpy(np.concatenate(gb_trues).reshape(-1))
    #         except:
    #             gb_preds = None

    #         try:
    #             qy_preds = torch.from_numpy(np.concatenate(qy_preds).reshape(-1))
    #             qy_trues = torch.from_numpy(np.concatenate(qy_trues).reshape(-1))
    #         except:
    #             qy_preds = None

    #         # try:
    #         #     hy_preds = torch.from_numpy(np.concatenate(hy_preds).reshape(-1))
    #         #     hy_trues = torch.from_numpy(np.concatenate(hy_trues).reshape(-1))
    #         # except:
    #         #     hy_preds = None

    #         try:
    #             cs_preds = torch.from_numpy(np.concatenate(cs_preds).reshape(-1))
    #             cs_trues = torch.from_numpy(np.concatenate(cs_trues).reshape(-1))
    #         except:
    #             cs_preds = None

    #         # try:
    #         #     zj_preds = torch.from_numpy(np.concatenate(zj_preds).reshape(-1))
    #         #     zj_trues = torch.from_numpy(np.concatenate(zj_trues).reshape(-1))
    #         # except:
    #         #     zj_preds = None

    #     model_output = Output()
    #     if self.net.head_type == 'classification':
    #         preds = torch.from_numpy(preds).long()
    #         trues = torch.from_numpy(trues).long()

    #         if self.net.target_dim == 2:
    #             model_output.acc = accuracy(preds, trues, task='binary', average='macro')
    #             model_output.pre = precision(preds, trues, task='binary', average='macro')
    #             model_output.rec = recall(preds, trues, task='binary', average='macro')
    #             model_output.f1 = f1_score(preds, trues, task='binary', average='macro')
    #             model_output.auroc = roc_auc_score(trues.cpu().numpy(), preds.cpu().numpy())
    #         else:
    #             model_output.acc = accuracy(preds, trues, task='multiclass', average='macro', num_classes=self.net.target_dim)
    #             model_output.pre = precision(preds, trues, task='multiclass', average='macro', num_classes=self.net.target_dim)
    #             model_output.rec = recall(preds, trues, task='multiclass', average='macro', num_classes=self.net.target_dim)
    #             model_output.f1 = f1_score(preds, trues, task='multiclass', average='macro', num_classes=self.net.target_dim)
    #             model_output.auroc = 0

    #     else:

    #         model_output.mse = mean_squared_error(preds, trues)
    #         model_output.mae = mean_absolute_error(preds, trues)

    #         model_output.zb_mse = mean_squared_error(zb_preds, zb_trues) if zb_preds is not None else 0
    #         model_output.zb_mae = mean_absolute_error(zb_preds, zb_trues) if zb_preds is not None else 0
    #         model_output.zb_mape = mean_absolute_percentage_error(zb_preds, zb_trues) if zb_preds is not None else 0
    #         model_output.gb_mse = mean_squared_error(gb_preds, gb_trues) if gb_preds is not None else 0
    #         model_output.gb_mae = mean_absolute_error(gb_preds, gb_trues) if gb_preds is not None else 0
    #         model_output.gb_mape = mean_absolute_percentage_error(gb_preds, gb_trues) if gb_preds is not None else 0
    #         model_output.qy_mse = mean_squared_error(qy_preds, qy_trues) if qy_preds is not None else 0
    #         model_output.qy_mae = mean_absolute_error(qy_preds, qy_trues) if qy_preds is not None else 0
    #         model_output.qy_mape = mean_absolute_percentage_error(qy_preds, qy_trues) if qy_preds is not None else 0
    #         # model_output.hy_mse = mean_squared_error(hy_preds, hy_trues) if hy_preds is not None else 0
    #         # model_output.hy_mae = mean_absolute_error(hy_preds, hy_trues) if hy_preds is not None else 0
    #         model_output.cs_mse = mean_squared_error(cs_preds, cs_trues) if cs_preds is not None else 0
    #         model_output.cs_mae = mean_absolute_error(cs_preds, cs_trues) if cs_preds is not None else 0
    #         model_output.cs_mape = mean_absolute_percentage_error(cs_preds, cs_trues) if cs_preds is not None else 0
    #         # model_output.zj_mse = mean_squared_error(zj_preds, zj_trues) if zj_preds is not None else 0
    #         # model_output.zj_mae = mean_absolute_error(zj_preds, zj_trues) if zj_preds is not None else 0
    #     torch.save(preds, f'./result/PPT_preds_{self.name}.pt')
    #     torch.save(trues, f"./result/PPT_trues_{self.name}.pt")
    #     if self.net.head_type in ["imputation", "prediction"]:
    #         print(f'\tTest MSE     : {model_output.mse:2.4f}\n'
    #                      f'\tTest MAE     : {model_output.mae:2.4f}\n'
    #                      f'\tTest ZB_MSE     : {model_output.zb_mse:2.4f}\n'
    #                      f'\tTest ZB_MAE     : {model_output.zb_mae:2.4f}\n'
    #                      f'\tTest ZB_MAPE     : {model_output.zb_mape:2.4f}\n'
    #                      f'\tTest GB_MSE     : {model_output.gb_mse:2.4f}\n'
    #                      f'\tTest GB_MAE     : {model_output.gb_mae:2.4f}\n'
    #                      f'\tTest GB_MAPE     : {model_output.gb_mape:2.4f}\n'
    #                     #  f'\tTest HY_MSE     : {model_output.hy_mse:2.4f}\n'
    #                     #  f'\tTest HY_MAE     : {model_output.hy_mae:2.4f}\n'
    #                      f'\tTest QY_MSE     : {model_output.qy_mse:2.4f}\n'
    #                      f'\tTest QY_MAE     : {model_output.qy_mae:2.4f}\n'
    #                      f'\tTest QY_MAPE     : {model_output.qy_mape:2.4f}\n'
    #                      f'\tTest CS_MSE     : {model_output.cs_mse:2.4f}\n'
    #                      f'\tTest CS_MAE     : {model_output.cs_mae:2.4f}\n'
    #                      f'\tTest CS_MAPE     : {model_output.cs_mape:2.4f}\n'
    #                     #  f'\tTest ZJ_MSE     : {model_output.zj_mse:2.4f}\n'
    #                     #  f'\tTest ZJ_MAE     : {model_output.zj_mae:2.4f}\n'
    #                      )
    #     else:
    #         print(f'\tTest Acc     : {model_output.acc:2.4f}\n'
    #                      f'\tTest Pre     : {model_output.pre:2.4f}\n'
    #                      f'\tTest Rec     : {model_output.rec:2.4f}\n'
    #                      f'\tTest F1     : {model_output.f1:2.4f}\n'
    #                      f'\tTest AUROC     : {model_output.auroc:2.4f}\n'
    #                      )



    def freeze(self):
        """
        freeze the model head
        require the model to have head attribute
        """
        if hasattr(get_model(self.net), 'head'):
            for param in get_model(self.net).parameters(): param.requires_grad = False
            for param in get_model(self.net).head.parameters(): param.requires_grad = True

    def unfreeze(self):
        for param in get_model(self.net).parameters(): param.requires_grad = True

    def configure_optimizers(self) -> Dict[str, Any]:

        """Configures optimizers and learning-rate schedulers to be used for training.

        Normally you'd need one, but in the case of GANs or similar you might need multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
        """
        finetune_lr = 5e-5
        optimizer = self.hparams.optimizer([
            {'params': self.net.backbone.parameters(), 'lr': 1e-6},
            {'params': self.net.relational_gcn_layers.parameters(), 'lr': finetune_lr},
            {'params': self.net.head.parameters(), 'lr': finetune_lr},
        ])
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}


if __name__ == "__main__":
    ...
