import pickle
import numpy as np
import torch
import os, sys

import time
import random
from tqdm import tqdm
import argparse
import logging

from gflownet import get_GFlowNet
from synthetic.synthetic_utils import plot_heat, plot_samples,\
    float2bin, bin2float, get_binmap, get_true_samples, get_ebm_samples, EnergyModel
from synthetic.synthetic_data import inf_train_gen, OnlineToyDataset



def get_logger(path=None, level="DEBUG"):
    logger = logging.getLogger()
    logger.setLevel(level)
    if path is None:
        logger.addHandler(logging.StreamHandler())
    else:
        logger.addHandler(logging.FileHandler(path, mode="a"))
    return logger


def makedirs(path):
    if not os.path.exists(path):
        print('creating dir: {}'.format(path))
        os.makedirs(path)
    else:
        print(path, "already exist!")


unif_dist = torch.distributions.Bernoulli(probs=0.5)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", "--d", default=0, type=int)
    # data
    parser.add_argument('--save_dir', type=str, default="./")
    parser.add_argument('--data', type=str, default='circles')  # 2spirals 8gaussians pinwheel circles moons swissroll checkerboard

    # training
    parser.add_argument('--n_iters', "--ni", type=lambda x: int(float(x)), default=1e5)
    parser.add_argument('--batch_size', "--bs", type=int, default=128)
    parser.add_argument('--print_every', "--pe", type=int, default=100)
    parser.add_argument('--eval_every', type=int, default=2000)
    parser.add_argument('--lr', type=float, default=.0001)
    parser.add_argument("--ebm_every", "--ee", type=int, default=1, help="EBM training frequency")

    # for GFN
    parser.add_argument("--type", type=str)
    parser.add_argument("--hid", type=int, default=512)
    parser.add_argument("--hid_layers", "--hl", type=int, default=3)
    parser.add_argument("--leaky", type=int, default=1, choices=[0, 1])
    parser.add_argument("--gfn_bn", "--gbn", type=int, default=0, choices=[0, 1])
    parser.add_argument("--init_zero", "--iz", type=int, default=0, choices=[0, 1], )
    parser.add_argument("--gmodel", "--gm", type=str,default="mlp")
    parser.add_argument("--train_steps", "--ts", type=int, default=1)
    parser.add_argument("--l1loss", "--l1l", type=int, default=0, choices=[0, 1], help="use soft l1 loss instead of l2")

    parser.add_argument("--with_mh", "--wm", type=int, default=0, choices=[0, 1])
    parser.add_argument("--rand_k", "--rk", type=int, default=0, choices=[0, 1])
    parser.add_argument("--lin_k", "--lk", type=int, default=0, choices=[0, 1])
    parser.add_argument("--warmup_k", "--wk", type=lambda x: int(float(x)), default=0, help="need to use w/ lin_k")
    parser.add_argument("--K", type=int, default=-1, help="for gfn back forth negative sample generation")

    parser.add_argument("--rand_coef", "--rc", type=float, default=0, help="for tb")
    parser.add_argument("--back_ratio", "--br", type=float, default=0.)
    parser.add_argument("--clip", type=float, default=-1., help="for gfn's linf gradient clipping")
    parser.add_argument("--temp", type=float, default=1)
    parser.add_argument("--opt", type=str, default="adam", choices=["adam", "sgd"])
    parser.add_argument("--glr", type=float, default=1e-3)
    parser.add_argument("--zlr", type=float, default=1)
    parser.add_argument("--momentum", "--mom", type=float, default=0.0)
    parser.add_argument("--gfn_weight_decay", "--gwd", type=float, default=0.0)

    parser.add_argument("--mixing_ratio", "--mr", type=float, default=0.5)
    parser.add_argument("--r_alpha", "--ra", type=float, default=1.0)
    args = parser.parse_args()

    train_start = time.time()
    

    # os.environ['CUDA_VISIBLE_DEVICES'] = "{:}".format(args.device)
    device = torch.device("cpu") if args.device < 0 else torch.device("cuda")

    # args.save_dir = os.path.join(args.save_dir, "test")
    makedirs(args.save_dir)
    logger = get_logger(os.path.join(args.save_dir, 'logs.txt'))
    logger.info(str(vars(args)))
    logger.info("Device:" + str(device))
    logger.info("Args:" + str(args))

    ############## Data
    discrete_dim = 32
    bm, inv_bm = get_binmap(discrete_dim, 'gray')

    db = OnlineToyDataset(args.data, discrete_dim)
    if not hasattr(args, "int_scale"):
        int_scale = db.int_scale
    else:
        int_scale = args.int_scale
    if not hasattr(args, "plot_size"):
        plot_size = db.f_scale
    else:
        db.f_scale = args.plot_size
        plot_size = args.plot_size
    # plot_size = 4.1

    batch_size = args.batch_size
    multiples = {'pinwheel': 5, '2spirals': 2}
    batch_size = batch_size - batch_size % multiples.get(args.data, 1)

    ############## EBM model
    # energy_model = EnergyModel(discrete_dim, 256).to(device)
    # optimizer = torch.optim.Adam(energy_model.parameters(), lr=args.lr)

    ############## GFN
    xdim = discrete_dim
    assert args.gmodel == "mlp"
    gfn = get_GFlowNet(args.type, xdim, args, device)

    # energy_model.to(device)
    # print("model: {:}".format(energy_model))

    itr = 0
    best_val_ll = -np.inf
    best_itr = -1
    lr = args.lr
    mixing_ratio = args.mixing_ratio
    r_alpha = args.r_alpha
    while itr < args.n_iters:
        st = time.time()

        x = get_true_samples(db, batch_size, bm, int_scale, discrete_dim).to(device)

        update_success_rate = -1.
        gfn.model.train()
        train_loss, train_logZ = gfn.train_gil(batch_size, mixing_ratio, r_alpha, silent=True, data=x)

        if (itr + 1) % args.eval_every == 0:

            # samples of gfn
            gfn_samples = gfn.sample(4000).detach()
            gfn_samp_float = bin2float(gfn_samples.data.cpu().numpy().astype(int), inv_bm, int_scale, discrete_dim)
            plot_samples(gfn_samp_float, os.path.join(args.save_dir, f'gfn_samples_{itr}.pdf'), lim=plot_size)

            # GFN LL
            gfn.model.eval()
            logps = []
            for _ in range(10):
                pos_samples_bs = get_true_samples(db, 1000, bm, int_scale, discrete_dim).to(device)
                logp = gfn.cal_logp(pos_samples_bs, 20)
                logps.append(logp.reshape(-1))
            gfn_test_ll = torch.cat(logps).mean()

            logger.info(f"Test NLL ({itr}): GFN: {-gfn_test_ll.item():.5f},\t Time: {time.time() - train_start}")

            if gfn_test_ll > best_val_ll:
                best_val_ll = gfn_test_ll
                best_itr = itr
                with open(os.path.join(args.save_dir, "best.pkl"), 'wb') as f:
                    pickle.dump({'itr':itr, 'gfn':gfn, }, f)

        itr += 1


    logger.info(f"Best NLL ({best_itr}): GFN: {-best_val_ll:.5f}")
    total_time = time.time() - train_start
    logger.info(f"Total time: {total_time}")


    with open(os.path.join(args.save_dir, "final.pkl"), 'wb') as f:
        data = gfn, gfn_samp_float, inv_bm, int_scale, discrete_dim, plot_size
        pickle.dump(data, f)


    quit(0)
