import argparse
import logging
import os
import sys

import os.path as osp
from ofa.imagenet_classification import networks
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import time
import torch.optim as optim
from ofa.model_zoo import proxylessnas_mobile, proxylessnas_net
from ofa.utils.layers import LinearLayer
from ofa.utils import init_models
from torchvision.models import *

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "./../")))
from data_preprocessing.utils.data_loader import *
# from utils.memory_cost_profiler import profile_memory_cost
import model.network as network


# Command setting
parser = argparse.ArgumentParser(description='Finetune compact model on source domain.')
parser.add_argument('--model', type=str, default='SHOTresnet18',help="compact model (resnet)", choices=['SHOTresnet18', 'SHOTresnet34'])
parser.add_argument('--batchsize', type=int, default=64)
parser.add_argument('--src', type=str, default='amazon')
parser.add_argument('--tar', type=str, default='webcam')
parser.add_argument('--dset', type=str, default='office', choices=['office', 'office-home', 'office-caltech', 'imageCLEF'])
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--n_epoch', type=int, default=100)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--decay', type=float, default=5e-4)
parser.add_argument('--data', type=str, default='../data/')
parser.add_argument('--early_stop', type=int, default=30)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()

# Parameter setting
DEVICE = torch.device('cuda:{}'.format(args.gpu))
BATCH_SIZE = {'src': int(args.batchsize), 'tar': int(args.batchsize)}

if args.dset == 'office-home':
    args.n_class = 65
if args.dset == 'office':
    args.n_class = 31
if args.dset == 'office-caltech':
    args.n_class = 10
if args.dset == 'imageCLEF':
    args.n_class = 12

def image_train(resize_size=256, crop_size=224):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])


def image_test(resize_size=256, crop_size=224):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        normalize
    ])


def load_test(root_path, dir, batch_size):
    data = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_test())
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4)
    return data_loader

def load_train(root_path, dir, batch_size, seed=2020):
    data_train = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_train())
    data_test = datasets.ImageFolder(root=os.path.join(root_path, dir), transform=image_test())
    train_size = int(0.9 * len(data_train))
    test_size = len(data_train) - train_size

    torch.manual_seed(seed)
    data_train, test1 = torch.utils.data.random_split(data_train, [train_size, test_size])

    torch.manual_seed(seed)
    _, data_val = torch.utils.data.random_split(data_test, [train_size, test_size])

    train_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4)
    val_loader = torch.utils.data.DataLoader(data_val, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4)
    return train_loader, val_loader


# Schedule learning rate
def lr_schedule(optimizer, epoch):
    def lr_decay(LR, n_epoch, e):
        return LR / (1 + 10 * e / n_epoch) ** 0.75

    for i in range(len(optimizer.param_groups)):
        if i < len(optimizer.param_groups) - 1:
            optimizer.param_groups[i]['lr'] = lr_decay(
                args.lr, args.n_epoch, epoch)
        else:
            optimizer.param_groups[i]['lr'] = lr_decay(
                args.lr, args.n_epoch, epoch) * 10

def test(model, target_test_loader):
    model.eval()
    correct = 0
    criterion = torch.nn.CrossEntropyLoss()
    len_target_dataset = len(target_test_loader.dataset)
    all_feat, all_target = None, None
    with torch.no_grad():
        for data, target in target_test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            feature = model[:-1](data)
            s_output = model(data)
            if all_feat is None:
                all_feat = feature.float()
                all_target = target.float()
            else:
                all_feat = torch.cat((all_feat, feature.float()), 0)
                all_target = torch.cat((all_target, target.float()), 0)
            loss = criterion(s_output, target)
            pred = torch.max(s_output, 1)[1]
            correct += torch.sum(pred == target)
    acc = correct.double() / len(target_test_loader.dataset)
    return acc

def finetune(model, dataloaders, optimizer, args):
    since = time.time()
    best_acc = 0
    criterion = nn.CrossEntropyLoss()
    stop = 0
    for epoch in range(1, args.n_epoch + 1):
        stop += 1
        # You can uncomment this line for scheduling learning rate
        # lr_schedule(optimizer, epoch)
        for phase in ['src', 'val', 'tar']:
            if phase == 'src':
                model.train()
            else:
                model.eval()
            total_loss, correct = 0, 0
            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'src'):
                    outputs= model(inputs)
                    # outputs = model(inputs)
                    # logging.info(outputs.cpu())
                    # logging.info(feature.cpu())
                    # loss = criterion(outputs.gpu(), labels)
                    loss = criterion(outputs, labels)
                preds = torch.max(outputs, 1)[1]
                if phase == 'src':
                    loss.backward()
                    optimizer.step()
                total_loss += loss.item() * inputs.size(0)
                correct += torch.sum(preds == labels.data)
            epoch_loss = total_loss / len(dataloaders[phase].dataset)
            epoch_acc = correct.double() / len(dataloaders[phase].dataset)

            log_str = 'Epoch: [{:02d}/{:02d}]---{}, loss: {:.6f}, acc: {:.4f}'.format(epoch, args.n_epoch, phase, epoch_loss,
                                                                                  epoch_acc)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if phase == 'val' and epoch_acc >= best_acc:
                stop = 0
                best_acc = epoch_acc
                torch.save(model.state_dict(), 'model.pkl')
        if stop >= args.early_stop:
            break
        print()
    model.load_state_dict(torch.load('model.pkl'))
    acc_test = test(model, dataloaders['tar'])
    time_pass = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_pass // 60, time_pass % 60))
    return model, acc_test

if __name__ == '__main__':
    torch.manual_seed(10)
    # Load data
    root_dir = args.data
    domain = {'src': str(args.src), 'tar': str(args.tar)}
    learning_rate = args.lr
    batchsize = args.batchsize
    n_classes = args.n_class

    if args.model[:4] == 'SHOT':
        netF, netB, netC = network.gen_shot_model(args.model[4:], n_classes,'bn',256,'wn')
        net = nn.Sequential(netF, netB, netC)
    else:
        net = eval(args.model + "(pretrained=True)")


        net.fc = nn.Linear(net.fc.in_features, n_classes) # re-assign a new linear layer as fc
        init_models(net.fc)

    net.to(DEVICE)

    param_group = []
    for name, parameters in net.named_parameters():
        if not name.__contains__('classifier'):
            param_group += [{'params': parameters}]
        else:
            param_group += [{'params': parameters, 'lr': learning_rate*10}]


    optimizer = optim.SGD(param_group, lr=learning_rate, momentum=args.momentum)

    dataloaders = {}
    dataloaders['src'], dataloaders['val'] = load_train(root_dir, domain['src'], BATCH_SIZE['src'])
    dataloaders['tar'] = load_test(root_dir, domain['tar'], BATCH_SIZE['tar'])

    args.out_file = open(osp.join('log_' + args.model + '_acc' + '.txt'), 'w')


    acc_test = test(net, dataloaders['tar'])
    log_str = 'Initial acc: {}'.format(acc_test)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

    model_best, best_acc = finetune(net, dataloaders, optimizer, args)

    save_model_name = './' + args.model + '_' + str(args.src)  + '.pt'

    torch.save(model_best, save_model_name)

    log_str = 'Best acc: {}'.format(best_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()
    print(log_str + '\n')

