import copy
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch_scatter import scatter


from gpl.models.gin import GIN
from gpl.models.mlp import MLPClean
from gpl.models.gpl import GPLV2, gaussianKL, maskConstKL, supervisedCT
from gpl.models.mcr2 import MaximalCodingRateReduction

class AssignerMLPWithZ(nn.Module):
    def __init__(self, channels, dropout_p, assign_edge=False):
        super().__init__()
        self.assign_edge = assign_edge

        in_dim = channels[0]
        # import ipdb; ipdb.set_trace()
        if self.assign_edge:
            channels[0] = in_dim*3 + in_dim # dim of Z
            self.feature_extractor = MLPClean(channels=channels, dropout=dropout_p, with_softmax=False)  # here we need to set with_softmax=False!!!
        else:
            channels[0] = in_dim + in_dim
            self.feature_extractor = MLPClean(channels=channels, dropout=dropout_p, with_softmax=False) # here we need to set with_softmax=False!!!

    def forward(self, emb, edge_index, batch, Z): 
        if self.assign_edge:
            col, row = edge_index
            f1, f2 = emb[col], emb[row]
            
            merged = torch.stack([f1, f2], dim=0)
            mean = merged.mean(dim=0)
            max, _ = merged.max(dim=0)
            min, _ = merged.min(dim=0)
            entity_feature = torch.cat([mean, max, min], dim=-1)
            
            edge_batch = batch[row]
            Z_ext = Z[edge_batch]
            
        else:
            entity_feature = emb
            Z_ext = Z[batch]

        # import ipdb; ipdb.set_trace()
        Z_ext = torch.zeros_like(Z_ext)

        embs = torch.cat([entity_feature, Z_ext], dim=-1)
        assign_log_logits = self.feature_extractor(embs)
        return assign_log_logits


class GPLV4(GPLV2):
    def __init__(self, encoder: GIN, config):
        super().__init__(encoder, config)
        self.prototype_num = self.gpl_config['prototype_num']
        self.hidden_size = self.model_config.hidden_size

        
    def initialize(self, ):
        self.subg_encoder = copy.deepcopy(self.encoder)
        self.experts = [copy.deepcopy(self.encoder) for i in range(self.prototype_num)]
        self.exp_base = torch.nn.Parameter((self.prototype_num, self.hidden_size))
        

    def __loss__(self, clf_logits, clf_labels, mean, std, node_mask, edge_mask, subg_embs, edge_index, batch):
        # grouped_embs: [B, K, dim]
        step_dict = {}

        if self.pred_loss_coeff == 0.0:
            pred_loss = torch.tensor(0.0)
        else:
            pred_loss = self.criterion(clf_logits, clf_labels)
            pred_loss = pred_loss * self.pred_loss_coeff
        # pred_loss = torch.tensor(0.0)
        step_dict['pred_loss'] = pred_loss

        mask_value = edge_mask
        if self.with_ib_constraint:
            if self.ib_constraint_type == 'vector':
                ib_const_loss_gaussian = gaussianKL(mean, std)
                ib_const_loss_gaussian = ib_const_loss_gaussian * self.ib_coeff_vector
                ib_const_loss = ib_const_loss_gaussian
                step_dict['vib_loss'] = ib_const_loss_gaussian.item()

            elif self.ib_constraint_type == 'mask':
                ib_const_loss_mask = maskConstKL(mask_value, self.get_r())
                ib_const_loss_mask = ib_const_loss_mask * self.ib_coeff_mask
                ib_const_loss = ib_const_loss_mask
                step_dict['eib_loss'] = ib_const_loss_mask.item()
            
            elif self.ib_constraint_type == 'both':
                ib_const_loss_gaussian = gaussianKL(mean, std) * self.ib_coeff_vector
                ib_const_loss_mask = maskConstKL(mask_value, self.get_r()) * self.ib_coeff_mask
                ib_const_loss = ib_const_loss_gaussian + ib_const_loss_mask
                step_dict['vib_loss'] = ib_const_loss_gaussian.item()
                step_dict['eib_loss'] = ib_const_loss_mask.item()
        else:
            ib_const_loss = torch.tensor(0.0)


        
        # import ipdb; ipdb.set_trace()
        if self.with_ct_loss and not self.queue_empty : # without warm up
            ct_loss = torch.tensor(0.0).to(clf_logits.device)
            unique_y = torch.unique(clf_labels).cpu().numpy().tolist()
            for y in unique_y: # all queues are not empty
                class_idx = clf_labels == y
                class_idx = class_idx.view(-1)
                this_class_embs = subg_embs[class_idx]

                same_class_embs = self.get_same_class_embs(y)
                other_class_embs = self.get_other_class_embs(this_y=y)
                
                
                this_class_ct_loss = supervisedCT(this_class_embs, same_class_embs, other_class_embs)
                
               
                ct_loss += this_class_ct_loss
            ct_loss = ct_loss / len(unique_y)
            ct_loss = ct_loss * self.ct_coeff
        else:
            ct_loss = torch.tensor(0.0)
        

        if self.with_rate_loss:
            X = subg_embs
            Y = clf_labels.to(torch.int).reshape(-1)
            rate_loss = self.rate_loss_ins.forward_compress_loss(X, Y, self.num_class)
            rate_loss = rate_loss * self.rate_loss_coeff * -1


       
        loss = pred_loss + ib_const_loss + ct_loss

        if self.with_rate_loss:
            loss += rate_loss
            step_dict['rate_loss'] = rate_loss

        step_dict['loss'] = loss
        step_dict['pred_loss'] = pred_loss.item()
        step_dict['ib_loss'] = ib_const_loss.item()
        step_dict['ct_loss'] = ct_loss.item()
        
        
        return step_dict


    def forward_pass(self, data, batch_idx, compute_loss=True):
        subg_embs, mean, std, clf_labels, edge_index, batch, node_mask, edge_mask = self.get_embs(data)

        # classification
        clf_logits = self.classifier(subg_embs)

        # update the momentum modules
        if self.with_ct_loss and self.training is True:
            self.update_queue(original_batch_data=data, embs=subg_embs)

        if compute_loss:
            # compute loss
            loss_dict = self.__loss__(clf_logits=clf_logits, 
                                    clf_labels=clf_labels,
                                    mean=mean,
                                    std=std,
                                    node_mask=node_mask,
                                    edge_mask=edge_mask,
                                    subg_embs=subg_embs,
                                    edge_index=edge_index,
                                    batch=batch,
                                )
        else:
            loss_dict = {}
        
        
        loss_dict['clf_logits'] = clf_logits
        loss_dict['y'] = clf_labels
        loss_dict['node_mask'] = node_mask
        loss_dict['edge_mask'] = edge_mask
        loss_dict['Z'] = mean

        loss_dict['batch'] = batch
        loss_dict['edge_index'] = edge_index
        loss_dict['exp_labels'] = data.edge_label

        return loss_dict

    def get_mask(self, N, embs, edge_index, batch, sampled_Z):
        edge_assign_logits = self.edge_assigner(embs, edge_index, batch, sampled_Z) # [N, 2], N is the number of edges
        edge_mask = self.sampling(edge_assign_logits, gumbel=True)

        node_mask = torch.ones((embs.shape[0],), device=embs.device) 

        edge_mask = edge_mask.reshape(-1, 1)
        node_mask= node_mask.reshape(-1, 1)

        return edge_mask, node_mask



    def get_embs(self, data):
        device = list(self.parameters())[0].device
        x = data.x.to(device)
        edge_index = data.edge_index.to(device)
        batch = data.batch.to(device)
        y = data.y.to(device)
        edge_attr = data.edge_attr.to(device) if data.edge_attr is not None else None
        
        assert self.encoder.graph_pooling is False, 'Should obtain node embeddings now'
        N = x.shape[0]

        assert self.with_assign_matrix is True
        assert self.with_ib_constraint is True
        # assert self.ib_constraint_type == 'both'
        assert self.assign_edge is True

        embs = self.encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr) # node-level embeddings
        Z = scatter(embs, batch, dim=0, reduce='sum') # [B, dim]
        mean = self.mean_encoder(Z)
        std = F.relu( self.std_encoder(Z) )

        
        sampled_Z = mean

        edge_mask, node_mask = self.get_mask(N, embs, edge_index, batch, sampled_Z)
        
        
        new_embs = self.subg_encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr, edge_atten=edge_mask) # node-level embeddings
        new_embs = new_embs * node_mask
        new_embs = scatter(new_embs, batch, dim=0, reduce='sum') # [B, dim]. subgraph embeddings

        
        return new_embs, mean, std, y, edge_index, batch, node_mask, edge_mask

    