import os.path as osp
import torch
import torch_geometric
import torch.nn as nn
import numpy as np
import json
from copy import deepcopy

from aug import TUDataset_aug as TUDataset
from aug import permute_edges_graphon
from torch_geometric.data import DataLoader
import json
from torch_geometric.utils import dense_to_sparse, to_dense_adj, mask_feature

from losses import *
from gin import Encoder
from evaluate_embedding import evaluate_embedding
from model import *
from sigl_tools import *

from arguments import arg_parse
import random
import pickle
from sigl_tools import plot_tsne_2d
import warnings
warnings.filterwarnings("ignore")
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')


class GcnInfomax(nn.Module):
  def __init__(self, hidden_dim, num_gc_layers, alpha=0.5, beta=1., gamma=.1):
    super(GcnInfomax, self).__init__()

    self.alpha = alpha
    self.beta = beta
    self.gamma = gamma
    self.prior = args.prior

    self.embedding_dim = mi_units = hidden_dim * num_gc_layers
    self.encoder = Encoder(dataset_num_features, hidden_dim, num_gc_layers)

    self.local_d = FF(self.embedding_dim)
    self.global_d = FF(self.embedding_dim)
    # self.local_d = MI1x1ConvNet(self.embedding_dim, mi_units)
    # self.global_d = MIFCNet(self.embedding_dim, mi_units)

    if self.prior:
        self.prior_d = PriorDiscriminator(self.embedding_dim)

    self.init_emb()

  def init_emb(self):
    initrange = -1.5 / self.embedding_dim
    for m in self.modules():
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)


  def forward(self, x, edge_index, batch, num_graphs):

    # batch_size = data.num_graphs
    if x is None:
        x = torch.ones(batch.shape[0]).to(device)

    y, M = self.encoder(x, edge_index, batch)
    
    g_enc = self.global_d(y)
    l_enc = self.local_d(M)

    mode='fd'
    measure='JSD'
    local_global_loss = local_global_loss_(l_enc, g_enc, edge_index, batch, measure)
 
    if self.prior:
        prior = torch.rand_like(y)
        term_a = torch.log(self.prior_d(prior)).mean()
        term_b = torch.log(1.0 - self.prior_d(y)).mean()
        PRIOR = - (term_a + term_b) * self.gamma
    else:
        PRIOR = 0
    
    return local_global_loss + PRIOR


class simclr(nn.Module):
  def __init__(self, hidden_dim, num_gc_layers, alpha=0.5, beta=1., gamma=.1):
    super(simclr, self).__init__()

    self.alpha = alpha
    self.beta = beta
    self.gamma = gamma
    self.prior = False # args.prior

    self.embedding_dim = mi_units = hidden_dim * num_gc_layers
    self.encoder = Encoder(dataset_num_features, hidden_dim, num_gc_layers)

    self.proj_head = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim), nn.ReLU(inplace=True), nn.Linear(self.embedding_dim, self.embedding_dim))

    self.init_emb()

  def init_emb(self):
    initrange = -1.5 / self.embedding_dim
    for m in self.modules():
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)


  def forward(self, x, edge_index, batch, num_graphs):

    # batch_size = data.num_graphs
    if x is None:
        x = torch.ones(batch.shape[0]).to(device)

    y, M = self.encoder(x, edge_index, batch)
    
    y = self.proj_head(y)
    
    return y
  
  def loss_cal(self, x, x_aug):
    """
    Computes the InfoNCE contrastive loss between original and augmented embeddings.

    Args:
        x (Tensor): Original embeddings of shape [batch_size, embedding_dim]
        x_aug (Tensor): Augmented embeddings of shape [batch_size, embedding_dim] used for positive pairs

    Returns:
        loss (Tensor): Scalar contrastive loss value

    This loss encourages the model to bring positive pairs (x[i], x_aug[i]) closer
    while pushing apart negative pairs (x[i], x_aug[j]) for j ≠ i.
    """

    T = 0.2  # Temperature parameter for scaling similarities

    batch_size, _ = x.size()
    x_abs = x.norm(dim=1)  # L2 norm of each vector in x (shape [batch_size])
    x_aug_abs = x_aug.norm(dim=1)  # L2 norm of each vector in x_aug (shape [batch_size])

    # Compute cosine similarity matrix between all pairs (x[i], x_aug[j])
    sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs)
    sim_matrix = torch.exp(sim_matrix / T)  # Apply temperature scaling and exponentiation

    pos_sim = sim_matrix[range(batch_size), range(batch_size)]  # Positive similarities on the diagonal

    # Compute contrastive loss: log of positive sim over all (minus the positive itself)
    loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
    loss = - torch.log(loss).mean()

    return loss
  
  def loss_cal2(self, x, x_aug1, clusters=None):
    """
    Computes the InfoNCE contrastive loss between original and augmented embeddings,
    using only negatives from different clusters.

    Args:
        x (Tensor): Original embeddings, shape [batch_size, embedding_dim]
        x_aug1 (Tensor): Augmented embeddings, shape [batch_size, embedding_dim] (positive pairs)
        x_aug2 (Tensor): Augmented embeddings, shape [batch_size, embedding_dim] (negative pairs)
        clusters (list or Tensor): Cluster labels, length [batch_size]

    Returns:
        loss (Tensor): Scalar contrastive loss value
    """
    T = 0.2  # Temperature
    device = x.device

    

    batch_size, _ = x.size()
    x_abs = x.norm(dim=1)  # [batch_size]
    x_aug1_abs = x_aug1.norm(dim=1)  # [batch_size]

    # Cosine similarity for the positive pairs (augmented small)
    sim_matrix1 = torch.einsum('ik,jk->ij', x, x_aug1) / torch.einsum('i,j->ij', x_abs, x_aug1_abs)
    sim_matrix1 = torch.exp(sim_matrix1 / T)
    pos_sim = sim_matrix1[range(batch_size), range(batch_size)]  # Positive similarities

    # Mask for different-cluster samples 
    sim_matrix2 = torch.einsum('ik,jk->ij', x, x_aug1) / torch.einsum('i,j->ij', x_abs, x_aug1_abs)
    if clusters is not None:
        # Convert clusters to tensor if it's a list
        if isinstance(clusters, list):
            clusters = torch.tensor(clusters, device=device)
        else:
            clusters = clusters.to(device)
        cluster_mask = clusters.unsqueeze(1) != clusters.unsqueeze(0)  # [batch_size, batch_size]
        cluster_mask = cluster_mask.to(sim_matrix2.device)
        sim_matrix2 = sim_matrix2 * cluster_mask  # Zero out same-cluster similarities
    sim_matrix2 = torch.exp(sim_matrix2 / T)
    neg_sim = sim_matrix2.sum(dim=1)  # Sum over all negatives per sample

    # Compute loss
    # pos_sim is for the data and its augmented version
    # neg_sim is for the data and the augmented version of another data in the other classes
    # augmentation ratio can be different for the positive and negative pairs
    loss = -torch.log(pos_sim / neg_sim).mean()
    return loss
  


def setup_seed(seed):

    torch.manual_seed(seed)
    torch_geometric.seed_everything(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    random.seed(seed)


if __name__ == '__main__':
    
    args = arg_parse()
    print(args)
    seed = args.seed if args.seed is not None else random.randint(1, 10000)
    print("Seed: ", seed)
    setup_seed(seed)

    accuracies = {'val': [], 'test': []}

    epochs = 20 # 20 is defult
    # epochs = 10 if args.DS in ['REDDIT-MULTI-5K', 'COLLAB', 'NCI1'] else epochs
    batch_size = 128 # default is 128
    batch_size = 32 if args.DS in ['MUTAG', 'PROTEINS'] else batch_size
    batch_size = 512 if args.DS in ['REDDIT-MULTI-5K', 'COLLAB', 'NCI1'] else batch_size
    lr = args.lr
    DS = args.DS
    path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', DS)

    if DS=='Sim':
        print("Loading graphons and graphs")
        with open("data/MGCL/cluster_labels_sim.pkl", 'rb') as f:
            dataset, cluster_labels_list, graphons_list, models_ISGL_list = pickle.load(f)
        print("Graphons loaded")
        dataset_eval = [(dataset[i], deepcopy(dataset[i])) for i in range(len(dataset))]
        dataloader_eval = DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)
    else:
        print("Loading graphons")
        with open("data/MGCL/cluster_labels_" + DS + ".pkl", 'rb') as f:
            cluster_labels_list, graphons_list, models_ISGL_list = pickle.load(f)
        print("Graphons loaded")
        dataset = TUDataset(root=path, name=DS, aug='none') # 'none' returns data itself and a dummy augmented view as (data, data_aug)

        try:
            dataset_num_features = dataset.get_num_feature()
        except:
            dataset_num_features = 1
    
        dataset = list(dataset)
        dataset_eval = TUDataset(root=path, name=DS, aug='none').shuffle() 
        dataset_eval = list(dataset_eval)
        dataloader_eval = DataLoader(dataset_eval, batch_size=batch_size)
        

    

    for i in range(len(dataset)):
        if DS == 'Sim':
            data_i = dataset[i]
        else:
            data_i, _ = dataset[i]
        cluster_idx = cluster_labels_list[i]
        trained_inr_i = graphons_list[cluster_idx]
        model_ISGL_i = models_ISGL_list[cluster_idx]
        trained_inr_i = trained_inr_i
        model_ISGL_i = model_ISGL_i
        data_aug_self = permute_edges_graphon(deepcopy(data_i), trained_inr_i, model_ISGL_i, drop_percent=args.Rpos) 
        data_i.cluster = cluster_idx
        dataset[i] = (data_i, data_aug_self) # we replace the dummy augmented view with our own

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    print("Data loaded")

    
    setup_seed(seed)
    model = simclr(args.hidden_dim, args.num_gc_layers).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    print(model)

    print('================')
    print('lr: {}'.format(lr))
    print('num_features: {}'.format(dataset_num_features))
    print('hidden_dim: {}'.format(args.hidden_dim))
    print('num_gc_layers: {}'.format(args.num_gc_layers))
    print('================')


    best_loss = float('inf')
    Epoch_losses = []
    for epoch in range(1, epochs+1):

        loss_all = 0
        model.train()
        # R_data = []
        for data_l in dataloader:
            data, data_aug1 = data_l
            node_num, _ = data.x.size()

            optimizer.zero_grad()
            data = data.to(device)
            data_aug1 = data_aug1.to(device)
            print(data.x.shape)
            x = model(data.x, data.edge_index, data.batch, data.num_graphs)
            x_aug1 = model(data_aug1.x, data_aug1.edge_index, data_aug1.batch, data_aug1.num_graphs)
            
            loss = model.loss_cal2(x, x_aug1, data.cluster)
            print(loss)
            loss_all += loss.item() * data.num_graphs
            loss.backward()
            optimizer.step()

            # if loss.item() < best_loss:
            #     model_state = deepcopy(model.state_dict())

            # c_l = data.cluster
            # y_l = data.y
            # if args.aug == 'pedges_gr':
            #     R_batch = getR(data.cluster, data.y, isMGCL=True)
            # else:
            #     R_batch = getR(data.cluster, data.y, isMGCL=False)
            # R_data = R_data + R_batch


        loss_epoch = loss_all / len(dataloader)
        print('Epoch {}, Loss {}'.format(epoch, loss_epoch))
        Epoch_losses.append(loss_epoch)

        # print(len(R_data))
        # R_e = np.mean(R_data)
        # print("Re=", np.round(R_e,3))

        if epoch % args.log_interval == 0:
            # model.load_state_dict(model_state)
            model.eval()
            emb_final, y_final = model.encoder.get_embeddings(dataloader_eval)
            acc_val, acc_test = evaluate_embedding(emb_final, y_final)
            # acc_val_final = np.round(100*acc_val_final.item(), 3)
            # acc_test_final = np.round(100*acc_test_final.item(), 3)


            acc_val = round(acc_val, 8)
            acc = round(acc_test, 8)
            accuracies['val'].append(acc_val)
            accuracies['test'].append(acc)
            select_val_index = accuracies['val'].index(max(accuracies['val']))
            acc_val_final = accuracies['val'][select_val_index]
            cur_test_value = accuracies['test'][select_val_index]

            print('Eval | Epochs {}: val {}, test {}, cur max test {}'.format(epoch, accuracies['val'][-1], accuracies['test'][-1], cur_test_value))


    # emb_posneg_final, y_posneg_final = model.encoder.get_embeddings2(dataloader)

    # R_final = np.round(np.mean(R_e),3)
    with open('logs_new/' + args.DS+ '/' + args.DS + '_results.log', 'a+') as f:
        f.write('DS:{},Seed:{},Epochs:{},Batch_size:{},Rpos:{},Lr:{}, Acc_val:{}, Acc_test:{}\n'.format(
            args.DS, seed, epochs, batch_size, args.Rpos, lr, acc_val_final, cur_test_value))
        
        # f.write('Aug:{},DS:{},Seed:{},Epochs:{},Batch_size:{},Rpos:{},Lr:{}, R:{}\n'.format(
        #     args.aug, args.DS, seed, epochs, batch_size, args.Rpos, lr, R_final))
        # f.write('\n')


# plot the epoch losses
plt.figure(figsize=(6, 4))
plt.plot(range(1, epochs+1), Epoch_losses)
plt.xlabel('Epoch')
plt.xticks(range(1, epochs+1))
plt.ylabel('Loss')
plt.grid()
plt.title('MGCL loss vs Epoch, ' + args.DS )
plt.tight_layout()
plt.savefig('logs_new/' + args.DS + '/lossepochs_seed' + str(seed) + '.png')
plt.close()
