"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""

import argparse
import os

import numpy as np
import torch as th
import torch
import torch.distributed as dist
import torchvision
import scipy.stats


import sys
sys.path.append('code')

from architectures import get_architecture
from datasets import get_dataset

try:
    from diffusion.improved_diffusion import dist_util, logger
    from diffusion.improved_diffusion.script_util import (
        NUM_CLASSES,
        model_and_diffusion_defaults,
        create_model_and_diffusion,
        add_dict_to_argparser,
        args_to_dict,
    )
except ImportError:
    from improved_diffusion import dist_util, logger
    from improved_diffusion.script_util import (
        NUM_CLASSES,
        model_and_diffusion_defaults,
        create_model_and_diffusion,
        add_dict_to_argparser,
        args_to_dict,
    )
#from pycallcc import *
def wrap(x): return x

from PIL import Image

class Args:
    image_size=32
    num_channels=128
    num_res_blocks=3
    num_heads=4
    num_heads_upsample=-1
    attention_resolutions="16,8"
    dropout=0.0
    learn_sigma=True
    sigma_small=False
    class_cond=False
    diffusion_steps=4000
    noise_schedule="cosine"
    timestep_respacing=range(0,600,25)
    use_kl=False
    predict_xstart=False
    rescale_timesteps=True
    rescale_learned_sigmas=True
    use_checkpoint=False
    use_scale_shift_norm=True

@wrap
#@only_once
def setup():
    global model, diffusion, classifier, dataset
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(Args(), model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(
        th.load("code/diffusion/cifar10_uncond_50M_500K.pt")
    )
    model.cuda()

    checkpoint = torch.load("models/cifar10/resnet110/noise_0.00/checkpoint.pth.tar")
    if type(checkpoint) is dict and "arch" in checkpoint.keys(): #model is from original smoothing repo
        classifier = get_architecture(checkpoint["arch"], "cifar10")
        classifier.load_state_dict(checkpoint['state_dict'])

    # classifier = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet56", pretrained=True)
    classifier.eval().cuda()

    dataset = torchvision.datasets.CIFAR10(
        root="/tmp", train=False, download=True
    )

    dataset = get_dataset('cifar10', 'test')

def img(x):
    x = np.array((np.clip(x, -1, 1)+1)*127.5, dtype=np.uint8)
    return Image.fromarray(x.transpose((1,2,0)))

def classify(x_start):
    t = np.where(diffusion.sqrt_one_minus_alphas_cumprod/diffusion.sqrt_alphas_cumprod > .21)[0][0]
    # print("T is", t)
    
    t_batch = th.tensor([t] * len(x_start)).cuda()
    
    noise = th.randn_like(x_start)

    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.201]
    
    x_t = diffusion.q_sample(x_start=x_start, t=t_batch, noise=noise)

    # print((x_start-x_t).std())
    if False:
        out = diffusion.p_sample(
            model,
            x_t,
            t_batch,
            clip_denoised=False
        )
    else:
        for i in reversed(range(0, t+1)):
            t_batch = th.tensor([i] * len(x_start)).cuda()
            with th.no_grad():
                out = diffusion.p_sample(
                    model,
                    x_t,
                    t_batch,
                    clip_denoised=False,
                )
                x_t = out["sample"]
    

    data = torchvision.transforms.Normalize(mean, std)((out['pred_xstart']+1)/2)

    return classifier(data)

def forward(x): # takes [-1,1] images
    assert -1.01 < x.min().item() < -.8
    assert .8 < x.max().item() < 1.01
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.201]
    data = torchvision.transforms.Normalize(mean, std)((x+1)/2)
    return classifier(data)

@wrap
def attack(x, y):
    x = (x + 1) / 2
    xo = x.cpu().detach().numpy().copy()
    for _ in range(100):
        x = torch.tensor(x, requires_grad=True)
        logits = classify(x.cuda()*2-1)
        loss = logits[torch.arange(len(x)).cuda(), y]
        print(loss.sum())
        loss.sum().backward()
        update = x.grad.cpu().detach().numpy()
        update /= np.sum(update**2,axis=(1,2,3),keepdims=True)**.5
        x = x.cpu().detach().numpy() - update * .06

        update = xo - x
        update /= np.sum(update**2,axis=(1,2,3),keepdims=True)**.5

        x = xo - update * 1.0

        #print(np.sum((x-xo)**2,axis=(1,2,3))**.5)
        
        x = np.clip(x, 0, 1)
    return torch.tensor(x).cuda()*2-1

@wrap
def defend():
    global x

    BS = 256
    z = []
    for i in range(0, BS*10, BS):
        x_start = dataset.data[i:i+BS]
        x_start = th.tensor(x_start.transpose((0,3,1,2))).cuda()/127.5-1

        # Uncomment to run attack
        #x_in = attack(x_start, torch.tensor(dataset.targets[i:i+BS]).cuda())

        # Here for clean accuracy
        x_in = x_start
        
        for n in [3]:
            out = [classify(x_in).cpu().detach().numpy() for _ in range(n)]

            guess = scipy.stats.mode(np.stack(out).argmax(2)).mode[0]
            maj = np.mean(guess == dataset.targets[i:i+BS])

            out = np.sum(out,0)
            asr = np.mean(out.argmax(1) == dataset.targets[i:i+BS])
            print('adding', asr, 'majority', maj)
            
        z.append(asr)
    print("==> overall cifar test accuracy", np.mean(z))
    
class Diffusion_Defense_Model(torch.nn.Module):
    def __init__(self, classifier, timestep_respacing=None):
        super().__init__()
        args_class = Args()
        if timestep_respacing:
            args_class.timestep_respacing = timestep_respacing
        model, diffusion = create_model_and_diffusion(
        **args_to_dict(args_class, model_and_diffusion_defaults().keys())
        )
        model.load_state_dict(
            th.load("code/diffusion/cifar10_uncond_50M_500K.pt")
        )
        model.cuda()

        # classifier = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet56", pretrained=True)
        # classifier.eval().cuda()
        classifier.eval().cuda()

        # dataset = torchvision.datasets.CIFAR10(
        #     root="/tmp", train=False, download=True
        # )
        self.model = model
        self.diffusion = diffusion
        self.classifier = classifier

        # self.eval()

    # simple forward function
    # def forward(self, x):
    #     x_start = x
    #     x_t = self.diffusion.q_sample(x_start=x_start)
    #     out = self.diffusion.p_sample(
    #         self.model,
    #         x_t,
    #         x_t.shape[0],
    #         clip_denoised=True
    #     )
    #     return out
    
    def forward(self, x):
        t = np.where(self.diffusion.sqrt_one_minus_alphas_cumprod/self.diffusion.sqrt_alphas_cumprod > .21)[0][0]
        # print("T is", t)
        x_start = x*2.-1.
        
        t_batch = th.tensor([t] * len(x_start)).cuda()


        noise = th.randn_like(x_start)
        t_batch = th.tensor([t] * x_start.shape[0]).cuda()

        x_t = self.diffusion.q_sample(x_start=x_start, t=t_batch, noise=noise)
        if False:
            out = self.diffusion.p_sample(
                self.model,
                x_t,
                t_batch,
                clip_denoised=False
            )
        else:
            for i in reversed(range(0, t+1)):
                t_batch = th.tensor([i] * len(x_start)).cuda()
                # with th.no_grad():
                out = self.diffusion.p_sample(
                    self.model,
                    x_t,
                    t_batch,
                    clip_denoised=False,
                )
                x_t = out["sample"]

        # mean = [0.4914, 0.4822, 0.4465]
        # std = [0.2023, 0.1994, 0.201]
        # data = torchvision.transforms.Normalize(mean, std)((out['pred_xstart']+1)/2) #TODO: is this needed? I don't think so because the classifier we use already has a normalize layer that uses these same numbers
        data = (out['pred_xstart']+1.)/2.

        return self.classifier(data)
    
if __name__ == "__main__":
    setup()
    defend()
    
