import gzip
import copy
import pickle
import os.path
import sys

import numpy as np

import matplotlib.pyplot as plt

from tqdm import tqdm
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F

from options import args_parser
from model import vie

from torchvision.models import resnet18

from utils.cifar10_dataset import CIFAR10SubSet

""" TESTING on CIFAR-10"""

def get_schedulers(scheduler, optimizer, milestones=[30,80], gamma=0.5, T_max=10, lr_mul=0.001, d_model=10, n_warmup_steps=5):
    if scheduler == "step":
        return torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=gamma)
    elif scheduler == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    elif scheduler == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def main():
    """Test SplitNN"""
    args = args_parser()
    lr = 1e-2
    epoch_max = 500
    epoch_finetune = 5
    bs = 32
    Q = 1
    num_class = 10

    train_size = 50000

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Training on device", device)

    #ckp = torch.load("/home/js905/code/cut_the_chain/saved_models_mi/3layer_decoder_no_vae.pt")

    ckp = torch.load("/home/js905/code/cut_the_chain/saved_models_mi/3layer_decoder_NG_{:.0f}.pt".format(args.laplacian_noise))


    net = resnet18()
    net.fc = nn.Linear(net.fc.in_features, num_class)

    # decoder = nn.Sequential(
    #         # input is Z
    #         nn.ConvTranspose2d(64, 3, 8, 2, 3),
    #         nn.Sigmoid())

    decoder = nn.Sequential(
            # input is Z
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 3, 8, 2, 3),
            nn.Sigmoid())


    modules = list(net.children())[:5]
    head = nn.Sequential(*modules)



    head.load_state_dict(ckp["head"])
    decoder.load_state_dict(ckp["decoder"])
    #vae.load_state_dict(ckp["vae"])


    head.to(device)
    decoder.to(device)





    transform_train = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    transform_valid = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    # transform_train = transforms.Compose(
    #      [transforms.ToTensor()])
    # transform_valid = transforms.Compose(
    #      [transforms.ToTensor()])

    tt = transforms.Compose(
         [transforms.Normalize((-1, -1, -1), (2, 2, 2)),
          transforms.ToPILImage()])
    # tt = transforms.ToPILImage()

    train_dataset = datasets.CIFAR10(
        root="../data/cifar10",
        train=True,
        download=True,
        transform=transform_train,
    ) 



    valid_dataset = datasets.CIFAR10(
        root="../data/cifar10",
        train=False,
        download=True,
        transform=transform_valid,
    )



    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=bs, shuffle=False, num_workers=2
    )

    def valid_rec(head, decoder, data_loader, device):
        with torch.no_grad():
            mse_list = []
            for i, (inputs, labels) in enumerate(data_loader): 
                inputs, labels = inputs.to(device), labels.to(device)
                rep = head(inputs)
                # defense here
                if args.defense_method == "laplacian_noise":
                    # add laplacian noise with scale args.laplacian_scale
                    rep_send = rep+ (
                        torch.distributions.laplace.Laplace(0, args.laplacian_noise)
                        .sample(rep.shape)
                        .to(device)
                    )
                outputs = decoder(rep_send)
                mse = F.mse_loss(outputs, inputs)

                mse_list.append(mse.item())
            metric = sum(mse_list) / len(mse_list)
        return metric

    

    head.eval()

    val_mse = valid_rec(head, decoder, valid_loader, device)
    print(
        f"val_mse: {val_mse:.4f}"
    )

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        #pbar.set_description("Epoch {}".format(e+1))

        inputs, targets = inputs.to(device), targets.to(device)
        
        rep = head(inputs)

        # defense here
        if args.defense_method == "laplacian_noise":
            # add laplacian noise with scale args.laplacian_scale
            rep_send = rep+ (
                torch.distributions.laplace.Laplace(0, args.laplacian_noise)
                .sample(rep.shape)
                .to(device)
            )

        rec_img = decoder(rep_send)

        rec_img = rec_img.cpu()
        inputs = inputs.cpu()
        plt.figure(figsize=(2, 12))
        for i in range(10):
            plt.subplot(10, 2, i*2+1)
            plt.imshow(tt(inputs[i]))
            plt.axis('off')
            plt.subplot(10, 2, i*2+2)
            plt.imshow(tt(rec_img[i]))
            plt.axis('off')

        plt.savefig("mi_1layer_train_{}_{}.jpg".format(args.defense_method, args.laplacian_noise))
        # plt.savefig("mi_1layer_train.jpg")
        break

    for batch_idx, (inputs, targets) in enumerate(valid_loader):
        #pbar.set_description("Epoch {}".format(e+1))

        inputs, targets = inputs.to(device), targets.to(device)
        
        rep = head(inputs)

        # defense here
        if args.defense_method == "laplacian_noise":
            # add laplacian noise with scale args.laplacian_scale
            rep_send = rep+ (
                torch.distributions.laplace.Laplace(0, args.laplacian_noise)
                .sample(rep.shape)
                .to(device)
            )

        rec_img = decoder(rep_send)

        rec_img = rec_img.cpu()
        inputs = inputs.cpu()
        plt.figure(figsize=(2, 12))
        for i in range(10):
            plt.subplot(10, 2, i*2+1)
            plt.imshow(tt(inputs[i]))
            plt.axis('off')
            plt.subplot(10, 2, i*2+2)
            plt.imshow(tt(rec_img[i]))
            plt.axis('off')

        plt.savefig("mi_1layer_valid_{}_{}.jpg".format(args.defense_method, args.laplacian_noise))
        # plt.savefig("mi_1layer_valid.jpg")
        break









if __name__ == "__main__":
    main()
