import gzip
import copy
import pickle
import os.path
import sys

import numpy as np

from tqdm import tqdm
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from options import args_parser
from model import vie

from torchvision.models import resnet18
from torchmetrics import StructuralSimilarityIndexMeasure

from utils.cifar10_dataset import CIFAR10withRepMask, CIFAR10SubsetWithRepMask
from utils.defense import defense

""" TESTING on CIFAR-10"""

def get_schedulers(scheduler, optimizer, milestones=[30,80], gamma=0.5, T_max=10, lr_mul=0.001, d_model=10, n_warmup_steps=5):
    if scheduler == "step":
        return torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=gamma)
    elif scheduler == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    elif scheduler == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def main():
    """Test SplitNN"""
    args = args_parser()
    lr = 1e-2
    epoch_max = 500
    epoch_finetune = 5
    bs = 32
    Q = 1
    num_class = 10

    train_size = 50000

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Training on device", device)

    if args.defense_method == "laplacian_noise":
        defense_level = args.laplacian_noise
    elif args.defense_method == "compress":
        defense_level = args.compress
    elif args.defense_method == "soteria":
        defense_level = args.compress
    else:
        defense_level = 0


    ckp = torch.load("/home/js905/code/cut_the_chain/saved_models_adv/cifar10_resnet18_lambda0_init.pt")
    #ckp = torch.load("/home/js905/code/cut_the_chain/saved_models_mi/NG/cifar10_resnet18_head3layer_NG{:.0f}.pt".format(args.laplacian_noise))

    net = resnet18()
    net.fc = nn.Linear(net.fc.in_features, num_class)

    decoder = nn.Sequential(
            # input is Z
            nn.ConvTranspose2d(64, 3, 8, 2, 3),
            nn.Sigmoid())

    class Flatten(nn.Module):
        def __init__(self):
            super(Flatten, self).__init__()
            
        def forward(self, x):
            x = x.view(x.size(0), -1)
            return x

    modules = list(net.children())[:-4]
    encoder = nn.Sequential(*modules)
    encoder.load_state_dict(ckp['encoder'])

    modules = list(encoder.children())[:3]
    head = nn.Sequential(*modules)
    modules = list(encoder.children())[3:]
    encoder = nn.Sequential(*modules)
    #encoder.load_state_dict(ckp['encoder'])
    #head.load_state_dict(ckp['head'])

    modules = list(net.children())[-4:-1]
    clf = nn.Sequential(*[*modules, Flatten(), list(net.children())[-1]])
    clf.load_state_dict(ckp['clf'])



    encoder.to(device)
    clf.to(device)
    head.to(device)
    decoder.to(device)




    '''
    if save_models_filename:
        if os.path.exists(save_models_filename):
            print(f"Restoring models from {save_models_filename}")
            data = torch.load(save_models_filename)
            # print("net1", net1.state_dict().keys())
            # print("data['net1']", data["net1"].keys())
            # net1.load_state_dict(data["net1"])
            # net2.load_state_dict(data["net2"])
            net1_a = data["net1_a"]
            net1_b = data["net1_b"]
            net2 = data["net2"]
    '''
    optim_head = optim.SGD(head.parameters(), lr=lr, momentum=0.9)
    optim_decoder = optim.SGD(decoder.parameters(), lr=lr, momentum=0.9)



    '''
    transform_train = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.ToPILImage(),
            #transforms.Pad(4, padding_mode="reflect"),
            #transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            #transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    transform_valid = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    '''
    transform_train = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    transform_valid = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    transform_train_attack = transforms.Compose(
         [transforms.ToTensor()])
    transform_valid_attack = transforms.Compose(
         [transforms.ToTensor()])

    tt = transforms.Compose(
         [transforms.ToPILImage()])

    train_dataset = CIFAR10withRepMask(
        args,
        root="../data/cifar10",
        dev_root="/home/js905/code/cut_the_chain/src/train_dataset_dev.npy",
        train=True,
        download=True,
        transform=transform_train,
    ) 


    valid_dataset = CIFAR10withRepMask(
        args,
        root="../data/cifar10",
        dev_root="/home/js905/code/cut_the_chain/src/valid_dataset_dev.npy",
        train=False,
        download=True,
        transform=transform_valid,
    )

    train_dataset_attack = CIFAR10SubsetWithRepMask(
        args,
        root="../data/cifar10",
        dev_root="/home/js905/code/cut_the_chain/src/train_dataset_attack_dev.npy",
        train=True,
        download=True,
        transform=transform_train_attack,
        num_sample=args.num_mc_sample
    ) 


    valid_dataset_attack = CIFAR10withRepMask(
        args,
        root="../data/cifar10",
        dev_root="/home/js905/code/cut_the_chain/src/valid_dataset_attack_dev.npy",
        train=False,
        download=True,
        transform=transform_valid_attack
    )



    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=bs, shuffle=False, num_workers=2
    )

    train_loader_attack = torch.utils.data.DataLoader(
        train_dataset_attack, batch_size=bs, shuffle=True, num_workers=2
    )

    valid_loader_attack = torch.utils.data.DataLoader(
        valid_dataset_attack, batch_size=bs, shuffle=False, num_workers=2
    )


    def valid_rec(head, decoder, data_loader, device):
        #with torch.no_grad():
        mse_list = []
        for i, (inputs, labels, masks) in enumerate(data_loader):
            inputs, labels, masks = inputs.to(device), labels.to(device), masks.to(device)
            if args.defense_method == "soteria":
                inputs.requires_grad=True
            rep = head(inputs)
            # defense here
            rep_send = defense(args, rep, inputs, head, mask=masks)
            outputs = decoder(rep_send)
            mse = F.mse_loss(outputs, inputs)

            mse_list.append(mse.item())
        metric = sum(mse_list) / len(mse_list)
        return metric

    def valid_rec_ssim(head, decoder, data_loader, device):
        #with torch.no_grad():
        ssim_list = []
        for i, (inputs, labels, masks) in enumerate(data_loader):
            inputs, labels, masks = inputs.to(device), labels.to(device), masks.to(device)
            if args.defense_method == "soteria":
                inputs.requires_grad=True
            rep = head(inputs)
            # defense here
            rep_send = defense(args, rep, inputs, head, mask=masks)
            outputs = decoder(rep_send)
            evaluater = StructuralSimilarityIndexMeasure(data_range=1.0)
            ssim = evaluater(inputs, outputs)

            ssim_list.append(ssim.item())
        metric = sum(ssim_list) / len(ssim_list)
        return metric

    def valid(head, encoder, clf, data_loader, device, quit=False):
        encoder.eval()
        clf.eval()
 
        #with torch.no_grad():
        correct, total = 0, 0
        for i, (inputs, labels, masks) in enumerate(data_loader):
            inputs, labels, masks = inputs.to(device), labels.to(device), masks.to(device)
            if args.defense_method == "soteria":
                inputs.requires_grad=True
            rep = head(inputs)
            # defense here
            rep_send = defense(args, rep, inputs, head, mask=masks)
            outputs = clf(encoder(rep_send))
            _, pred_label = torch.max(outputs.data, 1)

            total += inputs.data.size()[0]
            correct += (pred_label == labels.data).sum().item()
        metric = correct / float(total)
        return metric

    best_acc = 0
    clf.eval()
    head.eval()
    encoder.eval()

    train_acc = valid(head, encoder, clf, train_loader, device)
    val_acc = valid(head, encoder, clf, valid_loader, device)
    print(f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}")
    for e in range(epoch_max):
        mse_list = []
        #pbar = tqdm(enumerate(train_loader))
        for batch_idx, (inputs, targets, masks) in enumerate(train_loader_attack):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)
            if args.defense_method == "soteria":
                inputs.requires_grad=True
            
            optim_decoder.zero_grad()
            optim_head.zero_grad()

            rep = head(inputs)

            

            # defense here
            rep_send = defense(args, rep, inputs, head, mask=masks)
            
            
            rec_img = decoder(rep_send)
            loss = F.mse_loss(rec_img, inputs)
            loss.backward()
            optim_decoder.step()

            mse_list.append(loss.item())



        val_mse = valid_rec(head, decoder, valid_loader_attack, device)
        val_ssim = valid_rec_ssim(head, decoder, valid_loader, device)
        print(
            f"Epoch: {e}. train_mse: {sum(mse_list)/len(mse_list):.4f}, val_mse: {val_mse:.4f}, val_ssim: {val_ssim:.4f}"
        )

        for batch_idx, (inputs, targets, masks) in enumerate(train_loader_attack):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)
            if args.defense_method == "soteria":
                inputs.requires_grad = True
            
            rep = head(inputs)

            # defense here
            rep_send = defense(args, rep, inputs, head, mask=masks)

            rec_img = decoder(rep_send)

            rec_img = rec_img.cpu()
            inputs = inputs.cpu()
            plt.figure(figsize=(2, 12))
            for i in range(10):
                plt.subplot(10, 2, i*2+1)
                plt.imshow(tt(inputs[i]))
                plt.axis('off')
                plt.subplot(10, 2, i*2+2)
                plt.imshow(tt(rec_img[i]))
                plt.axis('off')

            plt.savefig("img/head3layer_subset/mi_1layer_train_{}_{}_{}.jpg".format(args.num_mc_sample, args.defense_method, defense_level))
            # plt.savefig("mi_1layer_train.jpg")
            break

        for batch_idx, (inputs, targets, masks) in enumerate(valid_loader_attack):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)
            if args.defense_method == "soteria":
                inputs.requires_grad=True
            
            rep = head(inputs)

            # defense here
            rep_send = defense(args, rep, inputs, head, mask=masks)

            rec_img = decoder(rep_send)

            rec_img = rec_img.cpu()
            inputs = inputs.cpu()
            plt.figure(figsize=(2, 12))
            for i in range(10):
                plt.subplot(10, 2, i*2+1)
                plt.imshow(tt(inputs[i]))
                plt.axis('off')
                plt.subplot(10, 2, i*2+2)
                plt.imshow(tt(rec_img[i]))
                plt.axis('off')

            plt.savefig("img/head3layer_subset/mi_1layer_valid_{}_{}_{}.jpg".format(args.num_mc_sample, args.defense_method, defense_level))
            # plt.savefig("mi_1layer_valid.jpg")
            break
        #torch.save({"head": head.state_dict(), "decoder": decoder.state_dict()}, args.save_models_filename)




if __name__ == "__main__":
    main()
