import argparse
import torch
import numpy as np
import random
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import os
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

# import sys
# sys.path.append('/home/is/sota-ku/ASVIB/fvib_reg/nonlinear_IB_PyTorch/src')

from network import ResNetForCifar, FCNet
from dataset import OccludedCifar

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


parser = argparse.ArgumentParser()

parser.add_argument('path', help="path where the files are saved", type=str)  
parser.add_argument('--n_trial', help="num of trials", default=5, type=int) 
# parser.add_argument('--epoch', help="the number of training epochs", default=200, type=int)  
# parser.add_argument('--batch', help="batch size", default=50, type=int) 
# parser.add_argument('--lr', help="learning rate", default=1e-4, type=float) 
parser.add_argument('--alpha', help="alpha in label smoothing", default=0., type=float) 

args = parser.parse_args() 

def train(n_epochs, net, train_loader, test_loader, alpha, path, optimizer, scheduler=None, data_processing=None, device=device):
    # KL = KLLoss()
    CE = nn.CrossEntropyLoss(label_smoothing = alpha)
    losses = []
    running_CEs = []
    running_CEs_test, corrects = [], []
    test_losses = []
    os.makedirs(path, exist_ok=True)
    for epoch in tqdm(range(n_epochs)):
        running_CE = 0.0
        net.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            if data_processing != None:
                data_processing.eval()
                with torch.no_grad():
                    inputs = data_processing(inputs)
            optimizer.zero_grad()
            # forward + backward + optimize
            out = net(inputs)
            loss = CE(out, labels)

            loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                running_CE += loss.item()/len(train_loader)
                   
        if scheduler is not None:
            scheduler.step()

        running_CEs.append(running_CE)

        correct = 0
        total = 0
        running_CE_test = 0.0
        net.eval()
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                if data_processing != None:
                    data_processing.eval()
                    inputs = data_processing(inputs)
                # calculate outputs by running images through the network
                out = net(inputs)
                
                running_CE_test += CE(out, labels).item()/len(test_loader)
                _, predicted = torch.max(out.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                    
        corrects.append(correct)
        running_CEs_test.append(running_CE_test)
        #print(100*correct/total)
        
    accs = 100 * np.array(corrects) / total
    running_CEs = np.array(running_CEs)
    running_CEs_test = np.array(running_CEs_test)
    
    np.save(path + "/acc", accs)
    np.save(path + "/train_ce_loss", running_CEs)
    np.save(path + "/test_ce_loss", running_CEs_test)
    torch.save(net.state_dict(), path + '/weight.pth')
    
    
    
    
# fix seed
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)

batch_size = 128

trainset = OccludedCifar('/Volumes/csbdeep15/sota_ku/cifar10/', '/Volumes/csbdeep15/sota_ku/mnist/', True, "cifar", True)
c_train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = OccludedCifar('/Volumes/csbdeep15/sota_ku/cifar10/', '/Volumes/csbdeep15/sota_ku/mnist/', False, "cifar", False)
c_test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=True, num_workers=4)

trainset = OccludedCifar('/Volumes/csbdeep15/sota_ku/cifar10/', '/Volumes/csbdeep15/sota_ku/mnist/', True, "mnist", True)
m_train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = OccludedCifar('/Volumes/csbdeep15/sota_ku/cifar10/', '/Volumes/csbdeep15/sota_ku/mnist/', False, "mnist", False)
m_test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=True, num_workers=4)



n = 9 #corresponds to 56-layer resnet

for i in range(args.n_trial):
    # trainig with cifar label
    n_epochs = 160
    lr = 0.1
    net = ResNetForCifar(layers=[2*n, 2*n, 2*n], num_classes=10).to(device)
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1)
    path = os.path.join(args.path, str(args.alpha), str(i), "cifar")
    path += "/"
    train(n_epochs,net, c_train_loader, c_test_loader, args.alpha, path, optimizer, scheduler)
    
    #training with mnist label
    n_epochs = 80
    lr = 0.001
    data_path = os.path.join(path, "weight.pth")
    data_processing = ResNetForCifar(layers=[2*n, 2*n, 2*n], num_classes=10).to(device)
    data_processing.load_state_dict(torch.load(data_path))
    data_processing.fc = nn.Identity().to(device)

    m_net = FCNet(hidden_dim=512).to(device)
    m_optimizer = optim.SGD(m_net.parameters(), lr=lr, momentum=0.9)
    m_scheduler = None
    path = os.path.join(args.path, str(args.alpha), str(i), "mnist")
    path += "/"
    train(n_epochs, m_net, m_train_loader, m_test_loader, 0., path, m_optimizer, m_scheduler, data_processing)