import gzip
import copy
import pickle
import os.path
import sys

import numpy as np

from tqdm import tqdm
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from options import args_parser

from model.ResNet_vfl import resnet18

""" TESTING on CIFAR-10"""

def get_schedulers(scheduler, optimizer, milestones=[30,80], gamma=0.5, T_max=10, lr_mul=0.001, d_model=10, n_warmup_steps=5):
    if scheduler == "step":
        return torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=gamma)
    elif scheduler == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    elif scheduler == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def main():
    """Test SplitNN"""
    args = args_parser()
    lr = 1e-2
    epoch_max = 1000
    bs = 32
    Q = 1

    train_size = 50000

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Training on device", device)

    net = resnet18(10).to(device)

    '''
    if save_models_filename:
        if os.path.exists(save_models_filename):
            print(f"Restoring models from {save_models_filename}")
            data = torch.load(save_models_filename)
            # print("net1", net1.state_dict().keys())
            # print("data['net1']", data["net1"].keys())
            # net1.load_state_dict(data["net1"])
            # net2.load_state_dict(data["net2"])
            net1_a = data["net1_a"]
            net1_b = data["net1_b"]
            net2 = data["net2"]
    '''
    optim_net = optim.SGD(net.parameters(), lr=lr, momentum=0.9)

    scheduler_net = torch.optim.lr_scheduler.StepLR(optim_net, 30, gamma=0.5)

    '''
    transform_train = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.ToPILImage(),
            #transforms.Pad(4, padding_mode="reflect"),
            #transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            #transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    transform_valid = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    '''
    transform_train = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    transform_valid = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_dataset = datasets.CIFAR10(
        root="../data/cifar10_vertical",
        train=True,
        download=True,
        transform=transform_train,
    ) 

    valid_dataset = datasets.CIFAR10(
        root="../data/cifar10_vertical",
        train=False,
        download=True,
        transform=transform_valid,
    )
 

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=bs, shuffle=False, num_workers=2
    )

    '''
    apply_transform = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    cifar_dataset = datasets.CIFAR10("/tmp/cifar10_vertical", train=True, download=True,
                                    transform=apply_transform)
    train_loader = DataLoader(cifar_dataset, batch_size=32, shuffle=True)
    '''



    def valid(model, data_loader, device, quit=False):
        model.eval()
 
        with torch.no_grad():
            correct, total = 0, 0
            for i, (inputs, labels) in enumerate(data_loader): 
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, pred_label = torch.max(outputs.data, 1)

                total += inputs.data.size()[0]
                correct += (pred_label == labels.data).sum().item()
            metric = correct / float(total)
        return metric

    for e in range(epoch_max):
        epoch_loss = 0
        pbar = tqdm(enumerate(train_loader))
        for batch_idx, (inputs, targets) in pbar:
            pbar.set_description("Epoch {}".format(e+1))

            inputs, targets = inputs.to(device), targets.to(device)
            

            #print("inputs", torch.mean(inputs), torch.min(inputs), torch.max(inputs), inputs.shape, type(inputs))
            #print("targets", torch.unique(targets), targets.shape, type(targets))



            """Compute on site-1"""
            net.train()
            optim_net.zero_grad()

            pred = net.forward(inputs)  # keep on site-1_a

            loss = criterion(pred, targets)

            loss.backward()

            optim_net.step()



            epoch_loss += loss.item()

        scheduler_net.step()

        # main training loop   
        #writer = SummaryWriter("./")
        #epoch_len = len(train_loader)
        epoch_len = int(train_size/bs)
        train_acc = valid(net, train_loader, device)
        val_acc = valid(net, valid_loader, device)
        print(
            f"Epoch: {e}. loss: {epoch_loss/epoch_len:.4f}, "
            f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}"
        )

        '''
        train_acc = valid(net1_a, net1_b, net2, train_loader, device)
        val_acc = valid(net1_a, net1_b, net2, valid_loader, device)
        print(
            f"Epoch {e+1}/{epoch_max}. loss: {epoch_loss/epoch_len:.4f}, "
            f"train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}"
        )
        '''
        #writer.add_scalar("loss", epoch_loss, e)
        #writer.add_scalar("train_acc", train_acc, e)
        #writer.add_scalar("val_acc", val_acc, e)

    #if save_models_filename:
    #    torch.save({"net1_a": net1_a, "net1_b": net1_b, "net2": net2, "train_acc": train_acc, "val_acc": val_acc}, save_models_filename)


if __name__ == "__main__":
    main()
