import os
import argparse
import numpy as np
import torch
from torch.utils import data
from guided_diffusion.image_datasets import *
from models import get_classifier
import torch.optim as optim
from tqdm import tqdm

import os
import yaml
import copy
import math
import random
import argparse
import itertools
import numpy as np
import os.path as osp
import matplotlib.pyplot as plt

from PIL import Image
from time import time
from os import path as osp
from multiprocessing import Pool

import torch

from torch.utils import data
from torch.nn import functional as F

from torchvision import transforms
from torchvision import datasets

# Diffusion Model imports
from guided_diffusion import dist_util
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    diffusion_defaults,
    create_model_and_diffusion,
    create_gaussian_diffusion,
    create_classifier,
    args_to_dict,
    add_dict_to_argparser,
)
from guided_diffusion.sample_utils import (
    get_DiME_iterative_sampling,
    clean_class_cond_fn,
    dist_cond_fn,
    ImageSaver,
    SlowSingleLabel,
    load_from_DDP_model,
    ChunkedDataset,
)
from guided_diffusion.gaussian_diffusion import _extract_into_tensor
from guided_diffusion.image_datasets_label0 import get_dataset, BINARYDATASET, MULTICLASSDATASETS

# core imports
from core.utils import print_dict, merge_all_chunks, generate_mask
from core.metrics import accuracy, get_prediction
from core.attacks_and_models import JointClassifierDDPM, get_attack

# model imports
from models import get_classifier
from utils_vae import *

import matplotlib
matplotlib.use('Agg')  # to disable display

def create_args():
    defaults = dict(
        clip_denoised=True,                  # Clipping noise
        batch_size = 100,                       # Batch size
        gpu='0',                             # GPU index, should only be 1 gpu
        save_images=True,                    # Saving all images
        num_samples=9000,            # useful to sample few examples
        cudnn_deterministic=True,           # setting this to true will slow the computation time but will have identic results when using the checkpoint backwards

        # path args
        model_path='./celebahq-ddpm.pt',                       # DDPM weights path
        classifier_path='./checkpoint.tar',                  # Classifier weights path
        modelS_path = '',
        output_path='results',               # Output path
        output_dir = './modelS_output',
        exp_name='exp',                      # Experiment name (will store the results at Output/Results/exp_name)

        # attack args
        seed=0,                              # Random seed 
        attack_method='PGD',                 # Attack method (currently 'PGD', 'C&W', 'GD' and 'None' supported)
        attack_iterations=50,                # Attack iterations updates
        attack_epsilon=255,                  # L inf epsilon bound (will be devided by 255)
        attack_step=1.0,                     # Attack update step (will be devided by 255)
        attack_joint=True,                   # Set to false to generate adversarial attacks
        attack_joint_checkpoint=False,       # use checkpoint method for backward. Beware, this will substancially slow down the CE generation!
        attack_checkpoint_backward_steps=1,  # number of DDPM iterations per backward process. We highly recommend have a larger backward steps than batch size (e.g have 2 backward steps and batch size of 1 than 1 backward step and batch size 2)
        attack_joint_shortcut=False,         # Use DiME shortcut to transfer gradients. We do not recommend it.

        # dist args
        dist_l1=0.0,                         # l1 scaling factor
        dist_l2=0.0,                         # l2 scaling factor
        dist_schedule='none',                # schedule for the distance loss. We did not used any for our results

        # filtering args
        sampling_time_fraction=0.1,          # fraction of noise steps (e.g. 0.1 for 1000 smpling steps would be 100 out of 1000)
        sampling_stochastic=True,            # Set to False to remove the noise when sampling
        
        # post processing
        sampling_inpaint=0.0,               # Inpainting threshold
        sampling_dilation=15,                # Dilation size for the mask generation

        # query and target label
        label_query=20,                      # Query label to target
        # label = 39,
        label_target=-1,                     # Target label, useful for MultiClass datasets

        # dataset
        image_size=256,                      # Dataset image size
        data_dir="./CelebAMask-HQ",                         # Path to Dataset
        dataset='CelebAHQ',                  # Target Dataset (ImageNet, CelebA, CelebAMV, CelebAHQ, BDDOIA and BDD100k available)
        chunks=1,                            # Chunking for spliting the CE generation into multiple gpus
        chunk=0,                             # current chunk (between 0 and chunks - 1)
        merge_chunks=False,                  # to merge all chunked results
        
        ratio = 1.0,
        alpha = 1.0,
        num_classes_for_modelS = 2,
        num_classes = 2,
        adv_ratio = 1.0,
        num_epochs = 31,
        base_lr = 0.001,
        decay_lr1 = 60,
        decay_lr2 = 90,
        momentum = 0.9,
        weight_decay = 2e-4,
        device = 'cuda:0',

)
    
    defaults.update(model_and_diffusion_defaults())
    defaults.update(
        diffusion_steps = 500,
        attention_resolutions = '32,16,8',
        class_cond = False,
        learn_sigma = True,
        noise_schedule = 'linear',
        num_channels = 128,
        num_head_channels = 64,
        num_heads = 4,
        num_res_blocks = 2,
        resblock_updown = True,
        use_fp16 = True,
        use_scale_shift_norm = True,
        image_size =256,
        sampling_stochastic=True,
        sampling_time_fraction = 0.1,
    )
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser.parse_args()

# args = create_args()

class PredYWithS(nn.Module):
    def __init__(self, feat_dim, num_classes=2):
        super(PredYWithS, self).__init__()
        latent_dim = feat_dim // 2
        self.feat_dim = feat_dim
        self.num_classes = num_classes

        pred = [nn.Linear(feat_dim, latent_dim), nn.BatchNorm1d(latent_dim), nn.ReLU(),
                nn.Linear(latent_dim, feat_dim), nn.ReLU()]
        fc = [nn.Linear(feat_dim, num_classes, bias=True)]
        self.pred = nn.Sequential(*pred)
        self.fc = nn.Sequential(*fc)

    def forward(self, x):
        x = self.pred(x)
        x = self.fc(x)
        return x


def main():
    args = create_args()
    set_deterministic(args.seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    os.makedirs(args.output_dir, exist_ok=True)
    train_dataset = get_dataset(args, is_train=True)
    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
                                shuffle=True, 
                                num_workers=4, pin_memory=True)
    test_dataset = get_dataset(args, is_train=False)
    test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size,
                                shuffle=True, 
                                num_workers=4, pin_memory=True)
    K = 16
    L = 112
    lr = 5e-4
    b1 = 0.9
    b2 = 0.999
    train_steps = 1000
    bce_loss = nn.BCEWithLogitsLoss().to(args.device)
    ce_loss = nn.CrossEntropyLoss().to(args.device)

    encoder = Linear_Encoder(1024, K+L)
    checkpoint = torch.load('./VAE_model/checkpoint_1000.pth', map_location='cpu')
    encoder.load_state_dict(checkpoint['encoder'])
    encoder.to(args.device)
    encoder.eval()

    respaced_steps = int(args.sampling_time_fraction * 50)
    model, respaced_diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()
    
    classifier = get_classifier(args)
    joint_classifier = JointClassifierDDPM(classifier=classifier,
                                        ddpm=model, diffusion=respaced_diffusion,
                                        steps=respaced_steps,
                                        stochastic=args.sampling_stochastic)
    joint_classifier.to(args.device).eval()

    modelS = PredYWithS(L, num_classes=args.num_classes).to(args.device)

    loader_iter = iter(train_loader)
    # opt = optim.SGD(modelS.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    opt = optim.Adam(modelS.parameters(), lr=lr, betas=(b1, b2))
    # print(args.label_query)
    for k in range(0, train_steps):
        modelS.train()
        train_loss = 0
        train_acc = 0
        try:
            x, y = next(loader_iter)
        except StopIteration:
            loader_iter = iter(train_loader)  
            x, y = next(loader_iter)
        x, y = x.to(args.device), y.to(args.device)
        # for idx, (x, y) in enumerate(tqdm(train_loader, desc="Training", total=len(train_loader))):
        f = joint_classifier.classifier.classifier.feat_extract(x)
        _, mu, _ = encoder(f)
        beta = mu[:, K:]
        zs = modelS(beta.detach().clone())

        loss = ce_loss(zs, y.long())
        opt.zero_grad()
        loss.backward()
        opt.step()

        preds = zs.argmax(dim=1) 
        targets = y.long()
        acc = 100. * (preds == targets).sum().item() / y.size(0)

            # train_loss += loss.detach().cpu().numpy()
            # train_acc +=  acc

        # train_loss /= len(train_loader)
        # train_acc /= len(train_loader)

        print(f"Step: {k}  "
            f"loss: {loss}  "
            f"acc{acc}  ")
        if (k + 1) % 200 == 0:
            torch.save(modelS.state_dict(), osp.join(args.output_dir, f'modelS_{k+1}.pth'))
            print(f"✅ Saved modelS at step {k+1}")
        
if __name__ == '__main__':
    main()