import tool
import opt_sw

import torch
import torch.nn as nn
import numpy as np 

import time
import logging

SETTING_PATH = '../settings/'


"""
test model
"""
def test_model(net, test_loader, device):
    correct = 0
    total =  0

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return 100*correct / total

def test_model_loss(net, test_loader, device):
    loss = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss += criterion(outputs, labels)
    
    return loss
                

"""
main
"""
def my_main():

    # 1. Read & Set Parameters
    args = tool.parse_arg()
    tool.setup_logging(args.dataset_name, args.save_path, args.save_name)

    f = open(SETTING_PATH+'{}/{}.txt'.format(args.set_path, args.set_name), 'r')
    optimizer_settings = eval(f.read())

    logging.warning('[START]')
    logging.warning('- dataset name: {}'.format(args.dataset_name))
    logging.warning('- random seed : {}'.format(args.random_seed))
    logging.warning('- setting path: {}/{}'.format(args.set_path, args.set_name))

    eval_losses = {}
    save_result = {}

    for optimizer_name in optimizer_settings:
        eval_losses[optimizer_name] = []
        save_result[optimizer_name] = []
        settings = optimizer_settings[optimizer_name]
        logging.warning(f'\n[Optimizer:{optimizer_name}]')

        # 2. Load Dataset, Model, Loss & Optimizer
        tool.set_seed(args.random_seed)
        train_loader, val_loader, test_loader, n_class = tool.load_dataset(args.dataset_name, args.model_name, args.batch_size, args.random_seed)

        device = torch.device(args.device)
        model = tool.load_model(args.model_name, n_class, args.model_weights, args.is_fulltrain).to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = tool.load_optimizer(model, settings, device)

        # 3. before Fine-Tuning
        start_time = time.time()
        if type(optimizer) is opt_sw.OptSwitcher:
            optimizer.init(model, args.model_name, train_loader, val_loader)
        init_time = time.time() - start_time
        logging.warning('- init time: {}'.format(init_time))

        # 4. Fine-Tuning
        start_time = time.time()
        pre_loss  = 100000
        conv_epoch = 0
        conv_patience = 10
        conv_mindelta = 0.001

        for epoch in range(args.n_epoch):

            for i, data in enumerate(train_loader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                # zero the parameter gradients
                if type(optimizer) is opt_sw.OptSwitcher:
                    optimizer.recommend_optimizer(model)
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                if type(optimizer) is opt_sw.OptSwitcher:
                    optimizer.step(loss.item())
                else:
                    optimizer.step()

                """
                # print statistics
                running_loss += loss.item()
                if i % 50 == 49:
                    logging.warning(f'-[{epoch+1}, {i+1:5d}] loss: {running_loss/10:.4f}')
                    running_loss = 0.0
                """
            # early stop?
            eval_loss = test_model(model, val_loader, device)
            eval_losses[optimizer_name].append(eval_loss)

            if pre_loss - eval_loss > conv_mindelta:
                conv_epoch = 0
            else:
                conv_epoch +=1

            pre_loss = eval_loss

            if conv_epoch >= conv_patience:
                logging.warning('- conv epoch: {}'.format(epoch+1))
                break
            elif epoch == args.n_epoch-1:
                logging.warning('- conv epoch: {}'.format(epoch+1))

            if args.is_save:
                train_acc, train_loss = test_model(model, train_loader, device), test_model_loss(model, train_loader, device)
                val_acc  , val_loss   = test_model(model, val_loader  , device), test_model_loss(model, val_loader  , device)
                test_acc , test_loss  = test_model(model, test_loader , device), test_model_loss(model, test_loader , device)
                save_result[optimizer_name].append([train_acc, val_acc, test_acc, train_loss, val_loss, test_loss])


        ft_time = time.time()-start_time
        logging.warning('- ft time: {}'.format(ft_time))

        # 5. test the network
        logging.warning(f'- train acc.: {test_model(model, train_loader, device):.4f}%')
        logging.warning(f'- val   acc.: {test_model(model, val_loader, device):.4f}%')
        logging.warning(f'- test  acc.: {test_model(model, test_loader, device):.4f}%')


    # save eval losses
    np.save(f'../result/{args.save_path}/{args.dataset_name}/{args.save_name}-eval_loss.npy', np.array(eval_losses))
    
    if args.is_save:
        np.save(f'../result/{args.save_path}/{args.dataset_name}/{args.save_name}-save.npy', np.array(save_result))

if __name__=='__main__':
    import warnings
    warnings.filterwarnings("ignore")

    my_main()