import torch
import logging
from data_utils import CustomDataset
from torch_geometric.utils import negative_sampling
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import networkx as nx
from data_utils import heat_diffusion
from torch_geometric.data import Data


def get_encoder_trainer(model, config, G):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    if config.encoder_model == "node2vec":
        train_cfg = config.encoder.node2vec.train
        return Node2VecTrainer(model, train_cfg)
    elif config.encoder_model == "vae":
        train_cfg = config.encoder.vae.train
        L = nx.laplacian_matrix(G).toarray()
        heats = heat_diffusion(L, num_signals=16, diffusion_rate=0.02, sample_from="normal")
        edge_list = list(G.edges())
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        x = torch.ones((G.number_of_nodes(), 1), dtype=torch.float)
        x = torch.tensor(heats.T, dtype=torch.float)

        data = Data(x=x, edge_index=edge_index, num_node=G.number_of_nodes()).to(device)
        return VAETrainer(model, train_cfg, data)
    
def get_decoder_trainer(model, config, embeddings=None):
    if config.decoder_model == "dot_product":
        return DotProductTrainer(model)
    
    elif config.decoder_model == "mlp":
        device = "cuda" if torch.cuda.is_available() else "cpu"
        train_cfg = config.decoder.mlp.train
        return MLPDecoderTrainer(model, train_cfg, embeddings)
    


class MLPDecoderTrainer:
    def __init__(self, model, train_config, embeddings):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.model = model.to(self.device)
        self.embeddings = embeddings.to(self.device)
        self.epochs = train_config.epochs

        lr = train_config.lr
        wd = train_config.weight_decay
        self.batch_size = train_config.batch_size
        self.neg_ratio = train_config.neg_ratio
        self.log_each = train_config.log_each

        if train_config.optimizer == "adam":
            if train_config.tune_embeddings == True:
                self.optimizer = torch.optim.Adam(([
                    {'params': self.model.embeddings, 'lr':train_config.tune_lr},
                    {'params': self.model.fc1.parameters(), 'lr':lr},
                    {'params': self.model.fc2.parameters(), 'lr':lr}
                ]), weight_decay=wd)
            else:
                self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=wd)
        elif train_config.optimizer == "adamw":
            if train_config.tune_embeddings == True:
                self.optimizer = torch.optim.AdamW(([
                    {'params': self.model.embeddings, 'lr':train_config.tune_lr},
                    {'params': self.model.fc1.parameters(), 'lr':lr},
                    {'params': self.model.fc2.parameters(), 'lr':lr}
                ]), weight_decay=wd)
            else:
                self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=wd)
        
        self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.5]).to(self.device))

    def train(self, data):
        edge_index = data.edge_index
        edges = edge_index.t().numpy()

        train_edges, test_edges = train_test_split(edges, test_size=0.2, random_state=42)
        train_edge_index = torch.tensor(train_edges).t().contiguous()
        test_edges_index = torch.tensor(test_edges).t().contiguous()

        for e in range(self.epochs):
            self.model.train()
            total_loss = 0.0
            train_loader = self.create_dataloader(train_edge_index, data.num_nodes, batch_size=self.batch_size, neg_ratio=self.neg_ratio)    
            self.optimizer.zero_grad()
            
            for batch in train_loader:
                edge_index_batch, labels = batch
                edge_index_batch = edge_index_batch.T.to(self.device)
                labels = labels.to(self.device)

                self.optimizer.zero_grad()
                out = self.model(edge_index_batch, self.embeddings)
                loss = self.criterion(out, labels)
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
            
            if e % self.log_each == 0:
                test_loader = self.create_dataloader(test_edges_index, data.num_nodes, batch_size=self.batch_size, neg_ratio=self.neg_ratio, shuffle=False)
                total_test_loss = 0.0
                self.model.eval()
                all_preds = []
                all_labels = []
                total_samples = 0

                with torch.no_grad():
                    for test_edge_batch, test_labels in test_loader:
                        test_edge_batch = test_edge_batch.T.to(self.device)
                        test_labels = test_labels.float().to(self.device)
                        preds = self.model(test_edge_batch, self.embeddings)
                        total_samples += labels.shape[0]

                        preds_binary = (preds > 0.5).int().detach().cpu().numpy().tolist()
                        all_preds.extend(preds_binary)
                        all_labels.extend(test_labels.detach().cpu().numpy().tolist())

                        test_loss = self.criterion(preds, test_labels)
                        total_test_loss += test_loss.item()

                avg_loss = total_loss / total_samples  # Compute average loss
                acc = accuracy_score(all_labels, all_preds)
                f1 = f1_score(all_labels, all_preds)
                precision = precision_score(all_labels, all_preds, zero_division=np.nan)
                recall = recall_score(all_labels, all_preds)

                print(f"Epoch: {e}, Train Loss: {total_loss:.5f} | Test Loss: {total_test_loss:.4f} | Accuracy: {acc:.5f} | F1: {f1:.5f} | Precision: {precision:.5f} | Recall: {recall:.5f}")

        
        return self.model

    def create_dataloader(self, edge_index, num_nodes, batch_size=64, neg_ratio=1, shuffle=True):
        dataset = CustomDataset(edge_index, num_nodes, neg_ratio)
        if shuffle == False: #evaluation
            batch_size = len(dataset)
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

        

class DotProductTrainer:
    def __init__(self, model):
        self.model = model

    def train(self, data):
        return self.model


class Node2VecTrainer:
    def __init__(self, model, config):

        self.device = "cuda" if config.cuda == 1 else "cpu"
        self.model = model.to(self.device)
        self.lr = config.lr
        self.batch_size = config.batch_size
        self.patience = config.patience
        self.epochs = config.epochs
        self.log_each = config.log_each
        self.model = self.model.to(self.device)

    def train(self):
        print(f"Training is started..")
        optimizer = torch.optim.SparseAdam(list(self.model.parameters()), lr=self.lr)
        loader = self.model.loader(batch_size=self.batch_size, shuffle=True)

        for epoch in range(self.epochs):
            self.model.train()
            total_loss = 0

            for pos_rw, neg_rw in loader:
                optimizer.zero_grad()
                pos_rw = pos_rw.to(self.device)
                neg_rw = neg_rw.to(self.device)
                loss = self.model.loss(pos_rw, neg_rw)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            epoch_loss = total_loss / len(loader)

            if epoch % self.log_each == 0:
                #logging.info(f"Epoch: {epoch}, Loss: {epoch_loss:.4f}")
                print(f"Epoch: {epoch}, Loss: {epoch_loss:.4f}")

                #print()

        print(f"End of the training..")

    def get_final_embeddings(self):
        return self.model.get_embeddings().detach().cpu()
    

class VAETrainer:
    def __init__(self, model, config, data):

        self.device = "cuda" if config.cuda == 1 else "cpu"
        self.model = model.to(self.device)
        self.lr = config.lr
        self.batch_size = config.batch_size
        self.epochs = config.epochs
        self.log_each = config.log_each
        self.model = self.model.to(self.device)
        self.data = data

        if config.optimizer == "adam":
            self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        elif config.optimizer == "adamw":
            self.optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
            
    def train(self):
        print(f"Training is started..")

        for epoch in range(self.epochs):
            self.model.train()
            self.optimizer.zero_grad()
            z = self.model.encode(self.data.x, self.data.edge_index)
            loss = self.model.recon_loss(z, self.data.edge_index) + (1 / self.data.num_nodes) * self.model.kl_loss()
            loss.backward()
            self.optimizer.step()

            if epoch % self.log_each == 0:
                #logging.info(f"Epoch: {epoch}, Loss: {epoch_loss:.4f}")
                print(f"Epoch: {epoch}, Loss: {loss.item():.4f}")

                #print()

        print(f"End of the training..")

    def get_final_embeddings(self):
        return self.model.get_embeddings(self.data).detach().cpu()