import torch.optim as optim
import torch.nn.functional as F
from copy import deepcopy
from deeprobust.graph import utils
import torch.nn as nn
import torch
from torch import Tensor
from torch_geometric.utils import dropout_adj


class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()

    def fit_inductive(self, data, train_iters=1000, initialize=True, verbose=False, patience=100, **kwargs):
        if initialize:
            self.initialize()

        self.train_data = data[0]
        self.val_data = data[1]
        self.test_data = data[2]
        # By default, it is trained with early stopping on validation
        self.train_with_early_stopping(train_iters, patience, verbose)

    def train_with_early_stopping(self, train_iters, patience, verbose):
        """early stopping based on the validation set
        """
        if verbose:
            print(f'--- training {self.name} model ---')
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        train_data, val_data = self.train_data, self.val_data
        early_stopping = patience
        best_acc_val = float('-inf')

        for i in range(train_iters):
            self.train()
            optimizer.zero_grad()
            loss_train = 0
            for graph_id, dat in enumerate(train_data):
                x, y = dat.graph['node_feat'].to(self.device), dat.label.to(self.device)    #.squeeze()
                edge_index = dat.graph['edge_index'].to(self.device)
                if hasattr(self, 'dropedge') and self.dropedge != 0:
                    edge_index, _ = dropout_adj(edge_index, p=self.dropedge)
                output = self.forward(x, edge_index)
                if self.args.dataset == 'elliptic':
                    loss_train += self.sup_loss(y[dat.mask], output[dat.mask])
                else:
                    loss_train += self.sup_loss(y, output)
            loss_train = loss_train / len(train_data)
            loss_train.backward()
            optimizer.step()

            if verbose and i % 10 == 0:
                print('Epoch {}, training loss: {}'.format(i, loss_train.item()))

            self.eval()
            eval_func = self.eval_func
            if self.args.dataset in ['fb100', 'ogb-arxiv', 'ogb-products', 'cora', 'amazon-photo', 'twitch-e']:
                y_val, out_val = [], []
                for ii, dataset in enumerate(val_data):
                    x_val = dataset.graph['node_feat'].to(self.device)
                    edge_index_val = dataset.graph['edge_index'].to(self.device)
                    out = self.forward(x_val, edge_index_val)
                    y_val.append(dataset.label.to(self.device))
                    out_val.append(out)
                acc_val = eval_func(torch.cat(y_val, dim=0), torch.cat(out_val, dim=0))
            elif self.args.dataset in ['elliptic']:
                y_val, out_val = [], []
                for ii, dataset in enumerate(val_data):
                    x_val = dataset.graph['node_feat'].to(self.device)
                    edge_index_val = dataset.graph['edge_index'].to(self.device)
                    out = self.forward(x_val, edge_index_val)
                    y_val.append(dataset.label[dataset.mask].to(self.device))
                    out_val.append(out[dataset.mask])
                acc_val = eval_func(torch.cat(y_val, dim=0), torch.cat(out_val, dim=0))
            else:
                raise NotImplementedError
            if best_acc_val < acc_val:
                best_acc_val = acc_val
                weights = deepcopy(self.state_dict())
                early_stopping = patience
            else:
                early_stopping -= 1
            if early_stopping <= 0:
                break

        if verbose:
            print('--- early stopping at {0}, Best acc_val = {1} ---'.format(i, best_acc_val))
        self.load_state_dict(weights)

    def sup_loss(self, y, pred):
        if self.args.dataset in ('twitch-e', 'fb100', 'elliptic'):
            if y.shape[1] == 1:
                true_label = F.one_hot(y, y.max() + 1).squeeze(1)
            else:
                true_label = y
            criterion = nn.BCEWithLogitsLoss()
            loss = criterion(pred, true_label.squeeze(1).to(torch.float))
        else:
            out = F.log_softmax(pred, dim=1)
            target = y.squeeze(1)
            criterion = nn.NLLLoss()
            loss = criterion(out, target)
        return loss

    @torch.no_grad()
    def predict(self, x=None, edge_index=None, edge_weight=None):
        self.eval()
        if x is None or edge_index is None:
            x, edge_index = self.test_data.graph['node_feat'], self.test_data.graph['edge_index']
            x, edge_index = x.to(self.device), edge_index.to(self.device)
        return self.forward(x, edge_index, edge_weight)

    def _ensure_contiguousness(self,
                               x,
                               edge_idx,
                               edge_weight):
        if not x.is_sparse:
            x = x.contiguous()
        if hasattr(edge_idx, 'contiguous'):
            edge_idx = edge_idx.contiguous()
        if edge_weight is not None:
            edge_weight = edge_weight.contiguous()
        return x, edge_idx, edge_weight


class BatchNorm1d_plus(nn.BatchNorm1d):
    def forward_plus(self, x: Tensor, weights=None, debug=0) -> Tensor:
        y = x.clone()
        if weights is None:
            mean = torch.mean(x, 0)
            var = torch.var(x, 0, unbiased=False)
        else:
            mean = self.running_mean * (1 - weights) + torch.mean(x, 0) * weights
            var = self.running_var * (1 - weights) + torch.var(x, 0, unbiased=False) * weights
        for dim in range(x.shape[1]):
            y[:, dim] = ((x[:, dim] - mean[dim]) / torch.sqrt(var[dim] + self.eps)) * self.weight[dim] + self.bias[dim]
        return y

