#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr  9 13:40:57 2024

@author: jXXXX
"""

import numpy as np
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
# For obscure reasons this needs pip install tensorboard==2.11.0! 
# The latest version, 2.13.0, crashes on import
from torch.utils.tensorboard import SummaryWriter 
from datetime import datetime
import os
import pickle
import json

import argparse

# Parse arguments from command line interface
parser = argparse.ArgumentParser(description='Task-Sequencing')

# Logging parameters
parser.add_argument('--train-epochs', type=int, default=int(1e10), 
                    help='number of training epochs (default: 1e10)')
parser.add_argument('--train-time', type=float, default=10, 
                    help='hours to train for (default: 10)')
parser.add_argument('--save-interval', type=int, default=1000, 
                    help='interval for saving model parameters (default: 10000)')
parser.add_argument('--loss-interval', type=int, default=10, 
                    help='interval for loss reporting (default: 100)')

# File parameters
parser.add_argument('--base-dir', type=str, default='./train', 
                    help='base directory to store results in (default: ./train)')
parser.add_argument('--run-dir', type=str, default=None, 
                    help='directory for this particular run (default: [timestamp])')
parser.add_argument('--resume-dir', type=str, default=None, 
                    help='dir to resume training from (default: None)')

# Task parameters
parser.add_argument('--n-tasks', type=int, default=30, 
                    help='total number of tasks (default: 30)')
parser.add_argument('--n-contexts', type=int, default=3, 
                    help='number of contexts in each task (defaul: 3)')
parser.add_argument('--n-steps', type=int, default=5, 
                    help='number of time steps within each context (defaul: 5)')
parser.add_argument('--seed', type=int, default=0, 
                    help='random seed for generating reproducible tasks (default: 0)')
parser.add_argument('--n-dims', type=int, default=10, 
                    help='dimension of input and output vectors (defaul: 10)')
parser.add_argument('--task-start', type=int, default=0, 
                    help='task index to start training from (default: 0)')
parser.add_argument('--task-stop', type=int, default=-1, 
                    help='task index to stop training before (default: ignore, -1)')

# Training parameters
parser.add_argument('--lr', type=float, default=1e-3,
                    help='learning rate (default: 1e-3)')
parser.add_argument('--do-full-feedback', dest='do_full_feedback', default=False, action='store_true',
                    help='give feedback on each timepoint (default: False)')
parser.add_argument('--batch-size', type=int, default=4096,
                    help='batch size (default: 4096)')

# Model parameters
parser.add_argument('--layers', type=int, default=3,
                    help='GRU hidden layers (default: 3)')
parser.add_argument('--do-context', type=int, default=0,
                    help='0: no context, 1: context network, 2: ground truth (default: 0)')
parser.add_argument('--do-categorical', dest='do_categorical', default=False, action='store_true',
                    help='force context net output to be categorical (default: false)')
parser.add_argument('--do-decode', dest='do_decode', default=False, action='store_true',
                    help='decode context from latent state (default: false)')
parser.add_argument('--test-decoder', dest='test_decoder', default=False, action='store_true',
                    help='switch off supervised training of context decoder (default: false)')

# Define all possible operations given input data and current state
def add(state, data):
    return state + data

def subtract(state, data):
    return state - data

def up(state, data):
    return np.max(np.stack([state,data], axis=0), axis=0)

def down(state,data):
    return np.min(np.stack([state,data], axis=0), axis=0)

def copy(state,data):
    return data

def ignore(state,data):
    return state

# Create RNN class
class RNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, do_softmax=False):
        super().__init__()
        self.hidden_dim  = hidden_dim
        self.n_layers = n_layers
        
        self.gru = nn.GRU(input_dim, hidden_dim, n_layers, batch_first=True)
        # self.fc1 = nn.Sequential(nn.Linear(hidden_dim, output_dim), 
        #                          nn.Softmax(dim=-1) if do_softmax else nn.Identity())
        self.fc1 = nn.Linear(hidden_dim, output_dim)

        self.set_device()
        
    def set_device(self, device=None):
        if device is None:
            self.gpu = torch.cuda.is_available() 
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        else:
            self.gpu = not (device == 'cpu')
            self.device = device
    
    def forward(self, x, h0=None):
        # Input: batch x timepoints x data dim. Send to device
        x = x.to(self.device)
        # Hidden layer: layers x batch size x hidden dim
        if h0 is None:
            h0 = torch.zeros((self.n_layers, x.shape[0], self.hidden_dim), 
                             device=self.device)  
        # Initialise list of hidden states. If we want access to hidden activations,
        # we're going to have to loop through iterations, 
        # because o,h=gru(x,h) will only provide the final hidden state
        h = [h0]
        o = []
        # GRU output: batch x timepoints x hidden dim
        for t in range(x.shape[1]):
            # Out: batch x timepoints x hidden dim
            # Hidden: layers x batch x hidden dim
            out, hidden = self.gru(torch.unsqueeze(x[:,t,:],1), h[-1])
            o.append(out)
            h.append(hidden)
        # Stack output and hidden values across timepoints
        # Simply concatenate output along second (time) dimension, 
        # so it's batch size x timepoints x output dimension
        o = torch.concat(o, 1)
        # Stack h, so it is timepoints x layers x batch size x hidden dim
        # Then permute to put batch size x timepoints x layers x hidden dim
        # Then flatten to make it batch size x timeopints x concatenated hidden dim
        h = torch.stack(h[1:]).permute((2,0,1,3)).flatten(2,3)
        # Model output: batch x timepoints x data dim
        o = self.fc1(o)
        return o, h
    
# Little helper function for making a bunch of directories
def make_dirs(base_dir, run_dir, sub_dirs=None):
    # Make new directory for this run, if it doesn't exist already
    if not os.path.exists(os.path.join(base_dir, run_dir)):
       os.makedirs(os.path.join(base_dir, run_dir))
    # Then make required sub-dir in each of those
    for d in [] if sub_dirs is None else sub_dirs:
        if not os.path.exists(os.path.join(base_dir, run_dir, d)):
           os.makedirs(os.path.join(base_dir, run_dir,d))

# Little helper function for parsing arguments, either provided by command line or dictionary
def prepare_args(defaults, inputs=None):
    # Set timestamp for run-dir if not provided
    if defaults['run_dir'] is None:
        defaults['run_dir'] = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    # If there is any additional input through input dictionary: overwrite defaults
    if inputs is not None:
        for key, val in inputs.items():
            defaults[key] = val
    return defaults

# Data sampling function, given an array of tasks
def get_model_input(tasks, batch_tasks, N_contexts, N_steps, N_dims):
    # Generate input data: one random vector in each step
    input_data = np.random.randn(len(batch_tasks), N_contexts*N_steps, N_dims)
    # Generate output data: task applied to previous step
    output_data = np.zeros_like(input_data)
    for b, batch in enumerate([tasks[b] for b in batch_tasks]):
        for c, context in enumerate(batch):
            if c == 0:
                previous_step = np.zeros(N_dims)
            for t in range(N_steps):
                output_data[b, c * N_steps + t, :] = \
                    context(previous_step, input_data[b, c * N_steps + t, :])
                previous_step = output_data[b, c * N_steps + t, :]
    return input_data, output_data
    
# Model output function, given models and model input
def get_model_output(contextRNN, actionRNN, input_data, output_data, tasks,
                     do_context, do_categorical, batch_tasks, N_contexts, N_steps, N_tasks):
    # Run through the model, depending on context model
    if do_context == 0:
        # Concatenate which task is being solved for each model, repeat for timepoints
        task_input = torch.stack([torch.eye(N_tasks)[batch_tasks] for _ in range(N_contexts*N_steps)],axis=1)        
    elif do_context == 1:
        # Concatenate which task is being solved for each model, repeat for contexts
        task_input = torch.stack([torch.eye(N_tasks)[batch_tasks] for _ in range(N_contexts)],axis=1)
        # Feed the task input into the context RNN to get context embedding
        context_input = contextRNN(task_input)[0].to('cpu')
        # If model is doing categorical prediction: output dim is number of operations, append zeros
        if do_categorical:
            context_input = torch.cat([context_input, torch.zeros((
                context_input.shape[0], context_input.shape[1], N_tasks - context_input.shape[2]))], axis=2)
        # Repeat across timepoints in each context (this is for slow timescale, but also cheating)
        task_input = torch.repeat_interleave(context_input, N_steps, 1)
    elif do_context == 2:
        # Directly provide the correct context
        context_input = torch.stack([torch.cat(
            [torch.eye(len(operations))[[operations.index(c) for c in tasks[b]]], 
             torch.zeros(N_contexts, N_tasks - len(operations))], axis=-1) 
            for b in batch_tasks])
        # Repeat across timepoints in eadcdh context
        task_input = torch.repeat_interleave(context_input, N_steps, 1)
    
    # Concatenate input data and task/context information
    model_in = torch.concat([task_input, torch.tensor(input_data, dtype=torch.float)], axis=-1)
    model_target = torch.tensor(output_data, dtype=torch.float)
    model_out, model_hidden = actionRNN(model_in)    
    return model_in, model_target, model_out, model_hidden

# Stick them all in a function
operations = [add, subtract, up, down, copy, ignore]

# Multiple bar plots, from https://stackoverflow.com/a/60270421
def bar_plot(ax, data, colors=None, total_width=0.8, single_width=1, legend=True):
    """Draws a bar plot with multiple bars per data point.

    Parameters
    ----------
    ax : matplotlib.pyplot.axis
        The axis we want to draw our plot on.

    data: dictionary
        A dictionary containing the data we want to plot. Keys are the names of the
        data, the items is a list of the values.

        Example:
        data = {
            "x":[1,2,3],
            "y":[1,2,3],
            "z":[1,2,3],
        }

    colors : array-like, optional
        A list of colors which are used for the bars. If None, the colors
        will be the standard matplotlib color cyle. (default: None)

    total_width : float, optional, default: 0.8
        The width of a bar group. 0.8 means that 80% of the x-axis is covered
        by bars and 20% will be spaces between the bars.

    single_width: float, optional, default: 1
        The relative width of a single bar within a group. 1 means the bars
        will touch eachother within a group, values less than 1 will make
        these bars thinner.

    legend: bool, optional, default: True
        If this is set to true, a legend will be added to the axis.
    """

    # Check if colors where provided, otherwhise use the default color cycle
    if colors is None:
        colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

    # Number of bars per group
    n_bars = len(data)

    # The width of a single bar
    bar_width = total_width / n_bars

    # List containing handles for the drawn bars, used for the legend
    bars = []

    # Iterate over all data
    for i, (name, values) in enumerate(data.items()):
        # The offset in x direction of that bar
        x_offset = (i - n_bars / 2) * bar_width + bar_width / 2

        # Draw a bar for every value of that type
        for x, y in enumerate(values):
            bar = ax.bar(x + x_offset, y, width=bar_width * single_width, color=colors[i % len(colors)])

        # Add a handle to the last drawn bar, which we'll need for the legend
        bars.append(bar[0])

    # Draw legend if we need
    if legend:
        ax.legend(bars, data.keys())

def run(arg_dict=None):
    # Parse command line arguments, and turn it into a dictionary.
    # But also allow for running from interactive python terminal,
    # through optional arg_dict. Overwrite any parsed defaults with dict values
    # Also prepare timestamps where necessary    
    args = prepare_args(vars(parser.parse_args()), arg_dict)
       
    # Set parameters for tasks to generate
    N_tasks = args['n_tasks']
    N_contexts = args['n_contexts']
    N_steps = args['n_steps']
    N_dims = args['n_dims']
    
    # Create a bunch of tasks, but fix the random seed so it's comparable across training
    np.random.seed(args['seed'])
    # I used to just random sample. But that may give overrepresentation of particular tasks
    tasks = [[np.random.choice(operations) for _ in range(N_contexts)] for _ in range(N_tasks)]
    # So instead, how about cycling through operations, then shuffling the result
    #tasks = [c for c in zip(
    #    *[np.random.permutation([operations[i % len(operations)] for i in range(N_tasks)]) 
    #      for _ in range(N_contexts)])]
    task_names = ['-'.join([c.__name__[0:2] for c in t]) for t in tasks]
    task_ops = [[operations.index(t[c]) for t in tasks] for c in range(N_contexts)]

    # Determine whether to use the context RNN
    do_context = args['do_context']
    # Determine whether to give feedback on each timepoint, instead of just at the end
    do_full_feedback = args['do_full_feedback']
    # Determine whether to force context RNN output to be categorical
    do_categorical = args['do_categorical']
    # Determine whether to decode context from hidden state 
    do_decode = args['do_decode']
    
    # Create a super simple RNN that has to solve the task
    N_layers = args['layers']
    N_hidden = N_dims*len(operations)+N_tasks
    # Single RNN: operates on hidden state and input for a given task
    actionRNN = RNN(N_dims+N_tasks, N_hidden, N_dims, N_layers)
    # Context RNN: slowly tracks context for each task
    contextRNN = RNN(N_tasks, N_tasks, len(operations) if do_categorical else N_tasks, 2, do_softmax=do_categorical)
    # Context decoder: takes action hidden state and decodes context
    contextDecoder = nn.Sequential(nn.Linear(N_hidden*N_layers, N_hidden*N_layers*2), 
                                   nn.ReLU(), 
                                   nn.Linear(N_hidden*N_layers*2,len(operations)),
                                   nn.Softmax(dim=-1))
    
    # Find if gpu is available
    gpu = torch.cuda.is_available()
    device = torch.device('cuda:0' if gpu else 'cpu')
    print('Is gpu available?', gpu, ': device = ', device)
    # Send models to gpu if available (and if not, it stays on the cpu)
    actionRNN.to(device)
    contextRNN.to(device)
    contextDecoder.to(device)    
    # Load trained parameters from directory to resume from, if provided
    if args['resume_dir'] is not None:
        actionRNN.load_state_dict(torch.load(os.path.join(args['resume_dir'], 'actionRNN.zip'),
                                             map_location=torch.device(device)))
        contextRNN.load_state_dict(torch.load(os.path.join(args['resume_dir'], 'contextRNN.zip'),
                                              map_location=torch.device(device)))
        contextDecoder.load_state_dict(torch.load(os.path.join(args['resume_dir'], 'contextDecoder.zip'),
                                              map_location=torch.device(device)))        
    
    # Specify loss and optimiser
    loss  = nn.MSELoss(reduction='none')
    loss_decode = nn.CrossEntropyLoss()
    optimiser = torch.optim.Adam(list(actionRNN.parameters()) + 
                                 (list(contextRNN.parameters()) if do_context==1 else []) +
                                 (list(contextDecoder.parameters()) if do_decode else []))
    # Make directories for this training run
    base_dir = args['base_dir']
    run_dir = args['run_dir']
    make_dirs(base_dir, run_dir, ['tensorboard', 'test'])
    # Create a tensor board to stay updated on training progress. 
    # Start tensorboard with tensorboard --logdir=train
    writer = SummaryWriter(os.path.join(base_dir, run_dir, 'tensorboard'))    
    # Save the current selection of tasks, if it doesn't exist yet
    # (It might exists if this is a previously trained model)
    with open(os.path.join(base_dir, run_dir, 'task.pkl'), 'wb') as f:
        pickle.dump(tasks, f)
    # Save the current arguments to a text file
    with open(os.path.join(base_dir, run_dir, 'args.json'), 'w') as fp:
        json.dump(args, fp)
        
    # Train it
    N_train = args['train_epochs']
    N_batch = args['batch_size']
    # Get start time for training, then get going
    train_start = datetime.now()    
    for i in range(N_train):
        # Break training loop if training time has surpassed maximum training time
        # if (datetime.now() - train_start).total_seconds() > args['train_time'] * 60 * 60:
        #     break        
        # Sample current task for each batch
        batch_tasks = np.random.randint(args['task_start'], 
                                        N_tasks if args['task_stop'] == -1 else args['task_stop'], 
                                        size=N_batch)
        # Generate input and output data
        input_data, output_data = get_model_input(tasks, batch_tasks, N_contexts, N_steps, N_dims)    
        
        # Run through model
        model_in, model_target, model_out, model_hidden = get_model_output(
            contextRNN, actionRNN, input_data, output_data, tasks,
            do_context, do_categorical, batch_tasks, N_contexts, N_steps, N_tasks)
    
        # Get loss
        if do_full_feedback:
            current_loss = loss(model_out, model_target.to(actionRNN.device))
        else:
            current_loss = loss(model_out[:,-1,:], model_target[:,-1,:].to(actionRNN.device))
    
        # Get decoding loss, if required
        current_loss_decode = torch.zeros(1, device=device);
        current_loss_context = torch.zeros(1, device=device);
        all_decoded = []
        if do_decode:
            for curr_context in range(N_contexts):
                true_context = torch.eye(len(operations), device=device)[
                    [task_ops[curr_context][t] for t in batch_tasks]]
                curr_decoded = contextDecoder(torch.mean(model_hidden[
                    :,(curr_context*N_steps):((curr_context+1)*N_steps),:],1))
                curr_context = model_in[:,curr_context*N_steps,:len(operations)].to(device)
                all_decoded.append(curr_decoded)
                if not args['test_decoder']:
                    current_loss_decode += loss_decode(curr_decoded, true_context)
                    current_loss_context += loss_decode(curr_context, curr_decoded.detach())
    
        # Log performance before backprop
        if np.mod(i,args['loss_interval']) == 0:
            loss_mat = np.mean(current_loss.detach().cpu().numpy(), axis=-1)
            writer.add_scalar('Losses/all', np.mean(loss_mat), i)
            writer.add_scalar('Losses/decode', current_loss_decode.detach().cpu().numpy(), i)
            writer.add_scalar('Losses/context', current_loss_context.detach().cpu().numpy(), i)
            for task_id in range(N_tasks):
                curr_task = batch_tasks==task_id
                if np.any(curr_task):
                    writer.add_scalar('Tasks/' + task_names[task_id], np.mean(loss_mat[curr_task]), i)  
            if do_decode:
                for c in range(N_contexts):
                    curr_decoded = all_decoded[c].detach().cpu().numpy()
                    for op, operation in enumerate(operations):
                        op_decoded = curr_decoded[np.array(task_ops[c])[batch_tasks]==op]
                        writer.add_scalar('Accuracies/' + str(c) + str(op) + '_' + operation.__name__, 
                                          sum(np.argmax(op_decoded,1)==op)/op_decoded.shape[0], i)
                    writer.add_scalar('Accuracies/' + str(c) + '_' + 'All', 
                                      sum(np.argmax(curr_decoded,1)==np.array(task_ops[c])[batch_tasks])/
                                      curr_decoded.shape[0], i)
            print(f'Step {i}, {datetime.now() - train_start}:' + f' loss {np.mean(loss_mat)}')            
            
        # Do backprop
        current_loss = torch.mean(current_loss) + current_loss_decode + current_loss_context
        optimiser.zero_grad()
        current_loss.backward()    
        optimiser.step()   
    
        # Store model weights
        if np.mod(i,args['save_interval']) == 0:    
            torch.save(actionRNN.state_dict(), os.path.join(base_dir, run_dir, 'actionRNN.zip'))
            torch.save(contextRNN.state_dict(), os.path.join(base_dir, run_dir, 'contextRNN.zip'))     
            torch.save(contextDecoder.state_dict(), os.path.join(base_dir, run_dir, 'contextDecoder.zip'))     
    # Store model weights
    torch.save(actionRNN.state_dict(), os.path.join(base_dir, run_dir, 'actionRNN.zip'))
    torch.save(contextRNN.state_dict(), os.path.join(base_dir, run_dir, 'contextRNN.zip'))        
    torch.save(contextDecoder.state_dict(), os.path.join(base_dir, run_dir, 'contextDecoder.zip'))     
    
    # Do some analysis on the results
    analyse(os.path.join(base_dir, run_dir))
    
def analyse(model_dir):
    # Load arguments from model_dir json file
    with open(os.path.join(model_dir, 'args.json')) as f:
        args = json.load(f)
        
    # Get directories for this training runtask_ops
    base_dir = args['base_dir']
    run_dir = args['run_dir']        
        
    # Get parameters for generated tasks
    N_tasks = args['n_tasks']
    N_contexts = args['n_contexts']
    N_steps = args['n_steps']
    N_dims = args['n_dims']
    
    # Create a bunch of tasks, but fix the random seed so it's comparable across training
    np.random.seed(args['seed'])
    tasks = [[np.random.choice(operations) for _ in range(N_contexts)] for _ in range(N_tasks)]
    task_names = ['-'.join([c.__name__[0:2] for c in t]) for t in tasks]
    
    # Determine whether to use the context RNN
    do_context = args['do_context']
    # Determine whether to force context RNN output to be categorical
    do_categorical = args['do_categorical']
    # Determine whether to decode context from hidden state 
    do_decode = args['do_decode']    
    
    # Create a super simple RNN that has to solve the task
    N_layers = args['layers']
    N_hidden = N_dims*len(operations)+N_tasks
    # Single RNN: operates on hidden state and input for a given task
    actionRNN = RNN(N_dims+N_tasks, N_hidden, N_dims, N_layers)
    # Context RNN: slowly tracks context for each task
    contextRNN = RNN(N_tasks, N_tasks, len(operations) if do_categorical else N_tasks, 2, do_softmax=do_categorical)
    # Context decoder: takes action hidden state and decodes context
    contextDecoder = nn.Sequential(nn.Linear(N_hidden*N_layers, N_hidden*N_layers*2), 
                                   nn.ReLU(), 
                                   nn.Linear(N_hidden*N_layers*2,len(operations)),
                                   nn.Softmax(dim=-1))
    
    # Load trained model
    actionRNN.load_state_dict(torch.load(os.path.join(base_dir, run_dir, 'actionRNN.zip'),
                                         map_location=torch.device('cpu')))
    contextRNN.load_state_dict(torch.load(os.path.join(base_dir, run_dir, 'contextRNN.zip'),
                                          map_location=torch.device('cpu')))
    contextDecoder.load_state_dict(torch.load(os.path.join(base_dir, run_dir, 'contextDecoder.zip'),
                                          map_location=torch.device('cpu')))
    
    # Make sure both models are also internally set to be on the gpu
    actionRNN.set_device('cpu')
    contextRNN.set_device('cpu')
    # Set them both to evaluation mode, instead of training
    actionRNN.eval()
    contextRNN.eval()    
    contextDecoder.eval()
    
    # Get performance through time for each task
    error_through_time = np.zeros((N_tasks, N_contexts*N_steps))
    context_embedding = np.zeros((N_tasks, N_contexts, N_tasks))
    hidden_state = np.zeros((N_tasks, N_contexts, N_steps, N_layers*N_hidden))
    decoded = [[[] for _ in range(len(operations))] for _ in range(N_contexts)]
    for t, task in enumerate(tasks):
        # Create a batch of just the current task
        batch_tasks = [t for _ in range(100)]
        # Generate input & output data
        input_data, output_data = get_model_input(tasks, batch_tasks, N_contexts, N_steps, N_dims)
        # Run model
        with torch.no_grad():
            # Collect outputs
            model_in, model_target, model_out, model_hidden = get_model_output(
                contextRNN, actionRNN, input_data, output_data, tasks,
                do_context, do_categorical, batch_tasks, N_contexts, N_steps, N_tasks)
        # Get accuracy through time for each step
        error_through_time[t] = np.mean(np.mean(np.abs(output_data - model_out.numpy()), axis=0), axis=-1)
        # Get embedding of each context, and decode the context from model hidden state
        for curr_context in range(N_contexts):
            # Collect the embedding in the current context (will be the same in each timestep)
            context_embedding[t][curr_context] = model_in[0, curr_context * N_steps,:N_tasks]
            # Collect the hidden states in the current context, average across batches
            hidden_state[t][curr_context] = np.mean(model_hidden.numpy(),0)[
                (curr_context*N_steps):((curr_context+1)*N_steps)]
            # Decode current operation from model hidden state
            if do_decode:
                curr_decoded = contextDecoder(torch.mean(model_hidden[
                    :,(curr_context*N_steps):((curr_context+1)*N_steps),:],1)).detach().cpu().numpy()
                decoded[curr_context][operations.index(task[curr_context])].append(curr_decoded)
                       
    # Copy task to start and stop from inputs
    task_start = args['task_start']
    task_stop = args['task_stop'] if args['task_stop'] > -1 else N_tasks
            
    # Plot accuracy through time
    plt.figure(figsize=(6,5));
    cols = 6;
    rows = int(np.ceil((task_stop - task_start) / cols))
    for row in range(rows):
        for col in range(cols):
            curr_plot = row*cols+col
            curr_task = curr_plot + task_start
            if curr_task < task_stop:
                plt.subplot(rows,cols,curr_plot+1)
                plt.plot(np.arange(N_contexts*N_steps), error_through_time[curr_task])
                plt.title(task_names[curr_task])
                plt.xlim([0,N_contexts*N_steps-1])
                plt.ylim([0,np.max(error_through_time[task_start:task_stop])])
    plt.tight_layout()
    plt.savefig(os.path.join(base_dir, run_dir, 'test', 'Accuracy.png'))
    
    # Plot accuracy per context
    operation_error = np.full((len(operations), N_contexts + 1), np.nan)
    for o, operation in enumerate(operations):
        op_dat = [[] for _ in range(N_contexts + 1)]        
        for c in range(N_contexts):
            for t in range(task_start, task_stop):
                if tasks[t][c] == operation:
                    op_dat[c].append(error_through_time[t][(c*N_steps):((c+1)*N_steps)])
                    op_dat[N_contexts].append(error_through_time[t][(c*N_steps):((c+1)*N_steps)])
        for c in range(N_contexts + 1):
            if len(op_dat[c]) > 0:
                operation_error[o, c] = np.mean(np.concatenate(op_dat[c]))                
    plt.figure(figsize=(6,5))
    ax = plt.axes()
    data = {op.__name__ : op_dat for op, op_dat in zip(operations, operation_error)}
    bar_plot(ax, data, total_width=.8, single_width=.9)
    plt.xticks([i for i in range(N_contexts+1)], ['First', 'Second', 'Third', 'All'])
    plt.ylabel('Error')
    plt.xlabel('Operation context')
    plt.tight_layout()
    plt.savefig(os.path.join(base_dir, run_dir, 'test', 'Operations.png'))
                
    # Plot decoding per operation per context
    if do_decode:
        plt.figure(figsize=(len(operations)*2, N_contexts*2))
        for dc, decoded_context in enumerate(decoded):
            for do, decoded_operation in enumerate(decoded_context):
                plt.subplot(N_contexts, len(operations), dc*len(operations) + do + 1)
                curr_decoded = np.concatenate(decoded_operation,0);
                curr_class = np.argmax(curr_decoded,1)
                plt.errorbar(np.arange(len(operations)), np.mean(curr_decoded,0), np.std(curr_decoded,0)/np.sqrt(len(curr_decoded)));
                plt.title(operations[do].__name__ + str(dc) + ': ' + f'{sum(curr_class==do)/len(curr_class)*100:.0f} %')
                plt.xlim([-1,len(operations)])
                plt.ylim([0,1])
                plt.xticks([])
                plt.yticks([])
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, run_dir, 'test', 'Decoded.png'))    
    
    # Plot the similarities of the contexts across tasks
    for activations, name in zip([context_embedding, np.mean(hidden_state, 3)], ['Context embedding', 'Hidden state']):
        plt.figure(figsize=(12,10))
        for curr_context in range(N_contexts):
            context_operation = np.array([operations.index(t[curr_context]) for t in tasks[task_start:task_stop]])
            embedding_sim = np.corrcoef(activations[task_start:task_stop,curr_context,:])
            operation_sim = np.reshape(context_operation,(-1,1)) == np.reshape(context_operation,(1,-1))
            
            plt.subplot(2,N_contexts,curr_context + 1)
            plt.imshow(embedding_sim,vmin=-1, vmax=1)
            plt.xticks(np.arange(task_stop-task_start), task_names[task_start:task_stop],rotation=90)
            plt.yticks(np.arange(task_stop-task_start), task_names[task_start:task_stop])
            
            plt.title(name + ' step ' + str(curr_context))
        
            plt.subplot(2,N_contexts,N_contexts + curr_context + 1)
            plt.imshow(operation_sim,vmin=0, vmax=1)
            plt.xticks(np.arange(task_stop-task_start), task_names[task_start:task_stop],rotation=90)
            plt.yticks(np.arange(task_stop-task_start), task_names[task_start:task_stop])
            
            # Keep upper triangles of similarity matrices. Round to 8 decimals to avoid numerical errors
            sim_vals = np.stack([m[np.triu_indices(task_stop-task_start,1)] for m in [np.around(embedding_sim,8), operation_sim]])
            plt.title('Correlation: ' + '{0:.2f}'.format(np.corrcoef(sim_vals)[0,1]))
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, run_dir, 'test', name + '.png'))
        
    # Plot trajectories (tasks x timepoints x hiddend dims) on first three principal components
    trajectories = np.reshape(hidden_state[task_start:task_stop], 
                              [task_stop-task_start, -1, N_layers*N_hidden])
    # Make a big matrix of all trajectories: (tasks * timepoints) x hidden dims
    pc_dat = np.reshape(trajectories, [-1, N_layers*N_hidden])
    # Do PCA
    y = pc_dat - np.mean(pc_dat,axis=0)
    cov = np.matmul(y.transpose(), y)
    w, v = np.linalg.eig(cov)
    # Project all trajectories on eigenvectors
    pc_trajectories = np.real(np.matmul(trajectories, v))
    # Plot only the first three dims of each
    plt.figure(figsize=(10,10));
    ax = plt.axes(projection ='3d')
    cm = plt.get_cmap('gist_rainbow')
    ax.set_prop_cycle(color=[cm(1.*i/(task_stop-task_start)) for i in range(task_stop-task_start)])
    for t_i, t in enumerate(pc_trajectories):
        curr_line = ax.plot(t[:, 0], t[:, 1], t[:, 2])
        ax.text(t[-1,0], t[-1, 1], t[-1, 2], task_names[task_start + t_i], zdir=None, color=curr_line[0].get_c())
        for c_i, c in enumerate(tasks[task_start + t_i]):
            ax.text(t[c_i*N_steps,0], t[c_i*N_steps, 1], t[c_i*N_steps, 2], c.__name__[:2], zdir=None, color=curr_line[0].get_c())
    plt.tight_layout()
    plt.savefig(os.path.join(base_dir, run_dir, 'test', 'Trajectories.png'))
            
    # Plot the training trajectories, exported from tensorboard
    do_context = np.loadtxt("./train/training/do_context_tensorboard.csv", delimiter =',',skiprows=1)        
    no_context = np.loadtxt("./train/training/no_context_tensorboard.csv", delimiter =',',skiprows=1)    
    true_context = np.loadtxt("./train/training/true_context_tensorboard.csv", delimiter =',',skiprows=1)    
    do_context_2 = np.loadtxt("./train/training/do_context_2_tensorboard.csv", delimiter =',',skiprows=1)    
    no_context_2 = np.loadtxt("./train/training/no_context_2_tensorboard.csv", delimiter =',',skiprows=1)    
    true_context_2 = np.loadtxt("./train/training/true_context_2_tensorboard.csv", delimiter =',',skiprows=1)    
    plt.figure(figsize=(12,4)); 
    for i, ylim in enumerate([2.5, 0.25, 0.05]):
        plt.subplot(1,3,i+1);
        plt.plot(no_context[:,1], no_context[:,2]); 
        plt.plot(do_context[:,1], do_context[:,2]); 
        plt.plot(true_context[:,1], true_context[:,2]); 
        plt.gca().set_prop_cycle(None); 
        plt.plot(no_context_2[:,1], no_context_2[:,2], ':'); 
        plt.plot(do_context_2[:,1], do_context_2[:,2], ':'); 
        plt.plot(true_context_2[:,1], true_context_2[:,2], ':');      
        plt.ylim([0, ylim])
        plt.xlabel('Training iteration'); 
        if i == 0:
            plt.legend(['No context', 'Context RNN', 'Ground truth context',
                        'No context new tasks', 'Context RNN new tasks', 'Ground truth context new tasks']); 
            plt.ylabel('Error');
    plt.tight_layout()
    plt.savefig("./train/training/LearningCurves.png")
    
def decode(model_dir):
    # Load arguments from model_dir json file
    with open(os.path.join(model_dir, 'args.json')) as f:
        args = json.load(f)
        
    # Get directories for this training runtask_ops
    base_dir = args['base_dir']
    run_dir = args['run_dir']        
        
    # Get parameters for generated tasks
    N_tasks = args['n_tasks']
    N_contexts = args['n_contexts']
    N_steps = args['n_steps']
    N_dims = args['n_dims']
    
    # Create a bunch of tasks, but fix the random seed so it's comparable across training
    np.random.seed(args['seed'])
    tasks = [[np.random.choice(operations) for _ in range(N_contexts)] for _ in range(N_tasks)]
    task_ops = [[operations.index(t[c]) for t in tasks] for c in range(N_contexts)]
    tasks_avail = np.arange(args['task_start'], 
                            N_tasks if args['task_stop'] == -1 else args['task_stop'])
    
    # Determine whether to use the context RNN
    do_context = args['do_context']
    # Determine whether to force context RNN output to be categorical
    do_categorical = args['do_categorical'] if 'do_categorical' in args.keys() else False
    # Determine whether to decode context from hidden state - here always true
    do_decode = True
    
    # Create a super simple RNN that has to solve the task
    N_layers = args['layers']
    N_hidden = N_dims*len(operations)+N_tasks

    # Find if gpu is available
    gpu = torch.cuda.is_available()
    device = torch.device('cuda:0' if gpu else 'cpu')
    print('Is gpu available?', gpu, ': device = ', device)
    
    # Action RNN: operates on hidden state and input for a given task
    actionRNN = RNN(N_dims+N_tasks, N_hidden, N_dims, N_layers)
    # Context RNN: slowly tracks context for each task
    contextRNN = RNN(N_tasks, N_tasks, len(operations) if do_categorical else N_tasks, 2, do_softmax=do_categorical)
        
    # Send models to gpu if available (and if not, it stays on the cpu)
    actionRNN.to(device)
    contextRNN.to(device)
    
    # Load trained model on correct device
    actionRNN.load_state_dict(torch.load(os.path.join(base_dir, run_dir, 'actionRNN.zip'),
                                         map_location=torch.device(device)))
    contextRNN.load_state_dict(torch.load(os.path.join(base_dir, run_dir, 'contextRNN.zip'),
                                          map_location=torch.device(device)))    
    
    # Make directories for this training run
    base_dir = args['base_dir']
    run_dir = os.path.join(args['run_dir'], 
                           'decode' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    make_dirs(base_dir, run_dir, ['tensorboard', 'test'])
    # Create a tensor board to stay updated on training progress. 
    # Start tensorboard with tensorboard --logdir=train
    writer = SummaryWriter(os.path.join(base_dir, run_dir, 'tensorboard'))    
        
    # Set training parameters
    N_train = 100 # Force to some low nr because this trains super quickly
    N_batch = args['batch_size']
    N_validation = 10
    
    # Store context decoder outputs (second dimension is train vs test)
    decoder_output = np.full((N_validation, 2, N_contexts, len(operations), len(operations)), np.nan)
    
    # Now resample test and training tasks a bunch of times, and retrain for each
    for val in range(N_validation):
        print(f'--- START VAL {val} / {N_validation} ---')
        # Get train and test divide for current iteration
        tasks_train = np.random.choice(tasks_avail, 
            size=int(0.6*len(tasks_avail)), replace=False)
        tasks_test = [i for i in tasks_avail if i not in tasks_train]

        # Re-create the context decoder: takes action hidden state and decodes context
        contextDecoder = nn.Sequential(nn.Linear(N_hidden*N_layers, N_hidden*N_layers*2), 
                                       nn.ReLU(), 
                                       nn.Linear(N_hidden*N_layers*2,len(operations)),
                                       nn.Softmax(dim=-1))
        contextDecoder.to(device)        
        # Specify loss and optimiser
        loss_decode = nn.CrossEntropyLoss()
        optimiser = torch.optim.Adam(contextDecoder.parameters())
    
        # Get start time for training, then get going
        train_start = datetime.now()    
        for i in range(N_train):
            # Sample current task for each batch
            batch_tasks = np.random.choice(tasks_train, size=N_batch, replace=True)
            # Generate input and output data
            input_data, output_data = get_model_input(tasks, batch_tasks, N_contexts, N_steps, N_dims)    
            
            # Run through model
            model_in, model_target, model_out, model_hidden = get_model_output(
                contextRNN, actionRNN, input_data, output_data, tasks,
                do_context, do_categorical, batch_tasks, N_contexts, N_steps, N_tasks)
           
            # Get decoding loss, if required
            current_loss_decode = torch.zeros(1, device=device);
            all_decoded = []
            for curr_context in range(N_contexts):
                true_context = torch.eye(len(operations), device=device)[
                    [task_ops[curr_context][t] for t in batch_tasks]]
                curr_decoded = contextDecoder(torch.mean(model_hidden[
                    :,(curr_context*N_steps):((curr_context+1)*N_steps),:],1))
                curr_context = model_in[:,curr_context*N_steps,:len(operations)].to(device)
                all_decoded.append(curr_decoded)
                current_loss_decode += loss_decode(curr_decoded, true_context)
        
            # Log performance before backprop
            if np.mod(i,args['loss_interval']) == 0:
                writer.add_scalar('Losses/decode', current_loss_decode.detach().cpu().numpy(), i)
                if do_decode:
                    for c in range(N_contexts):
                        curr_decoded = all_decoded[c].detach().cpu().numpy()
                        for op, operation in enumerate(operations):
                            op_decoded = curr_decoded[np.array(task_ops[c])[batch_tasks]==op]
                            if len(op_decoded) > 0:
                                writer.add_scalar('Accuracies/' + str(c) + str(op) + '_' + operation.__name__, 
                                                  sum(np.argmax(op_decoded,1)==op)/op_decoded.shape[0], i)
                        writer.add_scalar('Accuracies/' + str(c) + '_' + 'All', 
                                          sum(np.argmax(curr_decoded,1)==np.array(task_ops[c])[batch_tasks])/
                                          curr_decoded.shape[0], i)
                print(f'Step {i}, {datetime.now() - train_start}:' 
                      + f' loss {np.mean(current_loss_decode.detach().cpu().numpy())}')            
                
            # Do backprop
            current_loss = current_loss_decode
            optimiser.zero_grad()
            current_loss.backward()    
            optimiser.step()  
            
        # At the end of training, collect decoder performance
        for is_test, curr_task in enumerate([tasks_train, tasks_test]):
            decoded = [[[] for _ in range(len(operations))] for _ in range(N_contexts)]
            for t in curr_task:
                # Grab the task we're currently inspecting
                task = tasks[t]
                # Create a batch of just the current task
                batch_tasks = [t for _ in range(100)]
                # Generate input & output data
                input_data, output_data = get_model_input(tasks, batch_tasks, N_contexts, N_steps, N_dims)
                # Run model
                with torch.no_grad():
                    # Collect outputs
                    model_in, model_target, model_out, model_hidden = get_model_output(
                        contextRNN, actionRNN, input_data, output_data, tasks,
                        do_context, do_categorical, batch_tasks, N_contexts, N_steps, N_tasks)
                # Get embedding of each context, and decode the context from model hidden state
                for curr_context in range(N_contexts):
                    # Decode current operation from model hidden state
                    curr_decoded = contextDecoder(torch.mean(model_hidden[
                        :,(curr_context*N_steps):((curr_context+1)*N_steps),:],1)).detach().cpu().numpy()
                    decoded[curr_context][operations.index(task[curr_context])].append(curr_decoded)
            # Average across tasks and store
            for dc, decoded_context in enumerate(decoded):
                for do, decoded_operation in enumerate(decoded_context):
                    if len(decoded_operation) > 0:
                        curr_decoded = np.concatenate(decoded_operation,0)
                        decoder_output[val, is_test, dc, do, :] = np.mean(curr_decoded,0)
           
    # Save all decoder outputs
    np.save(os.path.join(base_dir, run_dir, 'decoder_output.npy'), decoder_output)
    
    # Plot decoding per operation per context
    plt.figure(figsize=(len(operations)*2, N_contexts*2))
    for is_test in [0, 1]:
        for dc in range(N_contexts):
            for do in range(len(operations)):
                curr_decoded = decoder_output[:, is_test, dc, do, :];
                curr_include = np.all(np.logical_not(np.isnan(curr_decoded)), axis=1)
                if np.any(curr_include):
                    plt.subplot(N_contexts*2, len(operations), (N_contexts * len(operations)) * is_test + dc*len(operations) + do + 1)  
                    curr_class = np.argmax(curr_decoded[curr_include],1)
                    plt.errorbar(np.arange(len(operations)), np.mean(curr_decoded[curr_include],0), 
                                 np.std(curr_decoded[curr_include],0)/np.sqrt(np.sum(curr_include)));
                    plt.title(['Train', 'Test'][is_test] + ', ' + operations[do].__name__ + str(dc) 
                              + ': ' + f'{sum(curr_class==do)/len(curr_class)*100:.0f} %')
                    plt.xlim([-1,len(operations)])
                    plt.ylim([0,1])
                    plt.xticks([])
                    plt.yticks([])                    
    plt.tight_layout()
    plt.savefig(os.path.join(base_dir, run_dir, 'Decoded.png')) 
    
    # Plot confusion matrices
    plt.figure(figsize=(len(operations)*2, N_contexts*2))
    for is_test in [0, 1]:
        for dc, decoded_context in enumerate(decoded):
            plt.subplot(2, N_contexts, N_contexts*is_test + dc + 1)
            curr_map = np.nanmean(decoder_output[:, is_test, dc, :, :], axis=0)
            plt.imshow(curr_map, vmin=0, vmax=1)
            plt.xticks([i for i in range(len(operations))], [op.__name__ for op in operations])
            plt.xlabel('Decoded operation');
            plt.yticks([i for i in range(len(operations))], [op.__name__ for op in operations])
            plt.ylabel('True operation')
            plt.title(['Train','Test'][is_test] + ', context ' + str(dc))
    plt.tight_layout()
    plt.savefig(os.path.join(base_dir, run_dir, 'Confusion.png'))         

if __name__ == '__main__':
    run()
