import argparse
import torch
import numpy as np
import random
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

# import sys
# sys.path.append('/home/is/sota-ku/ASVIB/fvib_reg/nonlinear_IB_PyTorch/src')

from network import ResNetForCifar, FCNet
from dataset import OccludedCifar
from utils import calc_mutual_info, logits_labels

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


parser = argparse.ArgumentParser()

parser.add_argument('path', help="path where the files are saved", type=str)  
parser.add_argument('--n_trial', help="num of trials", default=5, type=int) 
# parser.add_argument('--epoch', help="the number of training epochs", default=200, type=int)  
# parser.add_argument('--batch', help="batch size", default=50, type=int) 
# parser.add_argument('--lr', help="learning rate", default=1e-4, type=float) 
parser.add_argument('--alpha', help="alpha in label smoothing", default=0., type=float) 

args = parser.parse_args() 

def train(n_epochs, net, train_loader, test_loader, alpha, path, optimizer, scheduler=None, data_processing=None, device=device):
    # KL = KLLoss()
    CE = nn.CrossEntropyLoss(label_smoothing = alpha)
    losses = []
    running_CEs = []
    running_CEs_test, corrects = [], []
    test_losses = []
    os.makedirs(path, exist_ok=True)
    for epoch in tqdm(range(n_epochs)):
        running_CE = 0.0
        net.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            if data_processing != None:
                data_processing.eval()
                with torch.no_grad():
                    inputs = data_processing(inputs)
            optimizer.zero_grad()
            # forward + backward + optimize
            out = net(inputs)
            loss = CE(out, labels)

            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                running_CE += loss.item()/len(train_loader)
                   
        if scheduler is not None:
            scheduler.step()

        running_CEs.append(running_CE)

        correct = 0
        total = 0
        running_CE_test = 0.0
        net.eval()
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                if data_processing != None:
                    data_processing.eval()
                    inputs = data_processing(inputs)
                # calculate outputs by running images through the network
                out = net(inputs)
                
                running_CE_test += CE(out, labels).item()/len(test_loader)
                _, predicted = torch.max(out.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                    
        corrects.append(correct)
        running_CEs_test.append(running_CE_test)
        #print(100*correct/total)
        
    accs = 100 * np.array(corrects) / total
    running_CEs = np.array(running_CEs)
    running_CEs_test = np.array(running_CEs_test)
    
    np.save(path + "/acc", accs)
    np.save(path + "/train_ce_loss", running_CEs)
    np.save(path + "/test_ce_loss", running_CEs_test)
    torch.save(net.state_dict(), path + '/weight.pth')
    
    
    
    
# fix seed
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

batch_size = 128

data_path = '/Volumes/csbdeep15/sota_ku/cifar10/'
train_transform = transforms.Compose(
    [transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root=data_path, train=True,
                                        download=False, transform=train_transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

testset = torchvision.datasets.CIFAR10(root=data_path, train=False,
                                       download=False, transform=test_transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

no_aug_trainset = torchvision.datasets.CIFAR10(root=data_path, train=True,
                                        download=False, transform=test_transform)
no_aug_train_loader = torch.utils.data.DataLoader(no_aug_trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=4)
n = 9 #corresponds to 56-layer resnet


n_epochs = 160
lr = 0.1
net = ResNetForCifar(layers=[2*n, 2*n, 2*n], num_classes=10).to(device)
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1)
train(n_epochs,net, train_loader, test_loader, args.alpha, args.path, optimizer, scheduler)


#calculate mutual info
train_logits, train_labels = logits_labels(no_aug_train_loader, net, device)
test_logits, test_labels = logits_labels(test_loader, net, device)

train_i_x_t, train_i_t_y = calc_mutual_info(train_logits, train_labels)
test_i_x_t, test_i_t_y = calc_mutual_info(test_logits, test_labels)

train_mi_array = np.array([train_i_x_t.cpu().item(), train_i_t_y.cpu().item()])
test_mi_array = np.array([test_i_x_t.cpu().item(), test_i_t_y.cpu().item()])

np.save(args.path+"/train_mi.npy", train_mi_array)
np.save(args.path+"/test_mi.npy", test_mi_array)