import numpy as np
import torch
import torch.nn as nn
import torchvision
import os, sys
import copy
import time
import random
import ipdb
from tqdm import tqdm
import argparse
import network


from gflownet import get_GFlowNet
import utils_data


def makedirs(path):
    if not os.path.exists(path):
        print('creating dir: {}'.format(path))
        os.makedirs(path)
    else:
        print(path, "already exist!")

class EBM(nn.Module):
    def __init__(self, net, mean=None):
        super().__init__()
        self.net = net
        if mean is None:
            self.mean = None
        else:
            self.mean = nn.Parameter(mean, requires_grad=False)
            self.base_dist = torch.distributions.Bernoulli(probs=self.mean)

    def forward(self, x):
        if self.mean is None:
            bd = 0.
        else:
            bd = self.base_dist.log_prob(x).sum(-1)

        logp = self.net(x).squeeze()
        return logp + bd


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", "--d", default=0, type=int)
    
    # Seed
    parser.add_argument("--seed", default=10, type=int, help="seed")
    # data
    parser.add_argument('--save_dir', type=str, default="./")
    parser.add_argument('--data', type=str, default='dynamic_mnist')
    parser.add_argument("--down_sample", "--ds", default=0, type=int, choices=[0, 1])
    parser.add_argument('--ckpt_path', type=str, default=None)
    # models
    parser.add_argument('--model', type=str, default='mlp-256')
    parser.add_argument('--base_dist', "--bd", type=int, default=1, choices=[0, 1])
    parser.add_argument('--gradnorm', "--gn", type=float, default=0.0)
    parser.add_argument('--l2', type=float, default=0.0)
    parser.add_argument('--n_iters', "--ni", type=lambda x: int(float(x)), default=5e4)
    parser.add_argument('--batch_size', "--bs", type=int, default=100)
    parser.add_argument('--test_batch_size', type=int, default=100)
    parser.add_argument('--print_every', "--pe", type=int, default=100)
    parser.add_argument('--viz_every', "--ve", type=int, default=2000)
    parser.add_argument('--eval_every', type=int, default=2000)
    parser.add_argument('--lr', type=float, default=.0001)
    parser.add_argument("--ebm_every", "--ee", type=int, default=1, help="EBM training frequency")

    # for GFN
    parser.add_argument("--type", type=str)
    parser.add_argument("--hid", type=int, default=256)
    parser.add_argument("--hid_layers", "--hl", type=int, default=5)
    parser.add_argument("--leaky", type=int, default=1, choices=[0, 1])
    parser.add_argument("--gfn_bn", "--gbn", type=int, default=0, choices=[0, 1])
    parser.add_argument("--init_zero", "--iz", type=int, default=0, choices=[0, 1])
    parser.add_argument("--gmodel", "--gm", type=str, default="mlp")
    parser.add_argument("--train_steps", "--ts", type=int, default=1)
    parser.add_argument("--l1loss", "--l1l", type=int, default=0, choices=[0, 1], help="use soft l1 loss instead of l2")

    parser.add_argument("--with_mh", "--wm", type=int, default=0, choices=[0, 1])
    parser.add_argument("--rand_k", "--rk", type=int, default=0, choices=[0, 1])
    parser.add_argument("--lin_k", "--lk", type=int, default=0, choices=[0, 1])
    parser.add_argument("--warmup_k", "--wk", type=lambda x: int(float(x)), default=0, help="need to use w/ lin_k")
    parser.add_argument("--K", type=int, default=-1, help="for gfn back forth negative sample generation")

    parser.add_argument("--rand_coef", "--rc", type=float, default=0, help="for tb")
    parser.add_argument("--back_ratio", "--br", type=float, default=0.)
    parser.add_argument("--clip", type=float, default=-1., help="for gfn's linf gradient clipping")
    parser.add_argument("--temp", type=float, default=1)
    parser.add_argument("--opt", type=str, default="adam", choices=["adam", "sgd"])
    parser.add_argument("--glr", type=float, default=1e-3)
    parser.add_argument("--zlr", type=float, default=1e-1)
    parser.add_argument("--momentum", "--mom", type=float, default=0.0)
    parser.add_argument("--gfn_weight_decay", "--gwd", type=float, default=0.0)
    parser.add_argument('--mc_num', "--mcn", type=int, default=100)

    #OT Regularization
    parser.add_argument("--reg_coef", default=0.001, type=float, help="Coefficient for regularisation term for main objective loss")
    args = parser.parse_args()


    # seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    os.environ['CUDA_VISIBLE_DEVICES'] = "{:}".format(args.device)
    device = torch.device("cpu") if args.device < 0 else torch.device("cuda")

    args.device = device
    args.save_dir = os.path.join(args.save_dir, "test")
    makedirs(args.save_dir)

    print("Device:" + str(device))
    print("Args:" + str(args))

    before_load = time.time()
    train_loader, val_loader, test_loader, args = utils_data.load_dataset(args)
    plot = lambda p, x: torchvision.utils.save_image(x.view(x.size(0), args.input_size[0],
        args.input_size[1], args.input_size[2]), p, normalize=True, nrow=int(x.size(0) ** .5))
    print(f"It takes {time.time() - before_load:.3f}s to load {args.data} dataset.")

    def preprocess(data):
        if args.dynamic_binarization:
            return torch.bernoulli(data)
        else:
            return data

    if args.down_sample:
        assert args.model.startswith("mlp-")

    if args.model.startswith("mlp-"):
        nint = int(args.model.split('-')[1])
        net = network.mlp_ebm(np.prod(args.input_size), nint)
    elif args.model.startswith("cnn-"):
        nint = int(args.model.split('-')[1])
        net = network.MNISTConvNet(nint)
    elif args.model.startswith("resnet-"):
        nint = int(args.model.split('-')[1])
        net = network.ResNetEBM(nint)
    else:
        raise ValueError("invalid model definition")

    init_batch = []
    for x, _ in train_loader:
        init_batch.append(preprocess(x))
    init_batch = torch.cat(init_batch, 0)
    eps = 1e-2
    init_mean = init_batch.mean(0) * (1. - 2 * eps) + eps

    if args.base_dist:
        model = EBM(net, init_mean)
    else:
        model = EBM(net)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    xdim = np.prod(args.input_size)
    assert args.gmodel == "mlp"
    gfn = get_GFlowNet(args.type, xdim, args, device)
    model.to(device)
    print("model: {:}".format(model))

    itr = 0
    best_val_ll = -np.inf
    while itr < args.n_iters:
        for x in train_loader:
            st = time.time()
            x = preprocess(x[0].to(device))  #  -> (bs, 784)

            if args.gradnorm > 0:
                x.requires_grad_()

            update_success_rate = -1.
            assert "tb" in args.type
            train_loss, train_logZ = gfn.train(args.batch_size, scorer=lambda inp: model(inp).detach(),
                   silent=itr % args.print_every != 0, data=x, back_ratio=args.back_ratio)

            if args.rand_k or args.lin_k or (args.K > 0):
                if args.rand_k:
                    K = random.randrange(xdim) + 1
                elif args.lin_k:
                    K = min(xdim, int(xdim * float(itr + 1) / args.warmup_k))
                    K = max(K, 1)
                elif args.K > 0:
                    K = args.K
                else:
                    raise ValueError

                gfn.model.eval()
                x_fake, delta_logp_traj = gfn.backforth_sample(x, K)

                delta_logp_traj = delta_logp_traj.detach()
                if args.with_mh:
                    # MH step, calculate log p(x') - log p(x)
                    lp_update = model(x_fake).squeeze() - model(x).squeeze()
                    update_dist = torch.distributions.Bernoulli(logits=lp_update + delta_logp_traj)
                    updates = update_dist.sample()
                    x_fake = x_fake * updates[:, None] + x * (1. - updates[:, None])
                    update_success_rate = updates.mean().item()

            else:
                x_fake = gfn.sample(args.batch_size)

            if itr % args.ebm_every == 0:
                st = time.time() - st

                model.train()
                logp_real = model(x).squeeze()
                if args.gradnorm > 0:
                    grad_ld = torch.autograd.grad(logp_real.sum(), x,
                                      create_graph=True)[0].flatten(start_dim=1).norm(2, 1)
                    grad_reg = (grad_ld ** 2. / 2.).mean()
                else:
                    grad_reg = torch.tensor(0.).to(device)

                logp_fake = model(x_fake).squeeze()
                obj = logp_real.mean() - logp_fake.mean()
                l2_reg = (logp_real ** 2.).mean() + (logp_fake ** 2.).mean()
                loss = -obj + grad_reg * args.gradnorm + args.l2 * l2_reg

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if itr % args.print_every == 0 or itr == args.n_iters - 1:
                print("({:5d}) | ({:.3f}s/iter) |log p(real)={:.2e}, "
                     "log p(fake)={:.2e}, diff={:.2e}, grad_reg={:.2e}, l2_reg={:.2e} update_rate={:.1f}".format(itr, st,
                     logp_real.mean().item(), logp_fake.mean().item(), obj.item(), grad_reg.item(), l2_reg.item(), update_success_rate))

            if (itr + 1) % args.eval_every == 0:
                model.eval()
                print("GFN TEST")
                gfn.model.eval()

                gfn_val_ll = gfn.evaluate(val_loader, preprocess, args.mc_num)
                print("GFN Valid log-likelihood ({}) with {} samples: {}".format(itr, args.mc_num, gfn_val_ll.item()))

                gfn_test_ll = gfn.evaluate(test_loader, preprocess, args.mc_num)
                print("GFN Test log-likelihood ({}) with {} samples: {}".format(itr, args.mc_num, gfn_test_ll.item()))

                model.cpu()
                d = {}
                d['model'] = model.state_dict()
                d['optimizer'] = optimizer.state_dict()
                gfn_ckpt = {"model": gfn.model.state_dict(), "optimizer": gfn.optimizer.state_dict(),}
                gfn_ckpt["logZ"] = gfn.logZ.detach().cpu()
                torch.save(d, "{}/ckpt.pt".format(args.save_dir))
                torch.save(gfn_ckpt, "{}/gfn_ckpt.pt".format(args.save_dir))

                if gfn_val_ll.item() > best_val_ll:
                    best_val_ll = gfn_val_ll.item()
                    print("Best valid likelihood({}): {}".format(itr, gfn_val_ll.item()))
                    print("Test likelihood({}): {}".format(itr, gfn_test_ll.item()))
                    torch.save(d, "{}/best_ckpt.pt".format(args.save_dir))
                    torch.save(gfn_ckpt, "{}/best_gfn_ckpt.pt".format(args.save_dir))
                else:
                    torch.save(d, "{}/ckpt.pt".format(args.save_dir))
                    torch.save(gfn_ckpt, "{}/gfn_ckpt.pt".format(args.save_dir))

                model.to(device)

            itr += 1
            if itr > args.n_iters:
                print("Training finished!")
                quit(0)
