import warnings
from typing import Any

import pandas as pd
import subprocess
import torch
import torch.nn.functional as F
from lightning.pytorch.utilities.types import STEP_OUTPUT

from .data_util import get_data
from .models.gnn_model import get_gnn_
from .models.prediction_model import MLPNet
from ...base import *
from ...utils import get_kwargs
import os


class GRAPE(BaseImputerMixIn, Base):

    def __init__(self, lr, args, **kwargs):
        super().__init__(**get_kwargs(**kwargs))
        self.scaler = 'minmax'
        self.lr = lr
        self.args = args
        self.allow_missing_on_train = True
        self.model = get_gnn_(self.column_dim if args.node_mode == 0 else self.column_dim + 1, 1, self.args)
        self.impute_model = MLPNet(args.node_dim * 2, 1,
                                   hidden_layer_sizes=list(map(int, args.impute_hiddens.split('_'))),
                                   hidden_activation=args.impute_activation,
                                   dropout=args.dropout)

    def fit(self, scenario=lambda x: x):
        cfg = self._cfg
        args = self.args
        df = scenario(pd.read_csv(cfg.dataset.train_path))
        self._transform.fit(df)
        df = df.sample(len(df), random_state=cfg.seed)
        converted = self.tabular_transform.transform(df, return_as_tensor=True)
        self.data = get_data(converted, args.node_mode)
        return super().fit(scenario)

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=0)
        return self.optimizer

    def on_train_start(self) -> None:
        self.data = self.data.to(self.device)
        self.data.train_edge_attr = self.data.train_edge_attr.to(self.dtype)
        self.x = self.data.x.clone().detach().to(self.dtype)
        del self.data.x

    def training_step(self, *args: Any, **kwargs: Any):
        device = self.device
        data = self.data
        x = self.x
        train_edge_index = data.train_edge_index
        train_edge_attr = data.train_edge_attr
        train_labels = data.train_labels
        model = self.model
        impute_model = self.impute_model

        known_mask = torch.rand(int(train_edge_attr.shape[0] / 2), device=device) < self.args.known
        double_known_mask = torch.cat((known_mask, known_mask), dim=0)
        known_edge_index = train_edge_index[:, double_known_mask]
        known_edge_attr = train_edge_attr[double_known_mask]

        x_embd = model(x, known_edge_attr, known_edge_index)
        pred = impute_model([x_embd[train_edge_index[0]], x_embd[train_edge_index[1]]])
        pred_train = pred[:int(train_edge_attr.shape[0] / 2), 0]
        loss = F.mse_loss(pred_train, train_labels)
        self.log('t.loss', loss, prog_bar=True, on_epoch=True)
        return loss

    @torch.no_grad()
    def impute(self, df, seed=None, **kwargs):
        self.save_model_mode()
        self.eval()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            converted = self.tabular_transform.transform(df, return_as_tensor=True)
            nan_mask = converted.isnan()

            device = self.device
            data = get_data(converted, self.args.node_mode).to(device)
            test_input_edge_index = data.train_edge_index
            test_input_edge_attr = data.train_edge_attr
            test_edge_index = data.test_edge_index
            test_edge_attr = data.test_edge_attr
            model = self.model
            impute_model = self.impute_model
            x = data.x.clone().detach()
            x_embd = model(x, test_input_edge_attr, test_input_edge_index)
            pred = impute_model([x_embd[test_edge_index[0], :], x_embd[test_edge_index[1], :]])
            pred_test = pred[:int(test_edge_attr.shape[0] / 2), 0].cpu()
            converted[nan_mask] = pred_test
        return self.tabular_transform.inverse_transform(converted)

    def _impute(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
        pass

    def on_save_checkpoint(self, checkpoint):
        checkpoint['args'] = self.args

    def on_load_checkpoint(self, checkpoint):
        self.args = checkpoint['args']
