import copy
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch_scatter import scatter
from torch_geometric.data import Batch, Data
from torch import Tensor

from gpl.models.gin import GIN
from gpl.models.mlp import MLPClean
from gpl.models.gpl import GPLV2, gaussianKL, maskConstKL
from gpl.models.mi_bounds import critic, infonce_lower_bound, club_upper_bound_my

class AssignerMLP(nn.Module):
    def __init__(self, channels, dropout_p, assign_edge=False):
        super().__init__()
        self.assign_edge = assign_edge

        in_dim = channels[0]
        if self.assign_edge:
            channels[0] = in_dim*3 
            self.feature_extractor = MLPClean(channels=channels, dropout=dropout_p, with_softmax=False) 
        else:
            channels[0] = in_dim
            self.feature_extractor = MLPClean(channels=channels, dropout=dropout_p, with_softmax=False) 

    def forward(self, emb, edge_index, batch): 
        if self.assign_edge:
            col, row = edge_index
            f1, f2 = emb[col], emb[row]
            
            merged = torch.stack([f1, f2], dim=0)
            mean_ = merged.mean(dim=0)
            max_, _ = merged.max(dim=0)
            min_, _ = merged.min(dim=0)
            entity_feature = torch.cat([mean_, max_, min_], dim=-1)
        else:
            entity_feature = emb

        embs = entity_feature
        assign_log_logits = self.feature_extractor(embs)
        return assign_log_logits


class MaskSmoothLayer(nn.Module):
    def __init__(self, gamma):
        super().__init__()
        self.gamma = gamma

    def forward(self, mask, edge_index, assign_edge):
        if assign_edge:
            avg_node_mask = scatter(mask.reshape(-1), index=edge_index[0], reduce='mean') # 从edge prob来induce node prob
            smoothed_mask = (avg_node_mask[edge_index[0]] + avg_node_mask[edge_index[1]] ) / 2 # edge的两个node求均值得到平均的edge mask
        else:
            avg_edge_mask = (mask[edge_index[0]] + mask[edge_index[1]] ) / 2 # 从node mask得到edge mask
            smoothed_mask = scatter(avg_edge_mask, index=edge_index[0], reduce='mean') # 再从edge mask得到平均的node mask
        
        smoothed_mask = smoothed_mask.reshape((-1, 1)) # [|E|, 1]
        mask = (1-self.gamma) * mask + self.gamma * smoothed_mask
        return mask




def check_nan(x):
    if torch.isnan(x).any():
        import ipdb; ipdb.set_trace()

class ConditionalProb(nn.Module):
    def __init__(self, encoder, mean_encoder_channels, dropout_p):
        super().__init__()
        self.encoder = encoder
        self.mean = MLPClean(mean_encoder_channels, dropout=dropout_p, with_softmax=False)
        self.std = MLPClean(mean_encoder_channels, dropout=dropout_p, with_softmax=False)

    
    def forward(self, data, node_mask, edge_mask):
        """
        gx: a list of graphs
        gy: a list of graphs
        to build the q_\theta(gy|gx) network
        """
        device = edge_mask.device
        data = data.to(device)
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        

        # node representation of graph_x and graph_y
        repn_x = self.encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr) # node-level embeddings.  
        repn_y = self.encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr, edge_atten=edge_mask, node_mask=node_mask) # node-level embeddings

        # representation of graph_x and graph_y
        rep_x = scatter(repn_x, batch, dim=0, reduce='sum')
        rep_y = scatter(repn_y, batch, dim=0, reduce='sum')

        mu_x = self.mean(rep_x)
        sigma_x = self.std(rep_x)

        mu_y = self.mean(rep_y)
        sigma_y = self.std(rep_y)

        mu_x = mu_x / torch.norm(mu_x, dim=1, keepdim=True)
        mu_y = mu_y / torch.norm(mu_y, dim=1, keepdim=True)
        sigma_x = sigma_x.sigmoid()
        sigma_y = sigma_y.sigmoid()
        
        # to avoid inf when computing mu/sigma
        sigma_x = sigma_x + 1e-6
        sigma_y = sigma_y + 1e-6

        return mu_x, sigma_x, mu_y, sigma_y


def ib_coeff_uppb_scheduler_callback(**kwargs):
    
    model = kwargs['model']
    cur_epoch = kwargs['cur_epoch']
    logger = kwargs['logger']
    all_hparams = kwargs['all_hparams']

    # ascending ib_upper_bound loss coefficient
    max_ib_coeff_uppb = all_hparams['framework']['ib_coeff_uppb']
    if cur_epoch == 0:
        model.ib_coeff_uppb = 0
    if cur_epoch >= 100 and cur_epoch % 5 == 0:
        model.ib_coeff_uppb = max_ib_coeff_uppb * min(1, 0.1*(cur_epoch-100)/5)

    logger.info(f'model.ib_coeff_uppb={model.ib_coeff_uppb}')
    

class GPLV5(GPLV2):
    def __init__(self, encoder: GIN, config):
        super().__init__(encoder, config)

        
    def initialize(self, ):
        self.subg_encoder = copy.deepcopy(self.encoder)
        
        if self.assign_edge:
            self.edge_assigner = AssignerMLP(channels=self.model_config['edge_assigner_channels'], dropout_p=self.model_config['dropout_p'], assign_edge=True)
        else:
            self.node_assigner = AssignerMLP(channels=self.model_config['node_assigner_channels'], dropout_p=self.model_config['dropout_p'], assign_edge=False)

        self.q_theta = ConditionalProb(copy.deepcopy(self.encoder), self.model_config['mean_encoder_channels'], self.model_config['dropout_p'])
        

        
    def configure_optimizers(self):
        """
        two optimizers, one for updating q_\theta network, the other for other model parameters
        """
        lr = self.config['training']['optimizer_params']['lr']
        l2 = self.config['training']['optimizer_params']['l2']
        # \theta of q_\theta(y|x) 
        optimizer_q = torch.optim.Adam([p for n, p in self.named_parameters() if 'q_theta' in n], lr=lr, weight_decay=l2)
        # others
        optimizer_other = torch.optim.Adam([p for n, p in self.named_parameters() if 'q_theta' not in n], lr=lr, weight_decay=l2)

        return [optimizer_q, optimizer_other]
    

    def __loss__(self, clf_logits, y, node_mask, edge_mask, log_probs):
        step_dict = {}
        clf_labels = y

        if self.pred_loss_coeff == 0.0:
            pred_loss = torch.tensor(0.0)
        else:
            pred_loss = self.criterion(clf_logits, clf_labels)
            pred_loss = pred_loss * self.pred_loss_coeff
        step_dict['pred_loss'] = pred_loss

        mask_value = edge_mask
        ib_const_loss = torch.tensor(0.0).to(mask_value.device)
        if self.with_ib_constraint:
            
            if 'mask' in self.ib_constraint_type:
                ib_const_loss_mask = maskConstKL(mask_value, self.get_r())
                ib_const_loss_mask = ib_const_loss_mask * self.ib_coeff_mask
                # import ipdb; ipdb.set_trace()
                ib_const_loss += ib_const_loss_mask
                step_dict['eib_loss'] = ib_const_loss_mask.item()
            
            
            if 'upper_bound' in self.ib_constraint_type:
                mi_upper_bound_loss = club_upper_bound_my(log_probs)
                mi_upper_bound_loss = mi_upper_bound_loss * self.ib_coeff_uppb
                ib_const_loss += mi_upper_bound_loss
                step_dict['ib_upper_loss'] = mi_upper_bound_loss.item()

        
        loss = pred_loss + ib_const_loss


        step_dict['loss'] = loss
        step_dict['pred_loss'] = pred_loss.item()
        step_dict['ib_loss'] = ib_const_loss.item()

        return step_dict


    def backward_pass(self, optimizers, step_results):
        optimizer_q, optimizer_other = optimizers

        if self.__trainer__.cur_epoch % 2 == 0:
            optimizer_other.zero_grad()
            step_results['loss'].backward()
            optimizer_other.step()
        else:
            optimizer_q.zero_grad()
            step_results['q_theta_loss'].backward()
            optimizer_q.step()

    def forward_pass(self, data, batch_idx, compute_loss=True):
        return_dict = self.get_embs(data)

        # classification
        clf_logits = self.classifier(return_dict['subg_embs'])

        # q_\theta probs
        mu_x, sigma_x, mu_y, sigma_y = self.q_theta(data, return_dict['node_mask'], return_dict['edge_mask'])

        # compute loss
        # q_loss
        log_probs = critic(mu_x, sigma_x, mu_y)
        q_loss = -1 * infonce_lower_bound(log_probs) 

        # compute loss
        data = data.to(self.device)
        if compute_loss:
            loss_dict = self.__loss__(clf_logits=clf_logits, 
                                    y=data.y,
                                    node_mask=return_dict['node_mask'],
                                    edge_mask=return_dict['edge_mask'],
                                    log_probs=log_probs,
                                )
        else:
            loss_dict = {}
        
        loss_dict['q_theta_loss'] = q_loss

        
        loss_dict['clf_logits'] = clf_logits
        loss_dict['node_mask'] = return_dict['node_mask']
        loss_dict['edge_mask'] = return_dict['edge_mask']

        loss_dict['y'] = data.y
        loss_dict['batch'] = data.batch
        loss_dict['edge_index'] = data.edge_index

        if not hasattr(data, 'edge_label'):
            loss_dict['exp_labels'] = torch.zeros((data.edge_index.shape[1]), device=data.edge_index.device)
        else:
            loss_dict['exp_labels'] = data.edge_label

        return loss_dict

    def get_mask(self, embs, edge_index, batch):
        if self.with_ib_constraint is True:
            edge_assign_logits = self.edge_assigner(embs, edge_index, batch)
            edge_mask = self.sampling(edge_assign_logits, gumbel=self.gpl_config['use_gumbel'], hard=self.gpl_config['gumbel_hard'], tau=self.gpl_config['gumbel_tau'])

            if 'upper_bound' in self.ib_constraint_type: 
                try: 
                    assert edge_index[0].max() + 1 == batch.size(0)
                    node_mask = scatter(edge_mask, index=edge_index[0], reduce='max')
                    node_mask = node_mask.to(embs.device)
                except AssertionError: 
                    node_mask = torch.ones_like(batch).to(torch.float)*-1 # -1 ia an indicator
                    scatter(edge_mask, index=edge_index[1], reduce='max', out=node_mask)
                    node_mask[node_mask==-1] = 1 
            else:
                node_mask = torch.ones((embs.shape[0],), device=embs.device)
        else:
            edge_mask = torch.ones((edge_index.shape[1],), device=embs.device)
            node_mask = torch.ones((embs.shape[0],), device=embs.device)

        edge_mask = edge_mask.reshape(-1, 1)
        node_mask= node_mask.reshape(-1, 1)

        return edge_mask, node_mask
    

    def get_embs(self, data):
        data = data.to(self.device)
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        assert self.encoder.graph_pooling is False, 'Should obtain node embeddings now'
        N = x.shape[0]

        assert self.assign_edge is True

        hidden_size = self.encoder.hidden_size 
        embs = self.encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr) # node-level embeddings.  
        
        # get subgraph mask
        edge_mask, node_mask = self.get_mask(embs, edge_index, batch)

        # subgraph embeddings
        subg_embs = self.subg_encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr, edge_atten=edge_mask, node_mask=node_mask) # node-level embeddings
        subg_embs = scatter(subg_embs, batch, dim=0, reduce='sum')
        
        return_dict = {
            'node_mask': node_mask,
            'edge_mask': edge_mask,
            'subg_embs': subg_embs,
        }
        return return_dict

    