import copy
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.data import Batch
from torch_scatter import scatter
from collections import defaultdict, deque
from torch_geometric.nn import global_mean_pool


from gpl.models.gin import GIN
from gpl.models.mlp import MLPClean
from gpl.models.mcr2 import MaximalCodingRateReduction
from gpl.training import get_optimizer
from gpl.models.gpl import Criterion


class Prediction(nn.Module):
    def __init__(self, encoder: GIN, config):
        super().__init__()
        self.encoder = encoder
        self.config = config
        self.model_config = config['model']
        self.gpl_config = config['framework']
        self.training_config = config['training']

        self.num_class = self.gpl_config['num_class']
        self.multi_label = self.gpl_config['multi_label']

        
        
        self.criterion = Criterion(self.gpl_config['num_class'], self.gpl_config['multi_label'])
        
        ##################################################### basic initialize
       
        output_dim = 1 if self.num_class == 2 and not self.multi_label else self.num_class
        assert len(self.model_config['clf_channels']) == 2 
        clf_channels = self.model_config['clf_channels'] + [output_dim]
        print('[clf_channels]:', clf_channels)

        self.classifier = MLPClean(clf_channels, dropout=0, with_softmax=False)

        #####################################################
        self.device = config.device
    
    def configure_optimizers(self):
        opt_params = self.training_config['optimizer_params']
        opt_type = opt_params['optimizer_type']
        lr = opt_params['lr']
        l2 = opt_params['l2']
        opt = get_optimizer(self, opt_type, lr, l2)
        return opt
    
    def forward_pass(self, data, batch_idx):
        return_dict = self.get_embs(data)
        graph_embs = return_dict['graph_embs']
        
        clf_logits = self.classifier(graph_embs)

        # compute loss
        loss_dict = dict()
        prediction_loss = self.criterion(clf_logits, data.y)
        loss_dict['loss'] = prediction_loss

        loss_dict['clf_logits'] = clf_logits
        loss_dict['y'] = data.y

        loss_dict['batch'] = data.batch
        loss_dict['edge_index'] = data.edge_index

        return loss_dict
    
    
    def get_embs(self, data):
        data = data.to(self.device)
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        y = data.y
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        
        assert self.encoder.graph_pooling is False, 'Should obtain node embeddings now'
        N = x.shape[0]
        embs = self.encoder.get_emb(x=x, edge_index=edge_index, batch=batch, edge_attr=edge_attr) # node-level embeddings

        graph_embs = global_mean_pool(embs, data.batch)

        return_dict = {
            'graph_embs': graph_embs,
        }
        
        return return_dict

    