import argparse
import random
import warnings
import numpy as np

import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_scatter import scatter
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Compose
from torch_geometric.nn import global_add_pool

from tu_datasets import TUDataset
from evaluate_embedding import evaluate_embedding
from gin import WGINConv, ChebNetII_V2, GPRGNN_V2, BernNet_V2
from view_generator import ViewLearner
from utils import initialize_edge_weight, initialize_node_features, set_tu_dataset_y_shape
warnings.filterwarnings('ignore')


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)


class GIN(torch.nn.Module):
	def __init__(self, num_dataset_features, emb_dim=300, num_gc_layers=5, drop_ratio=0.0, pooling_type="standard", is_infograph=False, device=None):
		super(GIN, self).__init__()
		self.device = device
		self.pooling_type = pooling_type
		self.emb_dim = emb_dim
		self.num_gc_layers = num_gc_layers
		self.drop_ratio = drop_ratio
		self.is_infograph = is_infograph

		if self.pooling_type == "standard":
			self.out_graph_dim = self.emb_dim
		elif self.pooling_type == "layerwise":
			self.out_graph_dim = self.emb_dim * self.num_gc_layers
		else:
			raise NotImplementedError

		self.convs = torch.nn.ModuleList()
		self.bns = torch.nn.ModuleList()
		for i in range(num_gc_layers):
			if i:
				nn = Sequential(Linear(emb_dim, emb_dim), ReLU(), Linear(emb_dim, emb_dim))
			else:
				nn = Sequential(Linear(num_dataset_features, emb_dim), ReLU(), Linear(emb_dim, emb_dim))
			conv = WGINConv(nn)
			bn = torch.nn.BatchNorm1d(emb_dim)
			self.convs.append(conv)
			self.bns.append(bn)

	def forward(self, batch, x, edge_index):
		xs = []
		for i in range(self.num_gc_layers):
			x = self.convs[i](x, edge_index, None)
			x = self.bns[i](x)
			if i == self.num_gc_layers - 1:
				x = F.dropout(x, self.drop_ratio, training=self.training)
			else:
				x = F.dropout(F.relu(x), self.drop_ratio, training=self.training)
			xs.append(x)
		if self.pooling_type == "standard":
			xpool = global_add_pool(x, batch)
			return xpool, x
		elif self.pooling_type == "layerwise":
			xpool = [global_add_pool(x, batch) for x in xs]
			xpool = torch.cat(xpool, 1)
			if self.is_infograph:
				return xpool, torch.cat(xs, 1)
			else:
				return xpool, x
		else:
			raise NotImplementedError

	def get_embeddings(self, loader, is_rand_label=False):
		ret = []
		y = []
		with torch.no_grad():
			for data in loader:
				if isinstance(data, list):
					data = data[0].to(self.device)
				data = data.to(self.device)
				batch, x, edge_index = data.batch, data.x, data.edge_index
				if x is None:
					x = torch.ones((batch.shape[0], 1)).to(self.device)
				x, _ = self.forward(batch, x, edge_index)
				
				ret.append(x.cpu().numpy())
				if is_rand_label:
					y.append(data.rand_label.cpu().numpy())
				else:
					y.append(data.y.cpu().numpy())
		ret = np.concatenate(ret, 0)
		y = np.concatenate(y, 0)
		return ret, y


class GInfoMinMax(torch.nn.Module):
	def __init__(self, encoder, embedding_dim, proj_hidden_dim=300):
		super(GInfoMinMax, self).__init__()
		self.encoder = encoder
		self.input_proj_dim = embedding_dim
		self.proj_head = Sequential(Linear(self.input_proj_dim, proj_hidden_dim), ReLU(inplace=True), Linear(proj_hidden_dim, proj_hidden_dim))

		self.init_emb()

	def init_emb(self):
		for m in self.modules():
			if isinstance(m, Linear):
				torch.nn.init.xavier_uniform_(m.weight.data)
				if m.bias is not None:
					m.bias.data.fill_(0.0)

	def forward(self, batch, x, edge_index):
		z, node_emb = self.encoder(batch, x, edge_index)
		z = self.proj_head(z)
		return z, node_emb

	@staticmethod
	def calc_loss( x, x_aug, temperature=0.2, sym=True, loss_type='info'):
		batch_size, _ = x.size()
		x_abs = x.norm(dim=1)
		x_aug_abs = x_aug.norm(dim=1)

		sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs)
		sim_matrix = torch.exp(sim_matrix / temperature)
		pos_sim = sim_matrix[range(batch_size), range(batch_size)]
		if sym:
			loss_0 = pos_sim / (sim_matrix.sum(dim=0) - pos_sim)
			loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
			loss_0 = - torch.log(loss_0).mean()
			loss_1 = - torch.log(loss_1).mean()
			loss = (loss_0 + loss_1) / 2.0
			return loss
				
		else:
			loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
			loss_1 = - torch.log(loss_1).mean()
			return loss_1
	

def run(args):
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(args)
    print('------------------------')
    setup_seed(args.seed)

    my_transforms = Compose([initialize_node_features, initialize_edge_weight, set_tu_dataset_y_shape])
    dataset = TUDataset('data', args.dataset, transform=my_transforms)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    dataloader_eval = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    try:
       dataset_num_features = dataset[0].x.shape[1]
    except:
       dataset_num_features = 1
    print(dataset_num_features)
	
    if args.encoder == 'GIN':
       encoder = GIN(num_dataset_features=1, emb_dim=args.emb_dim, num_gc_layers=args.num_layers, drop_ratio=args.drop_ratio, pooling_type=args.pooling_type, device=device)
       embedding_dim = args.emb_dim
    elif args.encoder == 'ChebNetII_V2':
       encoder = ChebNetII_V2(dataset_num_features, args, device=device)
       embedding_dim = dataset_num_features * (args.K + 1)
    elif args.encoder == 'GPRGNN_V2':
       encoder = GPRGNN_V2(dataset_num_features, args, device=device)
       embedding_dim = dataset_num_features * (args.K + 1)
    elif args.encoder == 'BernNet_V2':
       encoder = BernNet_V2(dataset_num_features, args, device=device)
       embedding_dim = dataset_num_features * (args.K + 1)
	   
    model = GInfoMinMax(encoder, embedding_dim, proj_hidden_dim=embedding_dim).to(device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr=args.model_lr)

    view_encoder = GIN(num_dataset_features=1, emb_dim=args.emb_dim, num_gc_layers=args.num_layers, drop_ratio=args.drop_ratio, pooling_type=args.pooling_type, device=device)
    view_learner = ViewLearner(view_encoder, emb_dim=embedding_dim, mlp_edge_model_dim=args.mlp_edge_model_dim).to(device)
    view_optimizer = torch.optim.Adam(view_learner.parameters(), lr=args.view_lr)

    model_losses = []
    view_losses = []
    view_regs = []
    for epoch in range(1, args.epochs + 1):
        model_loss_all = 0
        view_loss_all = 0
        reg_all = 0
        for batch in dataloader:
            # set up
            batch = batch.to(device)

            # train view to maximize contrastive loss
            view_learner.train()
            view_learner.zero_grad()
            model.eval()

            x, _ = model(batch.batch, batch.x, batch.edge_index)
            edge_logits = view_learner(batch.batch, batch.x, batch.edge_index, None)

            temperature = 1.0
            bias = 0.0 + 0.0001  # If bias is 0, we run into problems
            eps = (bias - (1 - bias)) * torch.rand(edge_logits.size()) + (1 - bias)
            gate_inputs = torch.log(eps) - torch.log(1 - eps)
            gate_inputs = gate_inputs.to(device)
            gate_inputs = (gate_inputs + edge_logits) / temperature
            batch_aug_edge_weight = torch.sigmoid(gate_inputs).squeeze()

            x_aug, _ = model(batch.batch, batch.x, batch.edge_index)

            # regularization
            row, col = batch.edge_index
            edge_batch = batch.batch[row]
            edge_drop_out_prob = 1 - batch_aug_edge_weight
            uni, edge_batch_num = edge_batch.unique(return_counts=True)
            sum_pe = scatter(edge_drop_out_prob, edge_batch, reduce="sum")

            reg = []
            for b_id in range(args.batch_size):
                if b_id in uni:
                    num_edges = edge_batch_num[uni.tolist().index(b_id)]
                    reg.append(sum_pe[b_id] / num_edges)
                else:
                    # means no edges in that graph. So don't include.
                    pass
            reg = torch.stack(reg)
            reg = reg.mean()

            view_loss = model.calc_loss(x, x_aug, loss_type=args.loss) - (args.reg_lambda * reg)
            view_loss_all += view_loss.item() * batch.num_graphs
            reg_all += reg.item()
            # gradient ascent formulation
            (-view_loss).backward()
            view_optimizer.step()

            # train (model) to minimize contrastive loss
            model.train()
            view_learner.eval()
            model.zero_grad()

            x, _ = model(batch.batch, batch.x, batch.edge_index)
            edge_logits = view_learner(batch.batch, batch.x, batch.edge_index)

            temperature = 1.0
            bias = 0.0 + 0.0001  # If bias is 0, we run into problems
            eps = (bias - (1 - bias)) * torch.rand(edge_logits.size()) + (1 - bias)
            gate_inputs = torch.log(eps) - torch.log(1 - eps)
            gate_inputs = gate_inputs.to(device)
            gate_inputs = (gate_inputs + edge_logits) / temperature
            batch_aug_edge_weight = torch.sigmoid(gate_inputs).squeeze().detach()

            x_aug, _ = model(batch.batch, batch.x, batch.edge_index, None, batch_aug_edge_weight)

            model_loss = model.calc_loss(x, x_aug, loss_type=args.loss)
            model_loss_all += model_loss.item() * batch.num_graphs
            # standard gradient descent formulation
            model_loss.backward()
            model_optimizer.step()

        fin_model_loss = model_loss_all / len(dataloader)
        fin_view_loss = view_loss_all / len(dataloader)
        fin_reg = reg_all / len(dataloader)

        model_losses.append(fin_model_loss)
        view_losses.append(fin_view_loss)
        view_regs.append(fin_reg)

    accs = []
    for _  in range(5):
        model.eval()
        emb, y = model.encoder.get_embeddings(dataloader_eval, device, is_rand_label=False)
        acc_mean, acc_std = evaluate_embedding(emb, y)
        accs.append(acc_mean)
    print(f"(E) Test: acc = {np.mean(accs):.4f} +- {np.std(accs):.4f}")


def arg_parse():
    parser = argparse.ArgumentParser(description='AD-GCL TU')
    parser.add_argument('--dataset', type=str, default='IMDB-BINARY', help='Dataset')
    parser.add_argument('--gpu', type=str, default='2')
    parser.add_argument('--pooling_type', type=str, default='standard', help='GNN Pooling Type Standard/Layerwise')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size')
    parser.add_argument('--num_layers', type=int, default=5, help='Number of GNN layers before pooling')
    parser.add_argument('--seed', type=int, default=42)

    parser.add_argument('--mlp_edge_model_dim', type=int, default=64, help='embedding dimension')
    parser.add_argument('--epochs', type=int, default=150, help='Train Epochs')
    parser.add_argument('--emb_dim', type=int, default=32, help='embedding dimension')
    parser.add_argument('--model_lr', type=float, default=0.001, help='Model Learning rate.')
    parser.add_argument('--view_lr', type=float, default=0.001, help='View Learning rate.')

    parser.add_argument('--drop_ratio', type=float, default=0.0, help='Dropout Ratio / Probability')
    parser.add_argument('--reg_lambda', type=float, default=5.0, help='View Learner Edge Perturb Regularization Strength')
	
    parser.add_argument('--encoder', type=str, default='GIN')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    parser.add_argument('--use_bn', action='store_true')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha for APPN/GPRGNN.')
    parser.add_argument('--dprate', type=float, default=0.0, help='dropout for propagation layer.')
    parser.add_argument('--q', type=int, default=0, help='The constant for ChebBase.')
    parser.add_argument('--Init', type=str,choices=['SGC', 'PPR', 'NPPR', 'Random', 'WS', 'Null'], default='PPR', help='initialization for GPRGNN.')
    return parser.parse_args()


if __name__ == '__main__':
    print()
    args = arg_parse()
    run(args)