import sys, utils, argparse, wideresnet
import torch as t, torch.nn as nn, torchvision as tv
from tqdm import tqdm

seed = 1; in_ch = 3; m_sz = 32; n_classes = 10; num_groups = 1
class F(nn.Module):
    def __init__(self, depth=28, width=2, norm=None):
        super(F, self).__init__()
        self.f = wideresnet.Wide_ResNet(depth, width, norm=norm)
        self.energy_output = nn.Linear(self.f.last_dim, 1)
        self.class_output = nn.Linear(self.f.last_dim, 10)

    def forward(self, x, y=None):
        penult_z = self.f(x)
        return self.energy_output(penult_z).squeeze()

    def classify(self, x):
        penult_z = self.f(x)
        return self.class_output(penult_z)

class CCF(F):
    def __init__(self, depth=28, width=2, norm=None):
        super(CCF, self).__init__(depth, width, norm=norm)

    def forward(self, x, y=None):
        logits = self.classify(x)
        return logits

def init_random(bs):
    return t.FloatTensor(bs, 3, 32, 32).uniform_(-1, 1)

def leader_pulling(args, x, output, lambda_p = 0.1):
    batch_size = x.size(0) // (num_groups * args.num_particles_per_group)
    particles_per_group = num_groups * args.num_particles_per_group
    x_reshaped = x.view(batch_size, num_groups, args.num_particles_per_group, -1)
    m, n = t.meshgrid( t.tensor(range(batch_size)), t.tensor(range(num_groups)) )

    log_energy_reshaped = output.view(batch_size, num_groups, args.num_particles_per_group)
    local_leaders_ranks = log_energy_reshaped.argmax(dim=2)
    local_leaders_x_reshaped = x_reshaped[m, n, local_leaders_ranks[m, n], :].unsqueeze(2)
    local_log_energy_reshaped = log_energy_reshaped[m, n, local_leaders_ranks[m, n]].unsqueeze(2) # (num_nodes, num_groups, 1)
    global_leader_rank = local_log_energy_reshaped.argmax(dim=1).squeeze()
    l = local_leaders_x_reshaped[range(batch_size), global_leader_rank].view(batch_size, *x.shape[1:])

    l = l.repeat_interleave(particles_per_group, dim=0)
    worker_energy_reshaped = log_energy_reshaped.exp() / (log_energy_reshaped.exp().sum(dim=2, keepdim=True) + 1e-8)
    leader_energy_reshaped, _ = worker_energy_reshaped.max(dim=2, keepdim=True)
    pulling_scale = leader_energy_reshaped - worker_energy_reshaped
    pulling_strength = (lambda_p * pulling_scale).view([x.size(0)] + [1 for _ in range(x.dim()-1)])
    # pulling_strength = lambda_p
    return pulling_strength, l - x

def l_sample_q(i, x_k, args, device, f, y=None):
    f.eval()
    x_k = t.autograd.Variable(x_k, requires_grad=True)
    # lambda_p = min(0.005 * (0.9 ** i), 0.01) # sgld
    for _ in range(args.n_steps):
        output     = f(x_k, y = y)
        prob       = t.gather(output, 1, y[:, None]).view(-1) # output.logsumexp(1)
        f_prime    = t.autograd.grad(prob.sum(), [x_k], retain_graph=True)[0]
        p_strength, p_distance = leader_pulling(args, x_k.data, prob, lambda_p = args.lambda_p)
        x_k.data += p_strength * p_distance
        x_k.data += args.sgld_lr * f_prime + args.sgld_std * t.randn_like(x_k)

    final_samples = x_k.detach()
    return final_samples

def lhmc_sample_q(i, x_k, m_k, args, device, f, y=None):
    f.eval()
    x_k = t.autograd.Variable(x_k, requires_grad=True)
    for _ in range(args.n_steps):
        output     = f(x_k, y = y)
        prob       = t.gather(output, 1, y[:, None]).view(-1) # - output.sum(dim=1)
        f_prime    = t.autograd.grad(prob.sum(), [x_k], retain_graph=True)[0]
        p_strength, p_distance = leader_pulling(args, x_k.data, prob, lambda_p = args.lambda_p)
        grad = p_strength * p_distance + f_prime
        logp_pp = prob.clone() # -0.5 * (m_k ** 2).sum(dim=[1, 2, 3])

        x_k_star   = x_k.data + (args.sgld_lr ** 2) * grad + args.sgld_lr * m_k
        x_k_star.requires_grad_()

        output     = f(x_k_star, y = y)
        prob       = t.gather(output, 1, y[:, None]).view(-1) # - output.sum(dim=1)
        f_prime    = t.autograd.grad(prob.sum(), [x_k_star], retain_graph=True)[0]
        p_strength, p_distance = leader_pulling(args, x_k_star.data, prob, lambda_p = args.lambda_p)
        grad_star  = p_strength * p_distance + f_prime
        
        # update momentum by gradient
        m_k_star = m_k.data + 0.5 * args.sgld_lr * grad + 0.5 * args.sgld_lr * grad_star
        logp_p = prob.clone() # -0.5 * (m_k_star ** 2).sum(dim=[1, 2, 3])

        mask = (t.rand_like(logp_p) < t.exp(logp_p - logp_pp)).float().view(-1, 1, 1, 1)
        m_k.data = (1 - mask) * m_k + mask * m_k_star
        # print(mask.sum())
        
        # update momentum with resampling
        m_k.data = args.partial_refresh * m_k.data + args.sgld_std * t.randn_like(m_k.data)
        
        # bounded x_k
        x_k.data = ((1 - mask) * x_k + mask * x_k_star).clamp(-1., 1.)

    final_samples = x_k.detach()
    return final_samples

def plot(args, p, x):
    sqrt = lambda x: int(t.sqrt(t.Tensor([x])))
    tv.utils.save_image(t.clamp(x, -1, 1), p, normalize=True, nrow=sqrt(x.size(0)))
    t.save(x, '{}/{}_{}.pth'.format(args.save_dir, args.eval, classes[args.m_class]))

def lmc_cond_samples(f, args, device, save=True):
    x_k = init_random(args.total_batch_size).to(device)
    y = t.zeros(args.total_batch_size, device = device).long() + int(args.m_class)

    for i in tqdm(range(args.n_sample_steps + 1)):
        x_k = l_sample_q(i, x_k, args, device, f, y)
        if i % args.print_every == 0 and save:
            if args.lambda_p > 0.0:
                # plot(args, f"{args.save_dir}/{args.eval}_all_{classes[args.m_class]}_{i}.png", x_k)
                plot(args, f"{args.save_dir}/{args.eval}_{classes[args.m_class]}_{i}.png", x_k[[i for i in range(len(x_k)) if i % args.num_particles_per_group == 0]])
            else:
                plot(args, f"{args.save_dir}/{args.eval}_{classes[args.m_class]}_{i}.png", x_k)
    return x_k

def lhmc_cond_samples(f, args, device, save=True):
    x_k = init_random(args.total_batch_size).to(device)
    m_k = init_random(args.total_batch_size).to(device) * 0.01
    y = t.zeros(args.total_batch_size, device = device).long() + int(args.m_class)

    for i in tqdm(range(args.n_sample_steps + 1)):
        # m_k = m_k * math.sqrt(0.95) + init_random(args.batch_size).to(device) * math.sqrt(0.05)
        x_k = lhmc_sample_q(i, x_k, m_k, args, device, f, y)
        if i % args.print_every == 0 and save:
            if args.lambda_p > 0.0:
                # plot(args, f"{args.save_dir}/{args.eval}_all_{classes[args.m_class]}_{i}.png", x_k)
                plot(args, f"{args.save_dir}/{args.eval}_{classes[args.m_class]}_{i}.png", x_k[[i for i in range(len(x_k)) if i % args.num_particles_per_group == 0]])
            else:
                plot(args, f"{args.save_dir}/{args.eval}_{classes[args.m_class]}_{i}.png", x_k)
    return x_k

def main(args):
    utils.makedirs(args.save_dir)
    if args.print_to_log:
        sys.stdout = open(f'{args.save_dir}/log.txt', 'w')

    t.manual_seed(seed)
    if t.cuda.is_available():
        t.cuda.manual_seed_all(seed)

    device = t.device('cuda' if t.cuda.is_available() else 'cpu')

    model_cls = F if args.uncond else CCF
    f = model_cls(args.depth, args.width, args.norm)
    # print(f"loading model from {args.load_path}")

    # load em up
    ckpt_dict = t.load(args.load_path)
    f.load_state_dict(ckpt_dict["model_state_dict"])
    f.to(device)
    print('Now proceeding to:', args.eval, classes[args.m_class])

    if args.eval == "mc_cond_samples":
        args.lambda_p = 0.0
        lmc_cond_samples(f, args, device)

    elif args.eval == "lmc_cond_samples":
        lmc_cond_samples(f, args, device)
    
    elif args.eval == 'hmc_cond_samples':
        args.lambda_p = 0.0
        lhmc_cond_samples(f, args, device)

    elif args.eval == 'lhmc_cond_samples':
        lhmc_cond_samples(f, args, device)


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Energy Based Models and Shit")
    parser.add_argument("--eval", default="OOD", type=str)
    parser.add_argument("--score_fn", default="px", type=str,
                        choices=["px", "py", "pxgrad"], help="For OODAUC, chooses what score function we use.")
    parser.add_argument("--ood_dataset", default="svhn", type=str,
                        choices=["svhn", "cifar_interp", "cifar_100", "celeba"],
                        help="Chooses which dataset to compare against for OOD")
    parser.add_argument("--dataset", default="cifar_test", type=str,
                        choices=["cifar_train", "cifar_test", "svhn_test", "svhn_train"],
                        help="Dataset to use when running test_clf for classification accuracy")
    parser.add_argument("--datasets", nargs="+", type=str, default=[],
                        help="The datasets you wanna use to generate a log p(x) histogram")
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--norm", type=str, default=None, choices=[None, "norm", "batch", "instance", "layer", "act"])

    # EBM specific
    parser.add_argument("--n_steps", type=int, default=0)
    parser.add_argument("--m_class", type=int, default=0)
    parser.add_argument("--width", type=int, default=10)
    parser.add_argument("--depth", type=int, default=28)
    parser.add_argument("--uncond", action="store_true")
    parser.add_argument("--buffer_size", type=int, default=16)
    parser.add_argument("--reinit_freq", type=float, default=.05)
    parser.add_argument("--lambda_p", type=float, default=0.0)
    parser.add_argument("--sgld_lr", type=float, default=1.0)
    parser.add_argument("--sgld_std", type=float, default=1e-2)
    parser.add_argument("--partial_refresh", type=float, default=0.9)

    # Logging + Evaluation
    parser.add_argument("--save_dir", type=str, default='YOUR_SAVE_PATH')
    parser.add_argument("--print_every", type=int, default=100)
    parser.add_argument("--n_sample_steps", type=int, default=100)
    parser.add_argument("--load_path", type=str, default='CIFAR10_MODEL.pt')
    parser.add_argument("--print_to_log", action="store_true")
    parser.add_argument("--num_particles_per_group", type=int, default=4)
    
    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    args = parser.parse_args()
    args.save_dir = f"results-{classes[args.m_class]}/{args.save_dir}-{classes[args.m_class]}"
    import os; os.makedirs(args.save_dir, exist_ok=True)
    if args.lambda_p > 0.0:
        args.total_batch_size = args.batch_size * args.num_particles_per_group
    else:
        args.total_batch_size = args.batch_size
    main(args)