import os
import sys
sys.path.insert(0, './')
import json
import time
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from util.Eval import AverageCalculator, calc_accuracy

def epoch_pass(model, criterion, optimizer, data_loader, is_train, epoch_idx, lr_func, label, device, tosave,mia=None):
    '''
     model: the model for training or inference
     criterion: the criterion to calculate the loss
     optimizer: the optimizer
     data_loader: data loader
     is_train: whether or not it is training
     epoch_idx: the epoch idx
     lr_func: the learning rate scheduler
     label: label of this batch, e.g. train, validate, test
     device: the device of the model and the data
     tosave: the information to be saved
    '''

    use_gpu = device != torch.device('cpu') and torch.cuda.is_available()
    batch_size = data_loader.batch_size

    acc_calculator = AverageCalculator()
    loss_calculator = AverageCalculator()

    if is_train == True:
        model.train()
    else:
        model.eval()

    for idx, packed_data in enumerate(data_loader, 0):

        if is_train == True and lr_func is not None:
            epoch_batch_idx = epoch_idx
            epoch_batch_idx += idx / len(data_loader)
            lr_this_batch = lr_func(epoch_batch_idx)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_batch
            if idx == 0:
                tosave['lr_per_epoch'] = lr_this_batch
                print('Learing rate = %1.2e' % lr_this_batch)

        if len(packed_data) == 3:
            data_batch, label_batch, idx_batch = packed_data
        elif len(packed_data) == 4:
            data_batch, label_batch, true_label_batch, idx_batch = packed_data  # mislabeled data setting
            true_label_batch = true_label_batch.cuda(device) if use_gpu else true_label_batch
        else:
            raise ValueError('The input data is not valid!')

        data_batch = data_batch.cuda(device) if use_gpu else data_batch
        label_batch = label_batch.cuda(device) if use_gpu else label_batch
        idx_batch = idx_batch.int().data.cpu().numpy()

        if is_train:

            optimizer.zero_grad()
            logits = model(data_batch)
            loss = criterion(logits, label_batch)
            loss.backward()
            if mia is not None:
                mia.fit(model, data_batch, label_batch)     # get the gradients
            optimizer.step()

        else:

            logits = model(data_batch)
            loss = criterion(logits, label_batch)

        acc_this_batch = calc_accuracy(logits.data, label_batch)
        _, prediction_this_batch = logits.max(dim = 1)

        loss_calculator.update(loss.item(), data_batch.size(0))
        acc_calculator.update(acc_this_batch.item(), data_batch.size(0))

        sys.stdout.write('%s Batch Idx: %d - Accuracy: %.2f%%\r' % (
            label.upper(), idx, acc_calculator.average * 100.))

    loss_this_epoch = loss_calculator.average
    acc_this_epoch = acc_calculator.average

    print('%s Loss / Accuracy after Epoch %d: %.4f / %.2f%%' % (
        label.upper(), epoch_idx, loss_this_epoch, acc_this_epoch * 100.))
    tosave['%s_loss' % label][epoch_idx] = loss_this_epoch
    tosave['%s_acc' % label][epoch_idx] = acc_this_epoch

    return model, tosave, loss_this_epoch, acc_this_epoch

def train(model, train_loader, valid_loader, test_loader, epoch_num, epoch2save, optimizer,
    lr_func, out_folder, model_name, device, criterion, tosave, mia=None):
    '''
     model: the model to train
     train_loader, valid_loader, test_loader: the data loaders
     epoch_num: the number of epochs
     epoch2save: the list of epoch indices where we save checkpoints
     optimizer: the optimizer
     lr_func: the learning rate scheduler
     out_folder: the output folder
     model_name: the name of the model
     device: the device where model and data are stored
     criterion: the loss criterion
     tosave: the information to save
    '''

    best_valid_acc = 0.
    best_valid_epoch = 0

    for epoch_idx in range(epoch_num):

        # Training

        if mia is None:
            t0 = time.time()
            model, tosave, loss_this_epoch, acc_this_epoch = epoch_pass(model = model, criterion = criterion, optimizer = optimizer,
                data_loader = train_loader, is_train = True, epoch_idx = epoch_idx, lr_func = lr_func, label = 'train', device = device, tosave = tosave)
            t1 = time.time()

            tosave['runtime'][epoch_idx] = round((t1 - t0) / 60, 3)

        else:
            t0 = time.time()
            model, tosave, loss_this_epoch, acc_this_epoch = epoch_pass(model = model, criterion = criterion, optimizer = optimizer, mia=mia,
                data_loader = train_loader, is_train = True, epoch_idx = epoch_idx, lr_func = lr_func, label = 'train', device = device, tosave = tosave)
            t1 = time.time()

            tosave['runtime'][epoch_idx] = round((t1 - t0) / 60, 3)

        # Validation
        if valid_loader is not None:

            model, tosave, loss_this_epoch, acc_this_epoch = epoch_pass(model = model, criterion = criterion, optimizer = optimizer,
                data_loader = valid_loader, is_train = False, epoch_idx = epoch_idx, lr_func = lr_func, label = 'valid', device = device, tosave = tosave)
            if acc_this_epoch > best_valid_acc:
                best_valid_acc = acc_this_epoch
                torch.save(model.state_dict(), os.path.join(out_folder, '%s_bestvalid.ckpt' % model_name))

        # Test
        model, tosave, loss_this_epoch, acc_this_epoch = epoch_pass(model = model, criterion = criterion, optimizer = optimizer,
            data_loader = test_loader, is_train = False, epoch_idx = epoch_idx, lr_func = lr_func, label = 'test', device = device, tosave = tosave)

        if (epoch_idx + 1) in epoch2save:
            torch.save(model.state_dict(), os.path.join(out_folder, '%s_%d.ckpt' % (model_name, epoch_idx + 1)))
            pickle.dump(tosave, open(os.path.join(out_folder, '%s.pickle' % model_name), 'wb'))

    torch.save(model.state_dict(), os.path.join(out_folder, '%s.ckpt' % model_name))
    pickle.dump(tosave, open(os.path.join(out_folder, '%s.pickle' % model_name), 'wb'))

    return model, tosave
