import gzip
import copy
import pickle
import os.path
import sys

import numpy as np

from tqdm import tqdm
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F

from options import args_parser

from torchvision.models import resnet18

""" TESTING on CIFAR-10"""

def get_schedulers(scheduler, optimizer, milestones=[30,80], gamma=0.5, T_max=10, lr_mul=0.001, d_model=10, n_warmup_steps=5):
    if scheduler == "step":
        return torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=gamma)
    elif scheduler == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    elif scheduler == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def main():
    """Test SplitNN"""
    args = args_parser()
    lr = 1e-2
    epoch_max = 1000
    epoch_finetune = 10
    bs = 32
    Q = 1
    num_class = 10

    train_size = 50000

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Training on device", device)

    
    net = resnet18()
    net.fc = nn.Linear(net.fc.in_features, num_class)


    class Flatten(nn.Module):
        def __init__(self):
            super(Flatten, self).__init__()
            
        def forward(self, x):
            x = x.view(x.size(0), -1)
            return x

    modules = list(net.children())[:-4]
    encoder = nn.Sequential(*modules)

    modules = list(net.children())[-4:-1]
    clf = nn.Sequential(*[*modules, Flatten(), list(net.children())[-1]])
    clf_adv = nn.Sequential(*[Flatten(), nn.Linear(2048, 1024), nn.ReLU(), nn.Linear(1024, 256), nn.ReLU(), nn.Linear(256, 10)])

    ckp = torch.load("/home/js905/code/vfl_quit/saved_models_adv/cifar10_resnet18_lr001_psi1_phi1_theta0.5.pt")
    #encoder.load_state_dict(ckp['encoder'])
    encoder.to(device)
    clf.load_state_dict(ckp['clf'])
    clf.to(device)
    clf_adv.to(device)
    clf_adv_init = copy.deepcopy(clf_adv)
    clf_adv.load_state_dict(ckp['clf_adv'])
    



    '''
    if save_models_filename:
        if os.path.exists(save_models_filename):
            print(f"Restoring models from {save_models_filename}")
            data = torch.load(save_models_filename)
            # print("net1", net1.state_dict().keys())
            # print("data['net1']", data["net1"].keys())
            # net1.load_state_dict(data["net1"])
            # net2.load_state_dict(data["net2"])
            net1_a = data["net1_a"]
            net1_b = data["net1_b"]
            net2 = data["net2"]
    '''
    optim_encoder = optim.SGD(encoder.parameters(), lr=lr, momentum=0.9)
    optim_clf = optim.SGD(clf.parameters(), lr=lr, momentum=0.9)
    optim_clfadv = optim.SGD(clf_adv.parameters(), lr=lr, momentum=0.9)

    scheduler_encoder = torch.optim.lr_scheduler.StepLR(optim_encoder, 30, gamma=0.5)
    scheduler_clf = torch.optim.lr_scheduler.StepLR(optim_clf, 30, gamma=0.5)
    scheduler_clfadv = torch.optim.lr_scheduler.StepLR(optim_clfadv, 30, gamma=0.5)

    '''
    transform_train = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.ToPILImage(),
            #transforms.Pad(4, padding_mode="reflect"),
            #transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            #transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    transform_valid = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    '''
    transform_train = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    transform_valid = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_dataset = datasets.CIFAR10(
        root="../data/cifar10_vertical",
        train=True,
        download=True,
        transform=transform_train,
    ) 

    valid_dataset = datasets.CIFAR10(
        root="../data/cifar10_vertical",
        train=False,
        download=True,
        transform=transform_valid,
    )
 

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=bs, shuffle=False, num_workers=2
    )

    '''
    apply_transform = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    cifar_dataset = datasets.CIFAR10("/tmp/cifar10_vertical", train=True, download=True,
                                    transform=apply_transform)
    train_loader = DataLoader(cifar_dataset, batch_size=32, shuffle=True)
    '''



    def valid(encoder, clf, data_loader, device, quit=False):
        encoder.eval()
        clf.eval()
 
        with torch.no_grad():
            correct, total = 0, 0
            for i, (inputs, labels) in enumerate(data_loader): 
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = clf(encoder(inputs))
                _, pred_label = torch.max(outputs.data, 1)

                total += inputs.data.size()[0]
                correct += (pred_label == labels.data).sum().item()
            metric = correct / float(total)
        return metric


    train_acc = valid(encoder, clf, train_loader, device)
    val_acc = valid(encoder, clf, valid_loader, device)
    print(f"origional adv train acc:{train_acc:.4f}, val acc:{val_acc:.4f}")
    


    # finetune clf_adv to check the label leakage
    clf_adv_ft = copy.deepcopy(clf)
    optim_clfadvft = optim.SGD(clf_adv_ft.parameters(), lr=lr, momentum=0.9)
    for e in range(epoch_finetune):
        epoch_loss_adv = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets = inputs.to(device), targets.to(device)

            encoder.train()
            clf_adv_ft.train()

            optim_encoder.zero_grad()
            optim_clfadvft.zero_grad()

            feature = encoder(inputs)
            feature_sent = feature.detach().requires_grad_()
            pred_adv = clf_adv_ft(feature_sent)

            loss_adv = criterion(pred_adv, targets)

            loss_adv.backward(retain_graph=True)
            optim_clfadvft.step()

            epoch_loss_adv += loss_adv.item()
        adv_train_acc = valid(encoder, clf_adv_ft, train_loader, device)
        adv_val_acc = valid(encoder, clf_adv_ft, valid_loader, device)




        # main training loop   
        #writer = SummaryWriter("./")
        #epoch_len = len(train_loader)
        epoch_len = int(train_size/bs)


        print(f"Epoch: {e} Adversarial train loss: {epoch_loss_adv/epoch_len:.4f}, train acc: {adv_train_acc:.4f}, val acc: {adv_val_acc:.4f}")

        '''
        train_acc = valid(net1_a, net1_b, net2, train_loader, device)
        val_acc = valid(net1_a, net1_b, net2, valid_loader, device)
        print(
            f"Epoch {e+1}/{epoch_max}. loss: {epoch_loss/epoch_len:.4f}, "
            f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}"
        )
        '''
        #writer.add_scalar("loss", epoch_loss, e)
        #writer.add_scalar("train_acc", train_acc, e)
        #writer.add_scalar("val_acc", val_acc, e)



if __name__ == "__main__":
    main()
