
import copy
import json
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import pickle
from sklearn.metrics import average_precision_score
import sys
import torch
import torch.optim as optim
from torch.utils.data import TensorDataset

from DataUtils import get_loader, ImageDataset
from ModelWrapper import ModelWrapper
from ResNet import get_features, get_lm, get_model, set_lm

###
# Run configuration
###

def load_data(ids, images, data_key = 'file', index = None):
    files = []
    labels = []
    for i in ids:
        files.append(images[i][data_key])
        if index is None:
            labels.append(images[i]['label'])
        else:
            labels.append(images[i]['label'][index])

    labels = np.array(labels, dtype = np.float32)
    if len(labels.shape) == 1:
        labels = np.expand_dims(labels, 1)
        
    return files, labels

def load_phase(source, phase, index = None):
    with open('{}/{}/images.json'.format(source, phase), 'r') as f:
        images = json.load(f)
    ids = list(images)
    
    files, labels = load_data(ids, images, index = index)
    
    return files, labels

def load_phase_rep(source, phase, index = None):
    with open('{}/{}.pkl'.format(source, phase), 'rb') as f:
        data = pickle.load(f)
    ids = list(data)
    
    reps, labels = load_data(ids, data, data_key = 'rep', index = index)
    
    reps = np.array(reps, dtype = np.float32)
    if len(reps.shape) == 1:
        reps = np.expand_dims(reps, 0)
    
    return reps, labels  

def run(mode, trial, dataset_config, base_dir = './Outputs', select_metric = 'acc'):
    
    # Get the config for the dataset
    data_dir = dataset_config[0]
    id_from_path = dataset_config[1]
    out_features = dataset_config[2]

    # Setup the output directory
    model_dir = '{}/{}/trial{}'.format(base_dir, mode, trial)
    os.system('rm -rf {}'.format(model_dir))
    Path(model_dir).mkdir(parents = True)

    name = '{}/model'.format(model_dir)    
      
    # Get configuration from mode
    mode_split = mode.split('-')
    
    PRE = (mode == 'pretrained')
    
    TRANS = 'transfer' in mode_split
    TUNE = 'tune' in mode_split
    
    INIT = 'initial' in mode_split
    ADV = 'adv' in mode_split
    
    # Load default parameters
    if TRANS:
        lr = 0.001
    elif TUNE:
        lr = 0.0001
    elif not PRE:
        print('Error: Could not determine which parameters are to be trained')
        sys.exit(0)
   
    batch_size = 64
    select_cutoff = 5
    decay_max = 1
    
    parent_trans = '{}/initial-transfer/trial{}/model.pt'.format(base_dir, trial)
    parent_tune = '{}/initial-tune/trial{}/model.pt'.format(base_dir, trial)
        
    # Setup for each mode and train
    if PRE:
        # Setup the data loaders
        dataloaders = {}
        for phase in ['train', 'val']:
            files_tmp, labels_tmp = load_phase(data_dir, phase)
            dataset_tmp = ImageDataset(files_tmp, labels_tmp)
            dataloaders[phase] = get_loader(dataset_tmp, batch_size = batch_size)
            
        # Setup the model and optimization process
        model = get_model(mode = 'eval', parent = 'pretrained', out_features = out_features)
        model.cuda()

    elif INIT and TRANS:
        # Setup the data loaders
        dataloaders = {}
        for phase in ['train', 'val']:
            reps_tmp, labels_tmp = load_phase_rep('{}/pretrained/trial0'.format(base_dir), phase)
            reps_tmp = torch.Tensor(reps_tmp)
            labels_tmp = torch.Tensor(labels_tmp) 
            dataset_tmp = TensorDataset(reps_tmp, labels_tmp)
            dataloaders[phase] = get_loader(dataset_tmp, batch_size = batch_size)

        # Setup the model and optimization process
        model, _ = get_model(mode = 'transfer', parent = 'pretrained', out_features = out_features)
        lm = get_lm(model)
        optim_params = lm.parameters()
        model.cuda()
        lm.cuda()
        
        # Setup the loss
        metric_loss = torch.nn.BCEWithLogitsLoss()
        
        # Train
        lm = train_model(lm, optim_params, dataloaders, metric_loss, preds_batch, ap_agg, name = name,
                         lr_init = lr, decay_max = decay_max,
                         select_metric = select_metric, select_cutoff = select_cutoff,
                         mode = mode)      
        
        set_lm(model, lm)
        torch.save(model.state_dict(), '{}.pt'.format(name))
            
        # Clean up the model history saved during training
        os.system('rm -rf {}'.format(name))
               
    elif (INIT or ADV) and TUNE:
        # Setup the data loaders
        dataloaders = {}
        for phase in ['train', 'val']:
            files_tmp, labels_tmp = load_phase(data_dir, phase)
            dataset_tmp = ImageDataset(files_tmp, labels_tmp)
            dataloaders[phase] = get_loader(dataset_tmp, batch_size = batch_size)
            
        # Setup the model and optimization process
        model, optim_params = get_model(mode = 'tune', parent = parent_trans, out_features = out_features)
        model.cuda()
        
        # Setup the loss
        metric_loss = torch.nn.BCEWithLogitsLoss()
        
        # Train
        model = train_model(model, optim_params, dataloaders, metric_loss, preds_batch, ap_agg, name = name,
                            lr_init = lr, decay_max = decay_max,
                            select_metric = select_metric, select_cutoff = select_cutoff,
                            mode = mode, adv_train = ADV)        
        torch.save(model.state_dict(), '{}.pt'.format(name))
        
        # Clean up the model history saved during training
        os.system('rm -rf {}'.format(name))
        
    # Setup for evaluation
    model.cuda()
    model.eval()
    if PRE or ((INIT or ADV) and TUNE):
        feature_hook = get_features(model)
        wrapper = ModelWrapper(model, feature_hook = feature_hook, get_id = id_from_path)
        # Save Predictions and Representations
        for phase in ['train', 'val', 'test']:
            files_tmp, labels_tmp = load_phase(data_dir, phase)
            out_tmp = wrapper.predict_dataset(files_tmp, labels_tmp)
            with open('{}/{}.pkl'.format(model_dir, phase), 'wb') as f:
                pickle.dump(out_tmp, f)
            
###
# Training
###
    
def train_model(model, params, dataloaders, 
                # Metrics
                metric_loss, metric_acc_batch, metric_acc_agg, 
                # Learning rate configuration
                lr_init = 0.001, decay_phase = 'train', decay_metric = 'loss', decay_min = 0.001, decay_delay = 3, decay_rate = 0.1, decay_max = 2, 
                # Model selection configuration
                select_metric = 'acc', select_metric_index = 0, select_min = 0.001, select_cutoff = 5,
                # Mode configuration
                mode = None, mode_param = None, feature_hook = None, adv_train = False, 
                # Output configuration
                name = 'history', save_every_epoch = False):
  
    # Mode specific configuration
    mode_split = mode.split('-')
    
    TRANS = 'transfer' in mode_split
    TUNE = 'tune' in mode_split
    
    # Setup the learning rate and optimizer
    lr = lr_init
    optimizer = optim.Adam(params, lr = lr_init)
    
    # Setup adversarial training
    if adv_train:
        from FGSM import FastGradientSignUntargeted
        
        attack = FastGradientSignUntargeted(model, 
                                    epsilon = 0.0157, 
                                    alpha = 0.00784, 
                                    min_val = 0, 
                                    max_val = 1, 
                                    max_iters = 10, 
                                    _type = 'linf')
    
    # Setup the data logging
    loss_history = {}
    loss_history['train'] = []
    loss_history['val'] = []
    
    acc_history = {}
    acc_history['train'] = []
    acc_history['val'] = []
    
    select_history = []
    
    decay_history = []
    
    # Setup the training tracking
    select_wts = copy.deepcopy(model.state_dict())
    if select_metric == 'acc':
        select_value = 0
    elif select_metric == 'loss':
        select_value = np.inf
    else:
        print('Bad Parameter: select_metric')
        sys.exit(0)
    select_time = 0
    
    if decay_phase not in ['train', 'val']:
        print('Bad Parameter: decay_phase')
        sys.exit(0)
    if decay_metric == 'acc':
        decay_value = 0
    elif decay_metric == 'loss':
        decay_value = np.inf
    else:
        print('Bad Parameter: decay_metric')
        sys.exit(0)
    decay_time = 0
    decay_count = 0
        
    time = -1
    
    # Train
    os.system('rm -rf {}'.format(name))
    os.system('mkdir {}'.format(name))
    
    while True:
        # Check convergence and update the learning rate accordingly
        time += 1
        
        if time - select_time > select_cutoff:
            model.load_state_dict(select_wts)
            return model

        if time - decay_time > decay_delay:
            decay_count += 1
            if decay_count > decay_max:
                model.load_state_dict(select_wts)
                return model
            else:
                # Decay the learning rate
                lr = lr * decay_rate
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                decay_time = time
                decay_history.append(time)

        # Training and validation passes
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_acc = []
            running_counts = 0

            for data in dataloaders[phase]:

                # Load the data for this batch
                x = data[0]
                y = data[1]

                x = x.to('cuda')
                batch_size = x.size(0)
                y = y.to('cuda')

                # Forward pass
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    
                    if adv_train:
                        x_adv = attack.perturb(x, y, 'mean', True)
                        pred = model(x_adv)
                    else:
                        pred = model(x)
                    pred_sig = torch.sigmoid(pred)
                    
                    # Main loss
                    loss_main = metric_loss(pred, y)
                    
                    # Total loss
                    loss = loss_main
                    
                    # Backward pass
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Calculate batch statistics
                running_loss += loss.item() * batch_size
                running_counts += batch_size
                running_acc.append(metric_acc_batch(pred, y))

            # Calculate epoch statistics
            epoch_loss = running_loss / running_counts
            epoch_acc_all = metric_acc_agg(running_acc)
            epoch_acc = epoch_acc_all[select_metric_index]
            
            # Update the history
            loss_history[phase].append(epoch_loss)
            acc_history[phase].append(epoch_acc_all)
            
            # Check for decay objective progress
            if phase == decay_phase:
                if decay_metric == 'acc':
                    if epoch_acc > decay_value + decay_min:
                        decay_value = epoch_acc
                        decay_time = time
                elif decay_metric == 'loss':
                    if epoch_loss < decay_value - decay_min:
                        decay_value = epoch_loss
                        decay_time = time
                
            # Model selection
            if phase == 'val':
                if select_metric == 'acc':
                    if epoch_acc > select_value + select_min:
                        select_value = epoch_acc
                        select_time = time
                        select_wts = copy.deepcopy(model.state_dict())
                        select_history.append(time)
                        torch.save(select_wts, '{}/{}.pt'.format(name, time))
                elif select_metric == 'loss':
                    if epoch_loss < select_value - select_min:
                        select_value = epoch_loss
                        select_time = time
                        select_wts = copy.deepcopy(model.state_dict())
                        select_history.append(time)
                        torch.save(select_wts, '{}/{}.pt'.format(name, time))
                        
            if phase == 'train' and save_every_epoch:
                torch.save(model.state_dict(), '{}/train_{}.pt'.format(name, time))

        # Plot process so far
        metrics_num = len(acc_history['val'][0])
        metrics_names = metric_acc_agg(None)
        
        num_plots = 1 + metrics_num
        count = 1
        
        fig = plt.figure(figsize=(5, num_plots * 5))
        fig.subplots_adjust(hspace=0.6, wspace=0.6)
        
        x = [i for i in range(time + 1)]
    
        plt.subplot(num_plots, 1, count)
        plt.scatter(x, loss_history['train'], label = 'Train')
        plt.scatter(x, loss_history['val'], label = 'Val')
        if decay_metric == 'loss':
            for t in decay_history:
                plt.axvline(t, color = 'black', linestyle = '--')
        if select_metric == 'loss':
            for t in select_history:
                plt.axvline(t, color = 'green', linestyle = '--')
        plt.ylabel('Loss - Total')
        plt.legend()
        count += 1
        
        for i in range(metrics_num):
            plt.subplot(num_plots, 1, count)
            plt.scatter(x, [v[i] for v in acc_history['train']], label = 'Train')
            plt.scatter(x, [v[i] for v in acc_history['val']], label = 'Val')
            if i == select_metric_index:
                if decay_metric == 'acc':
                    for t in decay_history:
                        plt.axvline(t, color = 'black', linestyle = '--')
                if select_metric == 'acc':
                    for t in select_history:
                        plt.axvline(t, color = 'green', linestyle = '--')
            plt.xlabel('Time')
            plt.ylabel(metrics_names[i])
            plt.legend()
            count += 1
        plt.savefig('{}.png'.format(name))
        plt.close()
        
###
# Metrics
###
    
def preds_batch(y_hat, y):
    return [y_hat.cpu().data.numpy(), y.cpu().data.numpy()]

def ap_agg(preds_list):
    if preds_list is None:
        return ['Average Precision']
    else:
        y_hat = []
        y = []
        for batch in preds_list:
            y_hat.extend(batch[0])
            y.extend(batch[1])
        return [average_precision_score(y, y_hat)]
 