import torch
import torch.nn as nn
from torch_geometric.utils import degree
import numpy as np
from utils import graph_visualize

class DIR(nn.Module):
    def __init__(self, clf, extractor, criterion, config):
        super().__init__()
        self.clf = clf
        self.extractor = extractor
        self.criterion = criterion
        self.device = next(self.parameters()).device
        self.alpha = config['alpha']
        self.with_optim = True
        self.ratio = config['causal_ratio']

    def forward_pass(self, data, epoch, phase):
        # input: data batch
        # self.clf.train();self.extractor.train()
        alpha_prime = self.alpha * (epoch ** 1.6)
        # generate causal & non-causal part
        (causal_x, causal_edge_index, causal_edge_attr, causal_edge_weight, causal_batch, causal_pos), \
        (conf_x, conf_edge_index, conf_edge_attr, conf_edge_weight, conf_batch,
         conf_pos), pred_edge_weight = self.rationale_generator(data)

        # causal repr
        # need: x, pos, edge_attr, edge_index, data.batch, edge_attn=edge_attn

        causal_rep = self.clf.forward_pass(
            causal_x, causal_pos, causal_edge_attr, causal_edge_index, causal_batch, causal_edge_weight, with_enc=False)
        causal_out = self.clf.clf_out(causal_rep)
        conf_rep = self.clf.forward_pass(
            conf_x, conf_pos, conf_edge_attr, conf_edge_index, conf_batch, conf_edge_weight, with_enc=False).detach()
        conf_out = self.clf.conf_out(conf_rep)
        is_labeled = data.y == data.y
        causal_loss = self.criterion(
            causal_out.to(torch.float32)[is_labeled],
            data.y.to(torch.float32)[is_labeled]
        )
        if phase != 'train':
            return causal_loss, {'pred': causal_loss.item()}, causal_out
        conf_loss = self.criterion(
            conf_out.to(torch.float32)[is_labeled],
            data.y.to(torch.float32)[is_labeled]
        )
        env_loss = torch.tensor([]).to(self.device)
        for conf in conf_out:
            rep_out = self.get_comb_pred(causal_out, conf)
            tmp = self.criterion(rep_out.to(torch.float32)[is_labeled], data.y.to(torch.float32)[is_labeled])  # [1]
            env_loss = torch.cat([env_loss, tmp.unsqueeze(0)])

        DIR_loss = (env_loss.mean() + torch.var(env_loss * conf_rep.size(0)))
        batch_loss = causal_loss + alpha_prime * DIR_loss
        # optimize batch_loss and conf_loss.
        loss_dict = {'conf_loss': conf_loss.item(), 'pred': causal_loss.item(), 'DIR_loss': DIR_loss.item()}
        # return logits
        return (conf_loss, batch_loss), loss_dict, causal_out

    def get_comb_pred(self, causal_pred, conf_pred):
        conf_pred_tmp = conf_pred.detach()
        return torch.sigmoid(conf_pred_tmp) * causal_pred

    def rationale_generator(self, data):

        # self.clf.add_geo_feature(data)
        x, _ = self.clf.get_emb(data)
        # data.edge_index & data.edge_attr & data.pos is valued

        # calculate edge weight
        row, col = data.edge_index
        edge_rep = torch.cat([x[row], x[col]], dim=-1)  # torch.Size([256100, 128])
        pred_edge_weight = self.extractor(edge_rep).view(-1)

        causal_edge_index = torch.LongTensor([[], []]).to(x.device)
        causal_edge_weight = torch.tensor([]).to(x.device)
        causal_edge_attr = torch.tensor([]).to(x.device)
        conf_edge_index = torch.LongTensor([[], []]).to(x.device)
        conf_edge_weight = torch.tensor([]).to(x.device)
        conf_edge_attr = torch.tensor([]).to(x.device)

        edge_indices, num_nodes, _, num_edges, cum_edges = split_batch(data)

        for edge_index, N, C in zip(edge_indices, num_edges, cum_edges):
            n_reserve = int(self.ratio * N)
            edge_attr = data.edge_attr[C:C + N]
            single_mask = pred_edge_weight[C:C + N]
            single_mask_detach = pred_edge_weight[C:C + N].detach().cpu().numpy()
            rank = np.argpartition(-single_mask_detach, n_reserve)

            # idx_reverse: causal edge; idx_drop: non_causal edge
            idx_reserve, idx_drop = rank[:n_reserve], rank[n_reserve:]

            causal_edge_index = torch.cat([causal_edge_index, edge_index[:, idx_reserve]], dim=1)
            conf_edge_index = torch.cat([conf_edge_index, edge_index[:, idx_drop]], dim=1)

            causal_edge_weight = torch.cat([causal_edge_weight, single_mask[idx_reserve]])
            # NOTE: -1 * single_mask[idx_drop]
            conf_edge_weight = torch.cat([conf_edge_weight, -1 * single_mask[idx_drop]])
            causal_edge_attr = torch.cat([causal_edge_attr, edge_attr[idx_reserve]])
            conf_edge_attr = torch.cat([conf_edge_attr, edge_attr[idx_drop]])
        causal_x, causal_edge_index, causal_batch, causal_pos = relabel(x, causal_edge_index, data.batch, data.pos)
        conf_x, conf_edge_index, conf_batch, conf_pos = relabel(x, conf_edge_index, data.batch, data.pos)

        return (causal_x, causal_edge_index, causal_edge_attr, causal_edge_weight, causal_batch, causal_pos), \
               (conf_x, conf_edge_index, conf_edge_attr, conf_edge_weight, conf_batch, conf_pos), \
               pred_edge_weight


def relabel(x, edge_index, batch, pos=None):
    num_nodes = x.size(0)
    sub_nodes = torch.unique(edge_index)
    # print(sub_nodes)
    x = x[sub_nodes]
    batch = batch[sub_nodes]
    # print(batch)
    row, col = edge_index
    # remapping the nodes in the explanatory subgraph to new ids.
    node_idx = row.new_full((num_nodes,), -1)
    node_idx[sub_nodes] = torch.arange(sub_nodes.size(0), device=row.device)
    edge_index = node_idx[edge_index]
    if pos is not None:
        pos = pos[sub_nodes]
    return x, edge_index, batch, pos


def split_batch(g):
    split = degree(g.batch[g.edge_index[0]], dtype=torch.long).tolist()
    torch.set_printoptions(profile="full")
    edge_indices = torch.split(g.edge_index, split, dim=1)
    num_nodes = degree(g.batch, dtype=torch.long)
    cum_nodes = torch.cat([g.batch.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]])
    num_edges = torch.tensor([e.size(1) for e in edge_indices], dtype=torch.long).to(g.x.device)
    cum_edges = torch.cat([g.batch.new_zeros(1), num_edges.cumsum(dim=0)[:-1]])
    return edge_indices, num_nodes, cum_nodes, num_edges, cum_edges
