import utils
import torch as t, torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision as tv, torchvision.transforms as tr
import os
import sys
import argparse
import numpy as np
import wideresnet
import pdb

from resnet import ResNet18
from utils import add_edge_batch

from tqdm import tqdm
# Sampling
from tqdm import tqdm
t.backends.cudnn.benchmark = True
t.backends.cudnn.enabled = True
seed = 1
im_sz = 32
n_ch = 3
n_classes = 10


class F(nn.Module):
    def __init__(self, depth=28, width=2, norm=None):
        super(F, self).__init__()
        self.f = wideresnet.Wide_ResNet(depth, width, norm=norm)
        self.energy_output = nn.Linear(self.f.last_dim, 1)
        self.class_output = nn.Linear(self.f.last_dim, 10)

    def forward(self, x, y=None):
        penult_z = self.f(x)
        return self.energy_output(penult_z).squeeze()

    def classify(self, x):
        penult_z = self.f(x)
        return self.class_output(penult_z)


class CCF(F):
    def __init__(self, depth=28, width=2, norm=None):
        super(CCF, self).__init__(depth, width, norm=norm)

    def forward(self, x, y=None):
        logits = self.classify(x)
        if y is None:
            return logits.logsumexp(1)
        else:
            return t.gather(logits, 1, y[:, None])


def init_random(bs):
    return t.FloatTensor(bs, 3, 32, 32).uniform_(-1, 1)


def main(args):
    utils.makedirs(os.path.join(args.save_dir, f"single{args.origin}vs{args.target}"))
    if args.print_to_log:
        sys.stdout = open(f'{args.save_dir}/log.txt', 'w')

    t.manual_seed(seed)
    if t.cuda.is_available():
        t.cuda.manual_seed_all(seed)

    device = t.device('cuda' if t.cuda.is_available() else 'cpu')

    model_cls = F
    f = model_cls(args.depth, args.width, args.norm)
    classifier = ResNet18()
    print(f"loading model from {args.load_path}")

    # load em up
    ckpt_dict = t.load(args.load_path)
    f.load_state_dict(ckpt_dict["model_state_dict"])
    classifier.load_state_dict(t.load(os.path.join("saved_models", "svhn_adv_cnn.pt")))

    f = f.to(device)
    classifier = classifier.to(device)
    f.eval()
    classifier.eval()
    bs = args.batch_size

    init_sample = init_random(bs).to(device)
    x_k = t.autograd.Variable(init_sample, requires_grad=True)
    # y   = t.ones((bs, ), dtype=t.int64, device=device) * args.origin
    # sgld
    for k in tqdm(range(args.n_steps)):
        logits = classifier((x_k + 1.0) / 2.0)
        logsumexp = t.logsumexp(logits, dim=1)
        target_probs = t.exp(logits[:, args.target] - logsumexp)
        print(target_probs)
        success = t.max(logits, dim=1)[1] == args.target
        inf_tensor = t.full_like(logits, -t.inf)
        Z = t.scatter(logits, 1, t.tensor([args.target] * args.batch_size, device=device).unsqueeze(1), inf_tensor)
        # f_prime = t.autograd.grad((f(x_k, y=None) + 1e-1*logits[:, args.target] - (1e-1 + 1e-1)*logsumexp).sum(), [x_k], retain_graph=True)[0]
        # f_prime = t.autograd.grad((f(x_k, y=None) - 1e-1 * logsumexp - 1e-1*t.nn.functional.relu(t.max(Z, dim=1)[0] - logits[:, args.target])).sum(), [x_k], retain_graph=True)[0]
        f_prime = t.autograd.grad((f(x_k, y=None) - 1e-2*t.nn.functional.relu(t.max(Z, dim=1)[0] - logits[:, args.target])).sum(), [x_k], retain_graph=True)[0]
        x_k.data += args.sgld_lr * f_prime + args.sgld_std * t.randn_like(x_k)
        x_k.data = t.clamp(x_k.data, -1, 1)

        tv.utils.save_image(add_edge_batch((x_k.detach().cpu() + 1.0) / 2.0, success), os.path.join(args.save_dir, f"single{args.origin}vs{args.target}", f"{k}.eps"), nrow=8, padding=1, normalize=True, value_range=(0, 1))
        tv.utils.save_image(add_edge_batch((x_k.detach().cpu() + 1.0) / 2.0, success), os.path.join(args.save_dir, f"single{args.origin}vs{args.target}", f"{k}.png"), nrow=8, padding=1, normalize=True, value_range=(0, 1))
    final_samples = x_k.detach()
 
    print(final_samples.shape)
    



if __name__ == "__main__":
    parser = argparse.ArgumentParser("Energy Based Models")
    # optimization
    parser.add_argument("--batch_size", type=int, default=64)
    # regularization
    parser.add_argument("--sigma", type=float, default=3e-2)
    parser.add_argument("--norm", type=str, default=None, choices=[None, "norm", "batch", "instance", "layer", "act"])
    # EBM specific
    parser.add_argument("--n_steps", type=int, default=512)
    parser.add_argument("--width", type=int, default=10)
    parser.add_argument("--depth", type=int, default=28)
    parser.add_argument("--origin", type=int, default=0)
    parser.add_argument("--target", type=int, default=1)
    parser.add_argument("--buffer_size", type=int, default=0)
    parser.add_argument("--reinit_freq", type=float, default=.05)
    parser.add_argument("--sgld_lr", type=float, default=1.0)
    parser.add_argument("--sgld_std", type=float, default=1e-2)
    # logging + evaluation
    parser.add_argument("--save_dir", type=str, default='./YOUR_SAVE_PATH_BUDDDDDDYYYYYYY')
    parser.add_argument("--load_path", type=str, default=None)
    parser.add_argument("--print_to_log", action="store_true")

    args = parser.parse_args()
    main(args)

