import gzip
import copy
import pickle
import os.path
import sys

import numpy as np

from tqdm import tqdm
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F

from options import args_parser
from model import vie
import matplotlib.pyplot as plt

from torchvision.models import resnet18
from torchmetrics import StructuralSimilarityIndexMeasure

from utils.cifar10_dataset import CIFAR10SubSet
from utils.cifar100_dataset import CIFAR100SubSet
from utils.attack import TV, l2loss

""" TESTING on CIFAR-10"""

def get_schedulers(scheduler, optimizer, milestones=[30,80], gamma=0.5, T_max=10, lr_mul=0.001, d_model=10, n_warmup_steps=5):
    if scheduler == "step":
        return torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=gamma)
    elif scheduler == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    elif scheduler == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def main():
    """Test SplitNN"""
    args = args_parser()
    lr = 1e-2
    epoch_max = 500
    epoch_finetune = 5
    bs = 32
    Q = 1
    if args.dataset == 'cifar100':
        num_class = 100
    else:
        num_class = 10

    train_size = 50000

    if args.defense_method == "laplacian_noise":
        defense_level = args.laplacian_noise
    elif args.defense_method == "compress":
        defense_level = args.compress
    elif args.defense_method == "soteria":
        defense_level = args.compress
    elif args.defense_method == "adv":
        defense_level = args.l_adv
    elif args.defense_method == "MID":
        defense_level = args.MID_rate
    else:
        defense_level = 0

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Training on device", device)

    ckp = torch.load("/home/js905/code/cut_the_chain/saved_models_mi/{}/MID/cifar10_resnet18_head3layer_MID{}_res_softplus30.pt".format(args.dataset, args.MID_rate))

    net = resnet18()
    net.fc = nn.Linear(net.fc.in_features, num_class)




    modules = list(net.children())[:args.num_head_layer]
    head = nn.Sequential(*modules)
    head.load_state_dict(ckp['head'])

    vae = vie(feature_volume = 64*8*8, z_size = 64*8*8)
    vae.load_state_dict(ckp['vae'])


    head.to(device)
    vae.to(device)



    '''
    if save_models_filename:
        if os.path.exists(save_models_filename):
            print(f"Restoring models from {save_models_filename}")
            data = torch.load(save_models_filename)
            # print("net1", net1.state_dict().keys())
            # print("data['net1']", data["net1"].keys())
            # net1.load_state_dict(data["net1"])
            # net2.load_state_dict(data["net2"])
            net1_a = data["net1_a"]
            net1_b = data["net1_b"]
            net2 = data["net2"]
    '''
    optim_head = optim.SGD(head.parameters(), lr=lr, momentum=0.9)
    optim_vae = optim.SGD(vae.parameters(), lr=lr, momentum=0.9)


    '''
    transform_train = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.ToPILImage(),
            #transforms.Pad(4, padding_mode="reflect"),
            #transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            #transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    transform_valid = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    '''
    # transform_train = transforms.Compose(
    #      [transforms.ToTensor(),
    #       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    # transform_valid = transforms.Compose(
    #      [transforms.ToTensor(),
    #       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    transform_train = transforms.Compose(
         [transforms.ToTensor()])
    transform_valid = transforms.Compose(
         [transforms.ToTensor()])
    
    tt = transforms.Compose(
         [transforms.ToPILImage()])

    # train_dataset = datasets.CIFAR10(
    #     root="../data/cifar10",
    #     train=True,
    #     download=True,
    #     transform=transform_train,
    # ) 

    if args.dataset == "cifar100":

        train_dataset = datasets.CIFAR100(
            root="../data/cifar100",
            train=True,
            download=True,
            transform=transform_train,
        )


        valid_dataset = datasets.CIFAR100(
            root="../data/cifar100",
            train=False,
            download=True,
            transform=transform_valid,
        )

    else:

        train_dataset = datasets.CIFAR10(
            root="../data/cifar10",
            train=True,
            download=True,
            transform=transform_train,
        )


        valid_dataset = datasets.CIFAR10(
            root="../data/cifar10",
            train=False,
            download=True,
            transform=transform_valid,
        )



    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=bs, shuffle=False, num_workers=2
    )

    


    head.eval()
    vae.eval()
    evaluater = StructuralSimilarityIndexMeasure(data_range=1.0)
    for e in range(epoch_max):
        mse_list = []
        ssim_list = []
        
        #pbar = tqdm(enumerate(train_loader))
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets = inputs.to(device), targets.to(device)
            print(inputs.shape)
            
            optim_head.zero_grad()
            optim_vae.zero_grad()
            with torch.no_grad():

                rep_gt = head(inputs)
                (mu, std), rep_gt = vae(rep_gt)

            dummy_data = torch.randn(inputs.size())
            dummy_data = torch.Tensor(dummy_data).to(device).requires_grad_(True)
            optimizer = torch.optim.LBFGS([dummy_data])

            for it in range(5000):
                def closure():
                    optim_head.zero_grad()
                    optimizer.zero_grad()
                    optim_vae.zero_grad()
                    rep_dummy = head(dummy_data)
                    #(mu, std), rep_dummy = vae(rep_dummy)
                    
                    
                    TVLoss = TV(dummy_data)
                    normLoss = l2loss(dummy_data)
                    mseLoss = ((rep_dummy - rep_gt)**2).mean()
                    loss = mseLoss + args.lambda_TV * TVLoss + args.lambda_l2 * normLoss
                    loss.backward()
                    return loss

                optimizer.step(closure)

                if (it+1)%100 == 0:
                    mse_img = F.mse_loss(F.relu(dummy_data), inputs)
                    ssim = evaluater(inputs, F.relu(dummy_data))
                    print("batch {}, iteration {}: mse {}, ssim {}".format(batch_idx, it, mse_img.item(), ssim.item()))
            mse_list.append(mse_img.item())
            ssim_list.append(ssim.item())


            rec_img = dummy_data.cpu()
            inputs = inputs.cpu()
            plt.figure(figsize=(2, 12))
            for i in range(10):
                plt.subplot(10, 2, i*2+1)
                plt.imshow(tt(inputs[i]))
                plt.axis('off')
                plt.subplot(10, 2, i*2+2)
                plt.imshow(tt(rec_img[i]))
                plt.axis('off')

            plt.savefig("img_mle/{}/head3layer/mle_train_{}_{}.jpg".format(args.dataset, args.defense_method, defense_level))
            break
        break



if __name__ == "__main__":
    main()
