import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import wandb

from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import load_dataset
from data_utils import preprocess
from model import ModelE
from setup_utils import load_train_yaml, set_seed

def main(args):
    exp_name = f"{args.dataset}_E"
    yaml_data = load_train_yaml(exp_name)

    config_df = pd.json_normalize(yaml_data, sep='/')
    T = yaml_data['diffusion']['T']
    wandb.init(
        project=f"{exp_name}-0830",
        name=f"E T{T}",
        config=config_df.to_dict(orient='records')[0])

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_g = load_dataset(args.dataset)
    X_one_hot_3d, Y, E_one_hot,\
        X_marginal, Y_marginal, E_marginal, X_cond_Y_marginals = preprocess(train_g)

    # (|V|, d, 2)
    X_one_hot_2d = torch.transpose(X_one_hot_3d, 0, 1)
    # (|V|, 2 * d)
    X_one_hot_2d = X_one_hot_2d.reshape(X_one_hot_2d.size(0), -1)
    X_one_hot_2d = X_one_hot_2d.to(device)
    Y = Y.to(device)
    E_one_hot = E_one_hot.to(device)

    X_marginal = X_marginal.to(device)
    Y_marginal = Y_marginal.to(device)
    E_marginal = E_marginal.to(device)
    X_cond_Y_marginals = X_cond_Y_marginals.to(device)

    num_nodes = Y.size(0)
    dst, src = torch.triu_indices(num_nodes, num_nodes, offset=1, device=device)
    edge_index = torch.stack([dst, src], dim=1)

    # Set seed for reproducibility.
    set_seed()

    train_config = yaml_data["train"]
    data_loader = DataLoader(edge_index.cpu(), batch_size=train_config["batch_size"],
                             shuffle=True, num_workers=4)
    val_data_loader = DataLoader(edge_index, batch_size=train_config["val_batch_size"],
                                 shuffle=False)

    model = ModelE(X_marginal=X_marginal,
                   Y_marginal=Y_marginal,
                   E_marginal=E_marginal,
                   num_nodes=num_nodes,
                   X_cond_Y_marginals=X_cond_Y_marginals,
                   gnn_E_config=yaml_data["gnn_E"],
                   **yaml_data["diffusion"]).to(device)

    optimizer = torch.optim.AdamW(model.parameters(),
                                  **yaml_data["optimizer_E"])
    lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', **yaml_data["lr_scheduler"])

    best_epoch_E = 0
    best_val_nll_E = float('inf')
    best_log_p_0_E = float('inf')
    best_denoise_match_E = float('inf')

    num_patient_epochs = 0
    for epoch in range(train_config["num_epochs"]):
        model.train()

        for batch_edge_index in tqdm(data_loader):
            batch_edge_index = batch_edge_index.to(device)
            batch_dst, batch_src = batch_edge_index.T
            loss_E = model.log_p_t(E_one_hot,
                                   X_one_hot_2d,
                                   Y,
                                   batch_src,
                                   batch_dst,
                                   E_one_hot[batch_dst, batch_src])

            optimizer.zero_grad()
            loss_E.backward()
            nn.utils.clip_grad_norm_(
                model.parameters(), train_config["max_grad_norm"])
            optimizer.step()

            wandb.log({"train/loss_E": loss_E.item()})

        if (epoch + 1) % train_config["val_every_epochs"] != 0:
            continue

        model.eval()

        num_patient_epochs += 1

        denoise_match_E = []
        log_p_0_E = []
        for batch_edge_index in tqdm(val_data_loader):
            batch_dst, batch_src = batch_edge_index.T
            batch_denoise_match_E, batch_log_p_0_E = model.val_step(
                E_one_hot,
                X_one_hot_2d,
                Y,
                batch_src,
                batch_dst,
                E_one_hot[batch_dst, batch_src])
            denoise_match_E.append(batch_denoise_match_E)
            log_p_0_E.append(batch_log_p_0_E)

        denoise_match_E = np.mean(denoise_match_E)
        log_p_0_E = np.mean(log_p_0_E)
        val_E = denoise_match_E + log_p_0_E

        if val_E < best_val_nll_E:
            best_val_nll_E = val_E
            best_epoch_E = epoch
            os.makedirs(f"{args.dataset}_cpts", exist_ok=True)
            model_path = f"{args.dataset}_cpts/{exp_name}_model_T{T}.pth"
            torch.save({
                "dataset": args.dataset,
                "train_yaml_data": yaml_data,
                "best_val_nll": best_val_nll_E,
                "best_epoch_E": best_epoch_E,
                "model_state_dict": model.state_dict()
            }, model_path)
            print(f'model saved to {model_path}')

        if log_p_0_E < best_log_p_0_E:
            best_log_p_0_E = log_p_0_E
            num_patient_epochs = 0

        if denoise_match_E < best_denoise_match_E:
            best_denoise_match_E = denoise_match_E
            num_patient_epochs = 0

        wandb.log({"epoch": epoch,
                   "val/denoise_match_kl_E": denoise_match_E,
                   "val/log_p_0_E": log_p_0_E,
                   "val/best_log_p_0_E": best_log_p_0_E,
                   "val/best_denoise_match_E": best_denoise_match_E,
                   "val/best_val_E": best_val_nll_E})

        print("Epoch {} | best val E nll {:.7f} | patience {}/{}".format(
            epoch, best_val_nll_E, num_patient_epochs,
            train_config["patient_epochs"]))

        if num_patient_epochs == train_config["patient_epochs"]:
            break

        lr_scheduler.step(log_p_0_E)

    wandb.finish()

if __name__ == '__main__':
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--dataset", type=str,
                        choices=["amazon_photo", "amazon_computer",
                                 "cora", "citeseer"],
                        help="Dataset name.")
    args = parser.parse_args()

    main(args)
