import os.path as osp
import wandb
import warnings
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch_geometric.utils import mask_feature, dropout_adj
from torch_geometric.loader import NeighborLoader

from dataset.process_datasets import get_pt_data, get_train_node_idx, WEIGHT
from model.encoder import Encoder, InnerProductDecoder
from model.pt_model import PretrainModel
from model.vq import VectorQuantize
from utils.args import get_args_pretrain
from utils.others import seed_everything, get_scheduler, get_device_from_model, check_path

warnings.filterwarnings("ignore")


def pretrain(model, loader, optimizer, params, scheduler):
    model.train()
    device = get_device_from_model(model)

    for data in tqdm(loader):
        if data.x.size(0) != data.node_text_feat.size(0):
            x = data.node_text_feat[data.x].to(device)
        else:
            x = data.node_text_feat.to(device)

        edge_index = data.edge_index.to(device)
        if edge_index.size(1) != data.edge_text_feat.size(0):
            edge_attr = data.edge_text_feat[data.xe].to(device)
        else:
            edge_attr = data.edge_text_feat.to(device)
        # edge_attr = data.edge_text_feat[data.xe].to(device)

        field = data.field.to(device)

        graph = [x, edge_index, edge_attr, field]

        aug_x, _ = mask_feature(x, p=params['feat_p'])  # default mode = "col"
        aug_edge_index, aug_edge_attr = dropout_adj(edge_index, edge_attr, p=params['edge_p'], force_undirected=True)
        aug_graph = [aug_x, aug_edge_index, aug_edge_attr]

        losses = model(aug_graph, graph, params['topo_recon_ratio'], params["pretrain_batch_size"])

        feat_recon_loss = params['feat_lambda'] * losses['feat_recon_loss']
        topo_recon_loss = params['topo_lambda'] * losses['topo_recon_loss']
        field_loss = params['field_lambda'] * losses['field_loss']
        contrastive_loss = params['contrastive_lambda'] * losses['contrastive_loss']

        loss = feat_recon_loss + topo_recon_loss + contrastive_loss + field_loss

        optimizer.zero_grad()

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler:
            scheduler.step()
        
        losses = {
            'losses/feat_recon_loss': feat_recon_loss.item(),
            'losses/topo_recon_loss': topo_recon_loss.item(),
            'losses/contrastive_loss': contrastive_loss.item(),
            'losses/field_loss': field_loss.item(),
            'losses/loss': loss.item(),
        }

        if params['wandb']:
            wandb.log(losses)


def run(params):
    seed_everything(params["seed"])
    
    device = torch.device(f"cuda:{params['gpu']}") if torch.cuda.is_available() else torch.device("cpu")
    params['activation'] = nn.ReLU if params['activation'] == 'relu' else nn.LeakyReLU

    pretrain_data = get_pt_data(
        params["data_path"], 
        params["pretrain_dataset"], 
        params["graph_llm_name"],
        params["llm_b_size"],
        params["root_path"]
    )

    encoder = Encoder(
        params["input_dim"],
        params["hidden_dim"],
        params["activation"],
        params["num_layers"],
        params["normalize"],
        params["dropout"],
    )

    vq = VectorQuantize(
        params["hidden_dim"],
        params["codebook_size"],
        params["num_expert"],
        params["codebook_heads"],
        params["topk"],
        params["kmeans_init"]
    )
    
    # path = osp.join(params['model_path'], "codebook_size_{}_layer_{}_pretrain_on_{}_seed_{}".format(
    #     params["codebook_size"], params["num_layers"], params["pretrain_dataset"], params['seed']
    # ))
    # encoder.load_state_dict(torch.load(osp.join(path, f'encoder_{params["pretrain_epochs"]}.pt')))
    # vq.load_state_dict(torch.load(osp.join(path, f'vq_{params["pretrain_epochs"]}.pt')))

    feat_recon_decoder = nn.Linear(params["hidden_dim"], params["input_dim"])
    topo_recon_decoder = InnerProductDecoder(params["hidden_dim"], params["hidden_dim"])

    pretrain_model = PretrainModel(
        encoder,
        vq,
        feat_recon_decoder,
        topo_recon_decoder
    ).to(device)

    # Optimizer
    optimizer = AdamW(pretrain_model.parameters(), lr=params["lr"], weight_decay=params["pretrain_weight_decay"])
    scheduler = get_scheduler(optimizer, params["use_schedular"], params["pretrain_epochs"])

    for i in range(1, params["pretrain_epochs"] + 1):
        # Loader
        # Define the loader inside the loop to enable weighted sampling.

        batch_size = params["pretrain_batch_size"]

        dataset = params["pretrain_dataset"]
        if isinstance(dataset, list):
            weights = []
            for s in dataset:
                weights.extend(WEIGHT[s].values())
        elif isinstance(dataset, str):
            weights = list(WEIGHT[dataset].values())

        train_nodes = get_train_node_idx(pretrain_data, weights)

        loader = NeighborLoader(pretrain_data,
                                input_nodes=train_nodes,
                                num_neighbors=[params["num_neighbors"]] * params["num_layers"],
                                batch_size=batch_size,
                                shuffle=True
        )

        print("Number of training nodes is {}".format(len(train_nodes)))
        print("Number of mini-batches is {} at epoch {}.".format(len(loader), i))

        # Pretrain
        pretrain(model=pretrain_model, loader=loader, optimizer=optimizer, params=params, scheduler=scheduler)

        # Save the model
        save_path = params['model_path']
        save_path = osp.join(save_path, 'codebook_size_{}_layer_{}_pretrain_on_{}_seed_{}'.format(
            params["codebook_size"], params["num_layers"], params["pretrain_dataset"], params['seed']))
        check_path(save_path)

        if i % 5 == 0:
            try:
                pretrain_model.save_encoder(osp.join(save_path, f"encoder_{i}.pt"))
                pretrain_model.save_vq(osp.join(save_path, f"vq_{i}.pt"))
                print("Save the model at epoch {}".format(i))
            except:
                print("Failed to save the model at epoch {}".format(i))

    if params['wandb']:
        wandb.finish()


if __name__ == "__main__":
    params = get_args_pretrain()

    if params['wandb']:
        wandb.init(
            project="GFM-Pretrain",
            name="Pretrain on {}".format(params["pretrain_dataset"]),
            mode='online',
            config=params
        )

    run(params)
