import gzip
import copy
import pickle
import os.path
import sys

import numpy as np

from tqdm import tqdm
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from options import args_parser
from model import vie

from torchvision.models import resnet18

from utils.cifar10_dataset import CIFAR10SubSet

""" TESTING on CIFAR-10"""

def get_schedulers(scheduler, optimizer, milestones=[30,80], gamma=0.5, T_max=10, lr_mul=0.001, d_model=10, n_warmup_steps=5):
    if scheduler == "step":
        return torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=gamma)
    elif scheduler == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    elif scheduler == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def main():
    """Test SplitNN"""
    args = args_parser()
    lr = 1e-2
    epoch_max = 20
    epoch_finetune = 5
    bs = 32
    Q = 1
    num_class = 10

    train_size = 50000

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Training on device", device)

    ckp = torch.load("/home/js905/code/cut_the_chain/saved_models_mi/MID/cifar10_resnet18_head3layer_MID0.0001_res_softplus30.pt")
    #ckp = torch.load("/home/js905/code/cut_the_chain/saved_models_mi/NG/cifar10_resnet18_head5layer_NG10.pt")

    net = resnet18()
    net.fc = nn.Linear(net.fc.in_features, num_class)

    decoder = nn.Sequential(
            # input is Z
            nn.ConvTranspose2d(64, 3, 8, 2, 3),
            nn.Sigmoid())

    class Flatten(nn.Module):
        def __init__(self):
            super(Flatten, self).__init__()
            
        def forward(self, x):
            x = x.view(x.size(0), -1)
            return x



    modules = list(net.children())[:3]
    head = nn.Sequential(*modules)

    vae = vie(feature_volume = 64*8*8, z_size = 64*8*8)

    head.load_state_dict(ckp["head"])
    vae.load_state_dict(ckp["vae"])

    modules = list(net.children())[-4:-1]
    clf = nn.Sequential(*[*modules, Flatten(), list(net.children())[-1]])
    clf.load_state_dict(ckp['clf'])


    head.to(device)
    decoder.to(device)
    vae.to(device)



    clf.to(device)
    head.to(device)
    decoder.to(device)




    '''
    if save_models_filename:
        if os.path.exists(save_models_filename):
            print(f"Restoring models from {save_models_filename}")
            data = torch.load(save_models_filename)
            # print("net1", net1.state_dict().keys())
            # print("data['net1']", data["net1"].keys())
            # net1.load_state_dict(data["net1"])
            # net2.load_state_dict(data["net2"])
            net1_a = data["net1_a"]
            net1_b = data["net1_b"]
            net2 = data["net2"]
    '''
    optim_head = optim.SGD(head.parameters(), lr=lr, momentum=0.9)
    optim_decoder = optim.SGD(decoder.parameters(), lr=lr, momentum=0.9)



    '''
    transform_train = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.ToPILImage(),
            #transforms.Pad(4, padding_mode="reflect"),
            #transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            #transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    transform_valid = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    '''
    # transform_train = transforms.Compose(
    #      [transforms.ToTensor(),
    #       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    # transform_valid = transforms.Compose(
    #      [transforms.ToTensor(),
    #       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    transform_train = transforms.Compose(
         [transforms.ToTensor()])
    transform_valid = transforms.Compose(
         [transforms.ToTensor()])

    # tt = transforms.Compose(
    #      [transforms.Normalize((-1, -1, -1), (2, 2, 2)),
    #       transforms.ToPILImage()])
    tt = transforms.Compose(
        [transforms.ToPILImage()])

    train_dataset = datasets.CIFAR10(
        root="../data/cifar10",
        train=True,
        download=True,
        transform=transform_train,
    ) 

    mc_train_dataset = CIFAR10SubSet(
        root="../data/cifar10",
        train=True,
        download=True,
        transform=transform_train,
        returns="all",
        num_sample=args.num_mc_sample
    )


    valid_dataset = datasets.CIFAR10(
        root="../data/cifar10",
        train=False,
        download=True,
        transform=transform_valid,
    )



    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=bs, shuffle=False, num_workers=2
    )

    def valid_rec(head, decoder, data_loader, device):
        with torch.no_grad():
            mse_list = []
            for i, (inputs, labels) in enumerate(data_loader): 
                inputs, labels = inputs.to(device), labels.to(device)
                rep = head(inputs)
                # defense here
                (mu, std), rep = vae(rep)
                outputs = decoder(rep)
                mse = F.mse_loss(outputs, inputs)

                mse_list.append(mse.item())
            metric = sum(mse_list) / len(mse_list)
        return metric

    def valid(head, encoder, clf, data_loader, device, quit=False):
        encoder.eval()
        clf.eval()
 
        with torch.no_grad():
            correct, total = 0, 0
            for i, (inputs, labels) in enumerate(data_loader): 
                inputs, labels = inputs.to(device), labels.to(device)
                rep = head(inputs)
                # defense here
                (mu, std), rep = vae(rep)
                outputs = clf(encoder(rep))
                _, pred_label = torch.max(outputs.data, 1)

                total += inputs.data.size()[0]
                correct += (pred_label == labels.data).sum().item()
            metric = correct / float(total)
        return metric

    best_acc = 0
    clf.eval()
    head.eval()


    for e in range(epoch_max):
        mse_list = []
        #pbar = tqdm(enumerate(train_loader))
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets = inputs.to(device), targets.to(device)
            
            optim_decoder.zero_grad()
            optim_head.zero_grad()

            rep = head(inputs)

            
            (mu, std), rep = vae(rep)
            
            
            rec_img = decoder(rep)
            loss = F.mse_loss(rec_img, inputs)
            loss.backward()
            optim_decoder.step()

            mse_list.append(loss.item())



        val_mse = valid_rec(head, decoder, valid_loader, device)
        print(
            f"Epoch: {e}. train_mse: {sum(mse_list)/len(mse_list):.4f}, val_mse: {val_mse:.4f}"
        )

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets = inputs.to(device), targets.to(device)
            
            rep = head(inputs)

            (mu, std), rep = vae(rep)

            rec_img = decoder(rep)

            rec_img = rec_img.cpu()
            inputs = inputs.cpu()
            plt.figure(figsize=(2, 12))
            for i in range(10):
                plt.subplot(10, 2, i*2+1)
                plt.imshow(tt(inputs[i]))
                plt.axis('off')
                plt.subplot(10, 2, i*2+2)
                plt.imshow(tt(rec_img[i]))
                plt.axis('off')

            plt.savefig("img/head3layer/mi_1layer_train_MID1e-4_res_softplus30.jpg")
            # plt.savefig("mi_1layer_train.jpg")
            break

        for batch_idx, (inputs, targets) in enumerate(valid_loader):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets = inputs.to(device), targets.to(device)
            
            rep = head(inputs)

            (mu, std), rep = vae(rep)

            rec_img = decoder(rep)

            rec_img = rec_img.cpu()
            inputs = inputs.cpu()
            plt.figure(figsize=(2, 12))
            for i in range(10):
                plt.subplot(10, 2, i*2+1)
                plt.imshow(tt(inputs[i]))
                plt.axis('off')
                plt.subplot(10, 2, i*2+2)
                plt.imshow(tt(rec_img[i]))
                plt.axis('off')

            plt.savefig("img/head3layer/mi_1layer_valid_MID1e-4_res_softplus30.jpg")
            # plt.savefig("mi_1layer_valid.jpg")
            break
        #torch.save({"head": head.state_dict(), "decoder": decoder.state_dict()}, args.save_models_filename)
        '''
        train_acc = valid(net1_a, net1_b, net2, train_loader, device)
        val_acc = valid(net1_a, net1_b, net2, valid_loader, device)
        print(
            f"Epoch {e+1}/{epoch_max}. loss: {epoch_loss/epoch_len:.4f}, "
            f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}"
        )
        '''
        #writer.add_scalar("loss", epoch_loss, e)
        #writer.add_scalar("train_acc", train_acc, e)
        #writer.add_scalar("val_acc", val_acc, e)



if __name__ == "__main__":
    main()
