
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool, BatchNorm, GCNConv, GINConv, MessagePassing
from torch_geometric.utils import add_self_loops
from captum.attr import IntegratedGradients
from model_pretrain import PretrainGIN, GINet



class MLP(torch.nn.Module):
    def __init__(self, args):
        super().__init__()

        self.fc = torch.nn.ModuleList()
        self.bn = torch.nn.ModuleList()
        for i in range(args.mlp_fc_layer):
            self.fc.append(torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.out = torch.nn.Linear(args.hidden_dim, args.output_dim)

    def embed(self, fp):
        x = fp

        for fc, bn in zip(self.fc, self.bn):
            x = fc(x)
            x = bn(x)
            x = F.relu(x)

        return x

    def forward(self, fp, need_emb=True):
        if need_emb:
            x = self.embed(fp)
        else:
            x = fp

        x = self.out(x)
        x = F.log_softmax(x, 1)

        return x
    














class GCN(torch.nn.Module):
    def __init__(self, args):
        super().__init__()

        self.atom_embedding = torch.nn.Linear(args.mol_emb_dim, args.hidden_dim)

        self.graph_conv = torch.nn.ModuleList()
        self.graph_bn = torch.nn.ModuleList()
        for _ in range(args.gcn_graph_conv_layer):
            self.graph_conv.append(GCNConv(args.hidden_dim, args.hidden_dim))
            self.graph_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.x_fc = torch.nn.ModuleList()
        self.x_bn = torch.nn.ModuleList()
        for i in range(args.gcn_x_fc_layer):
            self.x_fc.append(torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.x_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.out = torch.nn.Linear(args.hidden_dim, args.output_dim)

    def embed(self, graph):
        x, edge_index, batch = graph.x, graph.edge_index, graph.batch
        
        x = F.elu(self.atom_embedding(x))

        for conv, bn in zip(self.graph_conv, self.graph_bn):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)

        x = global_add_pool(x, batch)

        for fc, bn in zip(self.x_fc, self.x_bn):
            x = fc(x)
            x = bn(x)
            x = F.relu(x)
        
        return x
    
    def forward(self, graph, need_emb=True):
        if need_emb:
            x = self.embed(graph)
        else:
            x = graph

        x = self.out(x)
        x = F.log_softmax(x, 1)

        return x















class GINL(torch.nn.Module):
    def __init__(self, args, need_pretrain=False):
        super().__init__()

        self.need_pretrain = need_pretrain
        self.beta = 0.1

        self.atom_embedding = torch.nn.Linear(args.mol_emb_dim, args.hidden_dim)

        self.graph_conv = torch.nn.ModuleList()
        self.graph_bn = torch.nn.ModuleList()
        for _ in range(args.gin_graph_conv_layer):
            gin_mlp = torch.nn.Sequential(torch.nn.Linear(args.hidden_dim, args.hidden_dim),
                                          torch.nn.ReLU())
            self.graph_conv.append(GINConv(nn=gin_mlp))
            self.graph_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.x_fc = torch.nn.ModuleList()
        self.x_bn = torch.nn.ModuleList()
        for i in range(args.gin_x_fc_layer):
            self.x_fc.append(torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.x_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))
            
        self.fp_fc = torch.nn.ModuleList()
        self.fp_bn = torch.nn.ModuleList()
        for i in range(args.gin_fp_fc_layer):
            self.fp_fc.append(torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.fp_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.out = torch.nn.Linear(args.hidden_dim, args.output_dim)

        if need_pretrain:
            self.pretrain = PretrainGIN(args, emb_dim=300, layer_num=5)
            self.xp_fc = torch.nn.Linear(300, args.hidden_dim)

    def embed(self, graph):
        x, edge_index, xp, edgep_index, edgep_attr, batch, fp = graph.x, graph.edge_index, graph.xp, graph.edgep_index, graph.edgep_attr, graph.batch, graph.fp

        x = F.elu(self.atom_embedding(x))

        for conv, bn in zip(self.graph_conv, self.graph_bn):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            
        x = global_add_pool(x, batch)

        for fc, bn in zip(self.x_fc, self.x_bn):
            x = fc(x)
            x = bn(x)
            x = F.relu(x)

        for fc, bn in zip(self.fp_fc, self.fp_bn):
            fp = fc(fp)
            fp = bn(fp)
            fp = F.relu(fp)
        
        x = x + fp

        if self.need_pretrain:
            xp = self.pretrain(xp, edgep_index, edgep_attr, batch)
            xp = self.xp_fc(xp)

            x = x + self.beta * xp

        return x
    
    def forward(self, graph, need_emb=True):
        if need_emb:
            x = self.embed(graph)
        else:
            x = graph

        x = self.out(x)
        x = F.log_softmax(x, 1)

        return x
















class GINE(torch.nn.Module):
    def __init__(self, args, need_pretrain=False):
        super().__init__()

        self.need_pretrain = need_pretrain
        self.beta = 0.1

        self.atom_embedding = torch.nn.Linear(args.mol_emb_dim, args.hidden_dim)

        self.graph_conv = torch.nn.ModuleList()
        self.graph_bn = torch.nn.ModuleList()
        for _ in range(args.gine_graph_conv_layer):
            gin_mlp = torch.nn.Sequential(torch.nn.Linear(args.hidden_dim, args.hidden_dim),   ###
                                          torch.nn.ReLU(),
                                          torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.graph_conv.append(GINConv(nn=gin_mlp))
            self.graph_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.x_fc = torch.nn.ModuleList()
        self.x_bn = torch.nn.ModuleList()
        for i in range(args.gine_x_fc_layer):   ###
            self.x_fc.append(torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.x_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))
            
        self.fp_fc = torch.nn.ModuleList()
        self.fp_bn = torch.nn.ModuleList()
        for i in range(args.gine_fp_fc_layer):
            self.fp_fc.append(torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.fp_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.out = torch.nn.Linear(args.hidden_dim, args.output_dim)

        if need_pretrain:
            self.pretrain = PretrainGIN(args, emb_dim=300, layer_num=5)
            self.xp_fc = torch.nn.Linear(300, args.hidden_dim)

    def embed(self, graph):
        x, edge_index, xp, edgep_index, edgep_attr, batch, fp = graph.x, graph.edge_index, graph.xp, graph.edgep_index, graph.edgep_attr, graph.batch, graph.fp

        x = F.elu(self.atom_embedding(x))

        for conv, bn in zip(self.graph_conv, self.graph_bn):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            
        x = global_add_pool(x, batch)

        for fc, bn in zip(self.x_fc, self.x_bn):
            x = fc(x)
            x = bn(x)
            x = F.relu(x)

        for fc, bn in zip(self.fp_fc, self.fp_bn):
            fp = fc(fp)
            #fp = bn(fp)   ###
            fp = F.relu(fp)
        
        x = x + fp

        if self.need_pretrain:
            xp = self.pretrain(xp, edgep_index, edgep_attr, batch)
            xp = self.xp_fc(xp)

            x = x + self.beta * xp

        return x
    
    def forward(self, graph, need_emb=True):
        if need_emb:
            x = self.embed(graph)
        else:
            x = graph

        x = self.out(x)
        x = F.log_softmax(x, 1)

        return x
    













class GIND(torch.nn.Module):
    def __init__(self, args, need_pretrain=False):
        super().__init__()

        self.need_pretrain = need_pretrain
        self.beta = 0.1

        self.atom_embedding = torch.nn.Linear(args.mol_emb_dim, args.hidden_dim)

        self.graph_conv = torch.nn.ModuleList()
        self.graph_bn = torch.nn.ModuleList()
        for _ in range(args.gin_graph_conv_layer):
            gin_mlp = torch.nn.Sequential(torch.nn.Linear(args.hidden_dim, args.hidden_dim),
                                          torch.nn.ReLU())
            self.graph_conv.append(GINConv(nn=gin_mlp))
            self.graph_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.x_fc = torch.nn.ModuleList()
        self.x_bn = torch.nn.ModuleList()
        for i in range(args.gin_x_fc_layer):
            self.x_fc.append(torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.x_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))
            
        self.fp_fc = torch.nn.ModuleList()
        self.fp_bn = torch.nn.ModuleList()
        for i in range(args.gin_fp_fc_layer):
            self.fp_fc.append(torch.nn.Linear(args.hidden_dim, args.hidden_dim))
            self.fp_bn.append(BatchNorm(args.hidden_dim, allow_single_element=True))

        self.out = torch.nn.Linear(args.hidden_dim, args.output_dim)

        if need_pretrain:
            self.xp_fc = torch.nn.Linear(128, args.hidden_dim)

    def embed(self, graph):
        x, edge_index, xp, batch, fp = graph.x, graph.edge_index, graph.xp, graph.batch, graph.fp

        x = F.elu(self.atom_embedding(x))

        for conv, bn in zip(self.graph_conv, self.graph_bn):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            
        x = global_add_pool(x, batch)

        for fc, bn in zip(self.x_fc, self.x_bn):
            x = fc(x)
            x = bn(x)
            x = F.relu(x)

        for fc, bn in zip(self.fp_fc, self.fp_bn):
            fp = fc(fp)
            fp = bn(fp)
            fp = F.relu(fp)
        
        x = x + fp

        if self.need_pretrain:
            xp = self.xp_fc(xp)

            x = x + self.beta * xp

        return x
    
    def forward(self, graph, need_emb=True):
        if need_emb:
            x = self.embed(graph)
        else:
            x = graph

        x = self.out(x)
        x = F.log_softmax(x, 1)

        return x
















class Model(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args

        if args.architecture == 'mlp':
            self.model = MLP(args)
        elif args.architecture == 'gcn':
            self.model = GCN(args)

        elif args.architecture == 'ginl':
            self.model = GINL(args)
        elif args.architecture == 'ginlp':
            self.model = GINL(args, need_pretrain=True)
            self.model.pretrain.load_state_dict(torch.load(args.pretrain_file), strict=True)

        elif args.architecture == "gine":
            self.model = GINE(args)
        elif args.architecture == "ginep":
            self.model = GINE(args, need_pretrain=True)
            if args.mode == "e":
                self.model.pretrain.load_state_dict(torch.load(args.pretrain_file), strict=True)
            elif args.mode == "a":
                self.model.load_state_dict(torch.load(args.pretrain_file), strict=True)
        
        elif args.architecture == "gind":
            self.model = GIND(args, need_pretrain=True)


        self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_type)

        self.loss_fn = torch.nn.NLLLoss()
        self.train_loss = []

        self.model = self.model.to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.scaler = torch.cuda.amp.GradScaler()

        if args.anchored:   # Save initial weights in the model for the anchored regularization and move them to the gpu
            self.model.anchor_weights = deepcopy({i: j for i, j in self.model.named_parameters()})
            self.model.anchor_weights = {i: j.to(self.device) for i, j in self.model.anchor_weights.items()}


    def train(self, dataloader):
        ig = IntegratedGradients(self.model)
        model_ref = deepcopy(self.model)

        for epoch_i in range(self.args.epochs):

            model_old = deepcopy(self.model)

            epoch_loss = 0
            epoch_item = 0

            for idx, batch in enumerate(dataloader):
                with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):

                    x, y = batch[0].to(self.device), batch[1].to(self.device)
                    y_hat = self.model(x)

                    if len(y_hat) == 0:
                        y_hat = y_hat.unsqueeze(0)
                    loss = self.loss_fn(y_hat, y.squeeze())

                    if self.args.anchored:   # Calculate the total anchored L2 loss
                        l2_loss = 0
                        for param_name, params in self.model.named_parameters():
                            anchored_param = self.model.anchor_weights[param_name]
                            l2_loss += (self.args.l2_lambda / len(y)) * torch.mul(params - anchored_param, params - anchored_param).sum()

                        loss = loss + l2_loss   # Add anchored loss to regular loss according to Pearce et al. (2018)
                    
                    if self.args.strategy == "grpo":
                        with torch.no_grad():
                            y_hat_ref = model_ref(x)
                            y_hat_old = model_old(x)
                            x_embed = self.model.embed(x)
                        attributions, _ = ig.attribute(x_embed, additional_forward_args=(False,), target=1, n_steps=20, method='gausslegendre', return_convergence_delta=True)

                        ratio = torch.mean(torch.exp(y_hat - y_hat_old), dim=1)

                        attributions = torch.norm(attributions, p=1, dim=1)
                        uncertainty = (attributions - torch.min(attributions)) / (torch.max(attributions) - torch.min(attributions))
                        uncertainty_ratio = torch.where((torch.squeeze(y) == 0) & (torch.exp(y_hat[:, 1]) > 0.5), 1-uncertainty, torch.ones_like(uncertainty))

                        adv = 1 - (torch.squeeze(y) - torch.exp(y_hat[:, 1]))**2 * uncertainty_ratio
                        adv = (adv - torch.min(adv)) / (torch.max(adv) - torch.min(adv))

                        kl = torch.mean(torch.exp(y_hat_ref - y_hat) - (y_hat_ref - y_hat) - 1, dim=1)

                        grpo_loss = -(torch.clamp(ratio, 1-self.args.grpo_epsilon, 1+self.args.grpo_epsilon) * adv - self.args.grpo_beta * kl) * self.args.grpo_lambda
                        grpo_loss = torch.mean(grpo_loss)
                        loss = loss + grpo_loss

                        
                    self.optimizer.zero_grad()
                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()

                    epoch_loss += loss.item()
                    epoch_item += 1

            self.train_loss.append(epoch_loss / epoch_item)


    def predict(self, dataloader):
        y_hats = torch.tensor([])

        with torch.no_grad():
            with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
                for batch in dataloader:

                    x = batch[0].to(self.device)
                    y_hat = self.model(x)

                    if len(y_hat) == 0:
                        y_hat = y_hat.unsqueeze(0)
                    y_hats = torch.cat((y_hats, y_hat.cpu()), 0)

        return y_hats
    













class Ensemble(torch.nn.Module):
    def __init__(self, args):
        seeds = np.random.randint(0, 1000, (args.ensemble_size))

        self.models = {i: Model(args) for i, s in enumerate(seeds)}

    def train(self, dataloader):
        for i, m in self.models.items():
            m.train(dataloader)

    def predict(self, dataloader):
        logits_N_K_C = torch.stack([m.predict(dataloader) for m in self.models.values()], dim=1)

        return logits_N_K_C   # [N, ensemble_num, class_num]
    
    def __getitem__(self, item):
        return self.models[item]

