import copy
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.data import Batch
from torch_scatter import scatter
from collections import defaultdict, deque


from gpl.models.gin import GIN
from gpl.models.mlp import MLPClean
from gpl.models.mcr2 import MaximalCodingRateReduction
from gpl.training import get_optimizer


cosine_similarity = torch.nn.CosineEmbeddingLoss(margin=-0.5)

def supervisedCT(this_class_embs, same_class_embs, other_class_embs):
    if same_class_embs.requires_grad:
        same_class_embs = same_class_embs.detach()
    if other_class_embs.requires_grad:
        other_class_embs = other_class_embs.detach()

    device = this_class_embs.device
    temperature = 1.0
    N1 = this_class_embs.shape[0]
    N2 = same_class_embs.shape[0]

    
    idx = torch.randperm(N2, device=device)[:min(N1, 1024)]

    same_class_embs = same_class_embs[idx]
    N2 = same_class_embs.shape[0]

    N3 = other_class_embs.shape[0]
    
    logits_pos = torch.mm( this_class_embs, same_class_embs.T)
    logits_pos = logits_pos.reshape(N1*N2, 1)

    logits_neg = torch.mm( this_class_embs, other_class_embs.T) # [N1, N3]
    logits_neg_exp = logits_neg[:, None, :].repeat(1, N2, 1) # [N1, N2, N3]
    logits_neg_exp = logits_neg_exp.reshape(N1*N2, N3) # [N1*N2, N3]
    logits = torch.cat([logits_pos, logits_neg_exp], dim=1)
    logits = logits/temperature
    labels = torch.zeros((N1*N2,), dtype=torch.int64, device=device)
    ct_loss = F.cross_entropy(logits, labels)

    return ct_loss


def cosineSimCT(this_class_embs, same_class_embs, other_class_embs):
    N1 = this_class_embs.shape[0]
    N2 = same_class_embs.shape[0]
    N3 = other_class_embs.shape[0]
    DIM = this_class_embs.shape[1]
    
    this_class_embs_exp1 = this_class_embs[:, None, :].repeat(1, N2, 1).reshape(N1*N2, DIM)
    same_class_embs_exp = same_class_embs[None, :, :].repeat(N1, 1, 1).reshape(N1*N2, DIM)
    this_class_embs_exp2 = this_class_embs[:, None, :].repeat(1, N3, 1).reshape(N1*N3, DIM)
    other_class_embs_exp = other_class_embs[None, :, :].repeat(N1, 1, 1).reshape(N1*N3, DIM)
    target_pos = torch.ones(N1*N2, dtype=torch.int64, device=this_class_embs.device)
    target_neg = torch.zeros(N1*N3, dtype=torch.int64, device=this_class_embs.device)
    ct_loss = cosine_similarity(this_class_embs_exp1, same_class_embs_exp, target_pos)
    ct_loss += cosine_similarity(this_class_embs_exp2, other_class_embs_exp, target_neg)

    return ct_loss


def gaussianKL(mean, std):
    # https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/
    term1 = (mean * mean).sum(dim=1).mean()
    term2 = std.square().sum(dim=1).mean()
    term3 = (std.square() + 1e-6 ).log().sum(dim=1).mean() # log-determiant of a diagonal matrix

    ib_const_loss = 1/2 * (term1 + term2 - term3)
    return ib_const_loss

def maskConstKL(mask, r):
    ib_loss = mask * torch.log( mask/r + 1e-6 ) + (1-mask) * torch.log( (1-mask)/(1-r + 1e-6) + 1e-6 )
    return ib_loss.mean()

class Criterion(nn.Module):
    def __init__(self, num_class, multi_label):
        super(Criterion, self).__init__()
        self.num_class = num_class
        self.multi_label = multi_label
        print(f'[INFO] [criterion] Using multi_label: {self.multi_label}')

    def forward(self, logits, targets):
        if self.num_class == 2 and not self.multi_label:
            # import ipdb; ipdb.set_trace()
            loss = F.binary_cross_entropy_with_logits(logits, targets.float().view(-1, 1))
        elif self.num_class > 2 and not self.multi_label:
            loss = F.cross_entropy(logits, targets.long())
        else:
            is_labeled = targets == targets  # mask for labeled data
            loss = F.binary_cross_entropy_with_logits(logits[is_labeled], targets[is_labeled].float())
        return loss

class AssignerMLP(nn.Module):
    def __init__(self, channels, dropout_p, assign_edge=False):
        super().__init__()
        self.assign_edge = assign_edge

        if self.assign_edge:
            channels[0] = channels[0]*3
            self.feature_extractor = MLPClean(channels=channels, dropout=dropout_p, with_softmax=False)  # here we need to set with_softmax=False!!!
        else:
            self.feature_extractor = MLPClean(channels=channels, dropout=dropout_p, with_softmax=False) # here we need to set with_softmax=False!!!

    def forward(self, emb, edge_index): 
        if self.assign_edge:
            col, row = edge_index
            f1, f2 = emb[col], emb[row]
            
            merged = torch.stack([f1, f2], dim=0)
            mean = merged.mean(dim=0)
            max, _ = merged.max(dim=0)
            min, _ = merged.min(dim=0)
            f12 = torch.cat([mean, max, min], dim=-1)

            assign_log_logits = self.feature_extractor(f12)
        else:
            assign_log_logits = self.feature_extractor(emb)
        return assign_log_logits
    

class GPLV2(nn.Module):
    def __init__(self, encoder: GIN, config):
        super().__init__()
        self.encoder = encoder
        self.config = config
        self.model_config = config['model']
        self.gpl_config = config['framework']
        self.training_config = config['training']

        self.num_class = self.gpl_config['num_class']
        self.multi_label = self.gpl_config['multi_label']

        self.with_assign_matrix = self.gpl_config['with_assign_matrix']
        self.with_cluster_loss = self.gpl_config['with_cluster_loss']
        self.with_ib_constraint = self.gpl_config['with_ib_constraint']
        self.ib_constraint_type = self.gpl_config['ib_constraint_type']
        self.ib_coeff_mask = self.gpl_config['ib_coeff_mask'] # assignment & group selector, coeff on ib loss.
        self.ib_coeff_vector = self.gpl_config['ib_coeff_vector']
        self.ib_coeff_uppb = self.gpl_config['ib_coeff_uppb']
        self.assign_edge = self.gpl_config['assign_edge']

        self.pred_loss_coeff = self.gpl_config['pred_loss_coeff']

        self.with_ct_loss = self.gpl_config['with_ct_loss']
        self.raw_data_in_queue = self.gpl_config['raw_data_in_queue']
        self.use_momentum_module = self.gpl_config['use_momentum_module']
        self.momentum_coeff = self.gpl_config['momentum_coeff'] # generally 0.999
        self.queue_max_len = self.gpl_config['queue_max_len']
        self.ct_coeff = self.gpl_config['ct_loss_coeff']
        self.fix_ib = self.gpl_config['fix_ib']

        
        self.fix_r = self.gpl_config['fix_r']
        self.decay_interval = self.gpl_config['decay_interval']
        self.init_r = self.gpl_config['init_r']
        self.final_r = self.gpl_config['final_r']
        self.decay_r = self.gpl_config['decay_r']
        self.criterion = Criterion(self.gpl_config['num_class'], self.gpl_config['multi_label'])
        self.with_rate_loss = self.gpl_config['with_rate_loss']
        self.rate_loss_coeff = self.gpl_config['rate_loss_coeff']

        self.with_reconstruct = self.gpl_config['with_reconstruct']
        self.recon_loss_coeff = self.gpl_config['recon_loss_coeff']
        self.pos_weight = self.gpl_config['pos_weight']
        ##################################################### basic initialize
        self.queue = {}
        for i in range(self.num_class):
            self.queue[i] = deque(maxlen=self.queue_max_len)
        
        
        output_dim = 1 if self.num_class == 2 and not self.multi_label else self.num_class
        assert len(self.model_config['clf_channels']) == 2 # 这里认为clf_channels不包含最后输出dim，输出dim由数据集中算出的num_class决定
        clf_channels = self.model_config['clf_channels'] + [output_dim]
        print('[clf_channels]:', clf_channels)

        self.classifier = MLPClean(clf_channels, dropout=0, with_softmax=False)

        #####################################################
        self.initialize()
        self.device = config.device
        
    def get_r(self):
        if self.fix_r:
            r = self.final_r
        else:
            current_epoch = self.__trainer__.cur_epoch
            r = self.init_r - current_epoch // self.decay_interval * self.decay_r
            if r < self.final_r:
                r = self.final_r
        return r

    def _get_ib_coeff(self):
        current_epoch = self.__trainer__.cur_epoch

        if self.fix_ib:
            return self.ib_coeff
        else:
            if current_epoch >= 100:
                return self.ib_coeff
            else:
                start = 0.9
                end = self.ib_coeff
                val = start + (end-start) * (100-current_epoch)/100
                return max(val, end)

    def _get_ct_coeff(self):
        current_epoch = self.__trainer__.cur_epoch
        if self.fix_ct:
            return self.ct_coeff
        else:
            if current_epoch < 100:
                return 0
            else:
                start = 0.01
                end = 0.03
                ct = start + (end-start) * (current_epoch-100)//100
                return min(ct, end)


    def initialize(self):
        self.classifier_encoder = copy.deepcopy(self.encoder)
        if self.assign_edge:
            self.edge_assigner = AssignerMLP(channels=self.model_config['edge_assigner_channels'], dropout_p=self.model_config['dropout_p'], assign_edge=True)
        else:
            self.node_assigner = AssignerMLP(channels=self.model_config['node_assigner_channels'], dropout_p=self.model_config['dropout_p'], assign_edge=False)
        
        self.mean_encoder = MLPClean(self.model_config['mean_encoder_channels'], dropout=self.model_config['dropout_p'], with_softmax=False)
        self.std_encoder = MLPClean(self.model_config['std_encoder_channels'], dropout=self.model_config['dropout_p'], with_softmax=False)

        self.mcr2_loss = MaximalCodingRateReduction()
        self.mcr2_batchnorm = nn.BatchNorm1d(64)

        self.ib_coeff_scheduler = None
        self.ct_coeff_scheduler = None


    @property
    def queue_empty(self):
        empty = False
        for i in range(self.num_class):
            if len(self.queue[i]) == 0:
                empty = True
                break
        return empty

    
    def update_momentum_modules(self,):
        for p1, p2 in zip(self.momentum_encoder.parameters(), self.encoder.parameters()):
            p1 = self.momentum_coeff * p1 + (1-self.momentum_coeff) * p2

        for p1, p2 in zip(self.momentum_assigner.parameters(), self.assigner.parameters()):
            p1 = self.momentum_coeff * p1 + (1-self.momentum_coeff) * p2
            

    def get_queue_instance_embs(self, queue):
        if self.raw_data_in_queue:
            instances = list(queue)
            batch_data = Batch.from_data_list(instances)
            queue_instance_embs = self.get_embs(batch_data)[0]
        else:
            queue_instance_embs = torch.cat(list(queue), dim=0)

        return queue_instance_embs
    
    def get_same_class_embs(self, y):
        assert not self.queue_empty
        queue = self.queue[y]
        embs = self.get_queue_instance_embs(queue)
        return embs

    def get_other_class_embs(self, this_y):
        assert not self.queue_empty
        all_embs = []
        for y in range(self.num_class):
            if y != this_y:
                queue = self.queue[y]
                embs = self.get_queue_instance_embs(queue)
                all_embs.append(embs)

        all_embs = torch.cat(all_embs, dim=0)
        return all_embs

    def update_queue(self, original_batch_data, embs):
        clf_labels = original_batch_data.y
        unique_y = torch.unique(clf_labels).cpu().numpy().tolist()

        for y in unique_y:
            class_idx = clf_labels == y
            class_idx = class_idx.view(-1)
            if self.raw_data_in_queue:
                instances = original_batch_data[class_idx] # instance version
                self.queue[y].append( Batch.from_data_list( instances) )
            else:
                self.queue[y].append( embs[class_idx].detach() ) # embedding version
        


    def __loss__(self, clf_logits, clf_labels, mean, std, node_mask, edge_mask, sampled_embs, edge_index, batch):
        if self.pred_loss_coeff == 0.0:
            pred_loss = torch.tensor(0.0)
        else:
            pred_loss = self.criterion(clf_logits, clf_labels)
            pred_loss = pred_loss * self.pred_loss_coeff
        
        cluster_loss = torch.tensor(0.0)
        
        if self.with_ib_constraint: # information bottleneck constraint
            if self.ib_constraint_type == 'vector':
                ib_const_loss = gaussianKL(mean, std)
                ib_const_loss = ib_const_loss * self._get_ib_coeff()
            elif self.ib_constraint_type == 'edge_prob':
                assert self.with_assign_matrix is True
                mask_value = edge_mask

                ib_const_loss = maskConstKL(mask_value, self.get_r())
                ib_const_loss = ib_const_loss * self.ib_coeff
            elif self.ib_constraint_type == 'both':
                raise NotImplementedError
        else:
            ib_const_loss = torch.tensor(0.0)
        
        if self.with_ct_loss and not self.queue_empty : # without warm up
            ct_loss = torch.tensor(0.0).to(mean.device)
            unique_y = torch.unique(clf_labels).cpu().numpy().tolist()
            for y in unique_y: # all queues are not empty
                class_idx = clf_labels == y
                class_idx = class_idx.view(-1)
                this_class_embs = mean[class_idx]

                same_class_embs = self.get_same_class_embs(y)
                other_class_embs = self.get_other_class_embs(this_y=y)
                
                this_class_ct_loss = supervisedCT(this_class_embs, same_class_embs, other_class_embs)
                ct_loss += this_class_ct_loss
            ct_loss = ct_loss / len(unique_y)
            ct_loss = ct_loss * self._get_ct_coeff()

        else:
            ct_loss = torch.tensor(0.0) 
        
        if self.with_rate_reduction_loss:
            X = embs
            Y = clf_labels.to(torch.int).reshape(-1)
            X = F.normalize(X)
            import ipdb; ipdb.set_trace()
            mcr2_loss = self.mcr2_loss(X, Y, num_classes=self.num_class) # logdet has a problem!
            mcr2_loss = mcr2_loss * self.rate_reduction_loss_coeff
        else:
            mcr2_loss = torch.tensor(0.0)
            

        # overall loss
        
        loss = pred_loss + ib_const_loss + ct_loss + mcr2_loss


        step_dict = {
            'loss': loss,
            'pred_loss': pred_loss.item(),
            'cluster_loss': cluster_loss.item(),
            
        }

       
        if self.with_ib_constraint:
            step_dict['ib_loss'] = ib_const_loss
        if self.with_ct_loss:
            step_dict['ct_loss'] = ct_loss

        return step_dict


    def configure_optimizers(self):
        opt_params = self.training_config['optimizer_params']
        opt_type = opt_params['optimizer_type']
        lr = opt_params['lr']
        l2 = opt_params['l2']
        opt = get_optimizer(self, opt_type, lr, l2)
        return opt


    def forward_pass(self, data, batch_idx):
        mean, std, graph_embs, clf_labels, edge_index, batch, node_mask, edge_mask = self.get_embs(data)

        # classification
        if self.with_ib_constraint and self.ib_constraint_type == 'vector':
            noise = torch.normal(mean=0.0, std=1.0, size=mean.shape).to(mean.device)
            sampled = mean + noise * std
        else:
            sampled = mean
        clf_logits = self.classifier(sampled)

        # update the momentum modules
        if self.with_ct_loss and self.training is True:
            self.update_queue(original_batch_data=data, embs=mean)
           

        # compute loss
        loss_dict = self.__loss__(clf_logits=clf_logits, 
                                clf_labels=clf_labels,
                                mean=mean,
                                std=std,
                                node_mask=node_mask,
                                edge_mask=edge_mask,
                                sampled_embs=sampled,
                                edge_index=edge_index,
                                batch=batch,
                            )
        
        
        

        loss_dict['clf_logits'] = clf_logits
        loss_dict['y'] = clf_labels
        
        loss_dict['node_mask'] = node_mask
        loss_dict['edge_mask'] = edge_mask

        loss_dict['batch'] = batch
        loss_dict['edge_index'] = edge_index
        loss_dict['exp_labels'] = data.edge_label

        return loss_dict
    
    def sampling(self, logits: Tensor, gumbel: bool, hard: bool=False, tau: float=1.0):
        if self.training and gumbel:
            mask = F.gumbel_softmax(logits, hard=hard, tau=tau) # [N, K]
        else:
            mask = logits.softmax(dim=1)
        
        mask = mask[:, 1] 
        return mask


    @staticmethod
    def concrete_sample(att_log_logit, temp, training):
        if training:
            random_noise = torch.empty_like(att_log_logit).uniform_(1e-10, 1 - 1e-10)
            random_noise = torch.log(random_noise) - torch.log(1.0 - random_noise) # 
            att_bern = ((att_log_logit + random_noise) / temp).sigmoid() # binary concrete distribution, 
        else:
            att_bern = (att_log_logit).sigmoid()
        return att_bern

    
    @staticmethod
    def lift_edge_att_to_node_att(edge_att, edge_index, N):
        
        node_att_0 = scatter(edge_att, index=edge_index[0], reduce='mean')
        node_att_1 = scatter(edge_att, index=edge_index[1], reduce='mean')
        if len(node_att_0) < len(node_att_1):
            node_att = node_att_1
        else:
            node_att = node_att_0

        
        return node_att
    
    @staticmethod
    def lift_node_att_to_edge_att(node_att, edge_index):
        src_lifted_att = node_att[edge_index[0]]
        dst_lifted_att = node_att[edge_index[1]]
        edge_att = src_lifted_att * dst_lifted_att
        
        return edge_att
    
    def get_embs(self, data):
        
        data = data.to(self.device)
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        y = data.y
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        
        assert self.encoder.graph_pooling is False, 'Should obtain node embeddings now'
        N = x.shape[0]

        if self.with_assign_matrix: # with mask or not
            embs = self.encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr) # node-level embeddings
            
            use_gumbel = True if self.ib_constraint_type == 'edge_prob' else False
            if self.assign_edge:
                edge_assign_logits = self.edge_assigner(embs, edge_index) # [N, K], N is the number of nodes in one batch

                edge_mask = self.sampling(edge_assign_logits, gumbel=use_gumbel, hard=self.gpl_config['gumbel_hard'], tau=self.gpl_config['gumbel_tau'])
                node_mask = self.lift_edge_att_to_node_att(edge_mask, edge_index, N) # 可能会出错
            else:
                node_assign_logits = self.node_assigner(embs, edge_index)
                node_mask = self.sampling(node_assign_logits, gumbel=use_gumbel, hard=self.gpl_config['gumbel_hard'], tau=self.gpl_config['gumbel_tau'])
                edge_mask = self.lift_node_att_to_edge_att(node_mask, edge_index)

            edge_mask = edge_mask.reshape(-1, 1)
            node_mask= node_mask.reshape(-1, 1)
        else:
            edge_assign_logits = None
            node_mask = torch.ones((N, 1)).to(device)
            edge_mask = torch.ones(len(edge_index[0]), 1).to(device)

        new_embs = self.classifier_encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr, edge_atten=edge_mask) # node-level embeddings
        new_embs = new_embs * node_mask 
        new_embs = scatter(new_embs, batch, dim=0, reduce='sum') # [B, dim]
        
        graph_embs = new_embs
        if self.with_ib_constraint and self.ib_constraint_type == 'vector':
            mean = self.mean_encoder(graph_embs)
            std = F.relu( self.std_encoder(graph_embs) )
        else:
            mean = graph_embs
            std = torch.zeros_like(mean)
        
        return mean, std, graph_embs, y, edge_index, batch, node_mask, edge_mask

    