#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 15 11:19:11 2024

@author: XXXX
"""

import models
import data
import maze

import numpy as np
import torch
import torch.nn as nn
# For obscure reasons this needs pip install tensorboard==2.11.0! 
# The latest version, 2.13.0, crashes on import
from torch.utils.tensorboard import SummaryWriter 
from torch.utils.data import DataLoader
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from datetime import datetime
import os
import sys
import importlib.util
import json
import glob
import itertools

import argparse

import run # Import self? Yep. Weird workaround to force reloading all functions below

# Parse arguments from command line interface
parser = argparse.ArgumentParser(description='Compositional-Metalearning')

# Logging parameters
parser.add_argument('--train-epochs', type=int, default=int(1e5), 
                    help='number of training epochs (default: 1e5)')
parser.add_argument('--save-interval', type=int, default=1000, 
                    help='interval for saving model parameters (default: 1000)')
parser.add_argument('--loss-interval', type=int, default=10, 
                    help='interval for loss reporting (default: 10)')

# File parameters
parser.add_argument('--base-dir', type=str, default='./train', 
                    help='base directory to store results in (default: ./train)')
parser.add_argument('--run-dir', type=str, default=None, 
                    help='directory for this particular run (default: [timestamp])')
parser.add_argument('--resume-dir', type=str, default=None, 
                    help='dir to resume training from (default: None)')
parser.add_argument('--data-dir', type=str, default='./data', 
                    help='base directory to store data in (default: ./data)')

# Task parameters
parser.add_argument('--n-contexts', type=int, default=3, 
                    help='number of contexts in each task (default: 3)')
parser.add_argument('--n-steps', type=int, default=-1, 
                    help='total number of steps (default: -1, meaning n_contexts * mean(durations))')
parser.add_argument('--n-dims', type=int, default=6, 
                    help='dimension of input and output vectors (default: 6)')
parser.add_argument('--task-train', type=float, default=0.6, 
                    help='fraction of tasks used for training (default: 0.6)')
parser.add_argument('--task-choose', type=int, default=-1, 
                    help='pick one particular task to train (default: -1, ignore)')
parser.add_argument('--do-test', dest='do_test', default=False, action='store_true',
                    help='train on test tasks (default: false)')
parser.add_argument('--sparse-feedback', dest='sparse_feedback', default=False, action='store_true',
                    help='train with sparse feedback (default: false)')

# Training parameters
parser.add_argument('--lr', type=float, default=1e-3,
                    help='learning rate (default: 1e-3)')
parser.add_argument('--lr-steps', type=float, default=1,
                    help='number of steps for decreasing learning rate (default: 1)')
parser.add_argument('--batch-size', type=int, default=4096,
                    help='batch size (default: 4096)')
parser.add_argument('--freeze-modules', dest='freeze_modules', default=False, action='store_true',
                    help='freeze module weights (default: false)')
parser.add_argument('--freeze-gating', dest='freeze_gating', default=False, action='store_true',
                    help='freeze gating weights (default: false)')
parser.add_argument('--freeze-boundaries', dest='freeze_boundaries', default=False, action='store_true',
                    help='freeze boundary weights (default: false)')
parser.add_argument('--n-particles', type=int, default=100,
                    help='number of particles in particle filter (default: 100)')
parser.add_argument('--grad-clip', type=float, default=-1,
                    help='gradient clipping (default: -1, no clipping)')

# Model parameters
parser.add_argument('--model', type=int, default=0, 
                    help='model type (default: 0): ' + 
                    '0 = rule learner, ' + 
                    '1 = motor learner, ' +
                    '2 = rule control')
parser.add_argument('--layers', type=int, default=1,
                    help='GRU layers (default: 1)')
parser.add_argument('--hidden', type=int, default=500,
                    help='GRU hidden units (default: 500)')
parser.add_argument('--seed', type=int, default=0, 
                    help='random seed for reproducible model behaviour (default: 0)')
parser.add_argument('--n-blocks', type=int, default=6, 
                    help='number of independent RNNs for modular networks (default: 6)')
parser.add_argument('--rank', type=int, default=0, 
                    help='rank of each independent RNN (default: 0 for full rank RNN)')
parser.add_argument('--use-gru', dest='use_gru', default=False, action='store_true',
                    help='use gru instead of rnn (default: false)')
parser.add_argument('--weight-init', type=float, default=1,
                    help='recurrent network weight initialisation factor (default: 1)')
parser.add_argument('--sigma-init', type=float, default=0.1,
                    help='initial value of likelihood sigma (default: 0.1)')
parser.add_argument('--true-context', dest='true_context', default=False, action='store_true',
                    help='provide ground truth context signal (default: false)')
parser.add_argument('--no-context', dest='no_context', default=False, action='store_true',
                    help='disable context network, for flat transitions (default: false)')
parser.add_argument('--task-id', dest='task_id', default=False, action='store_true',
                    help='provide task id as input for control model (default: false)')

# Run training
def train_rule(arg_dict=None):
    # Parse command line arguments, and turn it into a dictionary.
    # But also allow for running from interactive python terminal,
    # through optional arg_dict. Overwrite any parsed defaults with dict values
    # Also prepare timestamps where necessary    
    args = prepare_args(vars(parser.parse_args()), arg_dict)
       
    # Reload models module, in case you've been analysing an existing model
    load_module()
    
    # Prepare dataset
    dataset = data.TaskSequenceDataset(args)   
    # Define a sampler and a dataloader on this dataset
    sampler = data.UniformSampler(dataset, n_batches=args['train_epochs'], batch_size=args['batch_size'])
    # Pin memory, combined with non-blocking below, will make data transfer to gpu quicker
    dataloader = DataLoader(dataset, batch_sampler=sampler, pin_memory=True)
    # Get model
    model = get_model(args, dataset)
    
    # Prepare the model for training and initalise all required objects
    device, optimiser, scheduler, writer = prepare_training(model, args)
    
    # Set MSELoss in case of control model
    if args['model'] == 2:
        loss = nn.MSELoss(reduction='none')    
    
    # Get start time for training, then get going
    train_start = datetime.now()    
    for i, model_in in enumerate(dataloader):
                
        # Dataloader keeps everything on cpu. Send current batch to device
        for k, v in model_in.items():
            # For non-blocking rationale: see https://stackoverflow.com/a/55564072
            model_in[k] = v.numpy() if k == 'id' else v.to(device, non_blocking=True)
               
        # Run through model
        model_out = model.nll(model_in['input'], model_in['context'], model_in['output'], N=args['n_particles']) \
            if args['model'] == 0 else model(model_in['input'], model_in['task'])

        # Determine which time steps to include, and how much to weight them
        valid_loss = torch.sum(model_in['context'], axis=-1) > 0
        weight_loss = torch.ones_like(valid_loss)
        # Get loss for included timesteps
        current_loss = -model_out['likelihood'] if args['model'] == 0 \
            else torch.sum(loss(model_out['action_out'], model_in['output']),-1)
        current_loss = current_loss[valid_loss] * weight_loss[valid_loss]
        
        # Log performance before backprop
        if np.mod(i,args['loss_interval']) == 0:
            log_training(i, model, args, dataset, 
                         writer, train_start,
                         model_in, model_out, 
                         current_loss, valid_loss)                  
        
        # Do backprop
        current_loss = torch.mean(current_loss)
        optimiser.zero_grad()
        current_loss.backward()    
        # Do gradient clipping
        if args['grad_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args['grad_clip'])
        optimiser.step()   
                
        # Step learning rate scheduler
        if np.mod(i + 1, int(args['train_epochs']/args['lr_steps'])) == 0:
            scheduler.step()
            print('- Stepping learning rate to ' + str(scheduler.get_last_lr()))
    
        # Store model weights
        if np.mod(i, args['save_interval']) == 0:    
            torch.save(model.state_dict(), os.path.join(args['base_dir'], args['run_dir'], 'model.zip'))

    # Store model weights
    torch.save(model.state_dict(), os.path.join(args['base_dir'], args['run_dir'], 'model.zip'))

# Run training
def train_motor(arg_dict=None):
    # Parse command line arguments, and turn it into a dictionary.
    # But also allow for running from interactive python terminal,
    # through optional arg_dict. Overwrite any parsed defaults with dict values
    # Also prepare timestamps where necessary    
    args = prepare_args(vars(parser.parse_args()), arg_dict)
       
    # Reload models module, in case you've been analysing an existing model
    load_module()
    
    # Set up environment and model
    env = maze.ContinuousMaze(args)
    model = get_model(args, env)

    # Do all the necessary training prep
    device, optimiser, scheduler, writer = prepare_training(model, args)
              
    # Get start time for training, then get going
    train_start = datetime.now()    
    for i in range(args['train_epochs']):
        
        # Sample tasks by listing, shuffling, and tiling, so that they're sampled approximately uniformly
        task_list = np.arange(env.task_start, env.task_stop)[np.random.permutation(env.task_stop-env.task_start)]
        tasks = np.tile(task_list, np.ceil(args['batch_size'] / env.task_stop-env.task_start).astype(int))[:args['batch_size']]

        # For consistency: create input dictionary 
        model_in = env.get_model_input(tasks)
        
        # Run model episode
        model_out = model.nll_guided(env, model_in['context'], model_in['output'], N=args['n_particles'])

        # Determine which time steps to include from context feedback signal
        valid_loss = torch.sum(model_in['context'], axis=-1) > 0
        weight_loss = torch.ones_like(valid_loss)
        # Get loss for included timesteps
        current_loss = -model_out['likelihood'][valid_loss] * weight_loss[valid_loss]
        
        # Log performance before backprop
        if np.mod(i,args['loss_interval']) == 0:
            log_training(i, model, args, env, 
                         writer, train_start,
                         model_in, model_out, 
                         current_loss, valid_loss)                  
        
        # Do backprop
        current_loss = torch.mean(current_loss)
        optimiser.zero_grad()
        current_loss.backward()    
        # Do gradient clipping
        if args['grad_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args['grad_clip'])
        optimiser.step()   
                
        # Step learning rate scheduler
        if np.mod(i + 1, int(args['train_epochs']/args['lr_steps'])) == 0:
            scheduler.step()
            print('- Stepping learning rate to ' + str(scheduler.get_last_lr()))

        # Store model weights
        if np.mod(i,args['save_interval']) == 0:    
            torch.save(model.state_dict(), os.path.join(args['base_dir'], args['run_dir'], 'model.zip'))
            torch.save(optimiser.state_dict(), os.path.join(args['base_dir'], args['run_dir'], 'optimiser.zip'))
    # Store model weights
    torch.save(model.state_dict(), os.path.join(args['base_dir'], args['run_dir'], 'model.zip'))

# Analyse trained model
def analyse_rule(model_dir):
    # Import matplotlib here to avoid cluster issues
    from matplotlib import pyplot as plt
    
    # Set unit size, which determines scale of all plots
    unit_size = 2
    
    # Load arguments from model_dir json file
    with open(os.path.join(model_dir, 'args.json')) as f:
        args = json.load(f)
    
    # First plot training tasks
    args['do_test']=False
    
    # This is a bit ugly, but avoids using a different model for training and analysis:
    # Replace the loaded "models" module by the source file saved in training - if it exists
    run.load_module((model_dir + '/source') if os.path.isdir(model_dir + '/source') else None)  
    
    # Prepare a dataset, defined on a set of tasks
    args['data_dir'] = os.path.join(
        os.getcwd(), os.path.basename(os.path.normpath(args['data_dir'])))
    
    # Specify data object to use
    dataset = data.TaskSequenceDataset(args)
    
    # Get model
    model = run.get_model(args, dataset)
    
    # Keep model on cpu
    device = 'cpu'
    model.to(device)
    model.set_device(device)
    # Load model trained parameters    
    model.load_state_dict(torch.load(os.path.join(model_dir, 'model.zip'),
                                     map_location=torch.device(device), weights_only=True))    
    # Set model to evaluation mode
    model.eval()
    model.set_hard_sampling(True)
    
    # Plot operation learned by each module, given a previous input of zeros and a current input of a 1
    result = torch.concatenate([model.action_output(rnn(torch.zeros((1,dataset.n_dims_in)), 
                                                  rnn(torch.eye(dataset.n_dims_in)[0].unsqueeze(0), model.action_h0.unsqueeze(0))))
                          for rnn in model.action_rnn])
    # Use result to re-order modules, so they correspond to true modules
    reordered = np.argsort(np.argmax(result.detach().cpu().numpy(),-1))
       
    # Plot the full matrix for each operation
    fig = plt.figure(figsize=(unit_size*4,unit_size*2))
    for i, o in enumerate(reordered):
        # Select current module
        rnn = model.action_rnn[o]
        # Plot ground truth operation
        plt.subplot(2, model.n_modules, i + 1)
        plt.imshow(np.stack([np.roll(v, i) for v in np.eye(dataset.n_dims_in)]))
        plt.xticks([])
        plt.yticks([])
        if i == 0:
            plt.title('True operation', ha='left')    
        # Plot matrix operation for this module
        plt.subplot(2, model.n_modules, model.n_modules + i + 1)
        result = torch.concatenate([model.action_output(rnn(torch.zeros((1,dataset.n_dims_in)), 
                                                      rnn(input_vec.unsqueeze(0), model.action_h0.unsqueeze(0))))
                              for input_vec in torch.eye(dataset.n_dims_in)])
        plt.imshow(result.detach().cpu().numpy())
        if i==0:
            plt.title('Learned module', ha='left')    
        plt.xticks([])
        plt.yticks([])
    plt.subplots_adjust(top=0.89,
    bottom=0.0,
    left=0.02,
    right=0.985,
    hspace=0.2,
    wspace=0.095)
    
    # Now run the model on *all* training tasks
    with torch.no_grad():
        model_in = dataset.get_model_input(np.arange(dataset.task_start, dataset.task_stop))
        model_out = model.nll(model_in[0], model_in[1], model_in[3],N=100)
    ancestors = model_out['ancestor'].detach().cpu().numpy()
    activations = model_out['activation'].detach().cpu().numpy()
    predictions = model_out['context_out'].detach().cpu().numpy()
    
    # Collect state history for each particle at each timestep
    # Dimensions particles x timesteeps x modules x timesteps
    # The second dimension indexes the history, the final dimension the timestep
    # So history of particle at timestep for task: history[task][particle, :, :, timestep]
    T = ancestors[0].shape[1]
    history = []
    for task_activations, task_ancestors in zip(activations, ancestors):
        task_history = np.zeros(list(task_activations.shape) + [T])    
        for t in range(1,T):
            # Copy over history *up to* previous timestep from the ancestor of current particle
            task_history[:, :t, :, t] = task_history[task_ancestors[:, t-1], :t, :, t-1]
            # Copy over resampled activation from previous timestep for current particle
            task_history[:, t, :, t] = task_activations[:, t-1, :]
        history.append(task_history)
       
    # Create figure to plot time-dependent transition matrix
    fig = plt.figure(figsize=(unit_size*4,unit_size*2)); 
    # Set pattern to look for
    p = [0, 1] # p = [0,1]: find pattern after switch. p = [1]: find pattern everywhere
    for col in range(5):
        # Find average next step given current pattern
        next_steps_data = np.zeros((dataset.n_operations, dataset.n_operations))
        next_steps_model = np.zeros((dataset.n_operations, dataset.n_operations))
        for i in range(dataset.n_operations): 
            # Collect transitions in ground truth context of data
            next_step_data = []
            for c in dataset.task_contexts[dataset.task_start:dataset.task_stop].numpy():
                zero_padded = np.concatenate([np.zeros((1, dataset.n_operations)), c], axis=0)
                for t in range(len(p), len(zero_padded-1)):
                    if np.all(zero_padded[(t-len(p)):t, i] == p):
                        if np.any(zero_padded[t]):
                            next_step_data.append(zero_padded[t])
            next_steps_data[i] = np.mean(np.stack(next_step_data,0),0) if len(next_step_data) > 0 else 0
            # Collect transitions in model output
            next_step_model = []
            for task_history, task_predictions in zip(history, predictions):
                for particle_history, particle_predictions in zip(task_history, task_predictions):
                    # Reorder modules
                    particle_history = particle_history[:, reordered, :]
                    particle_predictions = particle_predictions[:, reordered]
                    # Binarise history
                    particle_bin = 1.0*((particle_history == np.max(particle_history, axis=1)[:,None,:])
                                        & (particle_history > 0))
                    # Collect predictions for matching histories at each time step
                    for t in range(len(p), len(particle_history)):
                        if np.all(particle_bin[(t-len(p)):t, i, t-1] == p):
                            if np.any(particle_predictions[t-1]):
                                next_step_model.append(particle_predictions[t-1])
            next_steps_model[i] = np.mean(np.stack(next_step_model,0),0) if len(next_step_model) > 0 else 0
    
        # Append another step to the pattern
        p = p + [1]
        # Plot ground truth transitions
        plt.subplot(2,5,col+1)
        plt.imshow(next_steps_data)        
        if col == 0:
            plt.title('True transitions', ha='left')            
        plt.xticks([])
        plt.yticks([])           
        # Plot model transitions
        plt.subplot(2,5,5+col+1)
        plt.imshow(next_steps_model)        
        if col == 0:
            plt.title('Learned transitions', ha='left')            
        plt.xticks([])
        plt.yticks([])    
    plt.subplots_adjust(top=0.85,
    bottom=0.025,
    left=0.02,
    right=0.985,
    hspace=0.485,
    wspace=0.095)
    
    # Then plot individual examples for test tasks
    args['do_test']=True
    
    # Specify data object to use
    dataset = data.TaskSequenceDataset(args)
    
    # Now run the model on test tasks
    with torch.no_grad():
        model_in = [dataset.__getitem__(int(t*dataset.task_samples)+1) for t in range(dataset.task_stop - dataset.task_start)]
        model_in = {k: torch.stack([m[k] for m in model_in]) for k in model_in[0].keys() if k not in ['id']}
        model_out = model.nll(model_in['input'], model_in['context'], model_in['output'], N=100)
        trace_back = model.trace_back(model_out, N=1, sample=False)
    # Average particles in forward pass   
    model_out = {k: v if k in  ['ancestor', 'particle'] else torch.mean(v, 1) for k, v in model_out.items()}
    # Squeeze particles in backward trace (there's only one)
    trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}
       
    # Plot performance on individual tasks
    cols = min(dataset.task_stop - dataset.task_start, 1)
    rows = 4
    fig = plt.figure(figsize=(unit_size*2,unit_size*4))
    for t, task in enumerate(range(dataset.task_start, dataset.task_start + cols)):
        # Get length of this current task: sum of contexts in true data
        true_context = dataset.task_contexts[task].detach().numpy()    
        T = int(np.sum(true_context))
        # Plot target output    
        plt.subplot(rows, cols, 0*cols + t + 1);
        plt.imshow(model_in['output'][t][:T].numpy().transpose(), interpolation='none', aspect='auto');
        plt.xticks([]); plt.yticks([]);
        plt.ylabel('$y_t$', labelpad=-0.2)
        plt.title('True output')
        # Plot inferred output
        plt.subplot(rows, cols, 1*cols + t + 1);
        plt.imshow(trace_back['action_out'][t][:T].numpy().transpose(), interpolation='none', aspect='auto');
        plt.xticks([]); plt.yticks([]);
        plt.ylabel('$\mu_t$', labelpad=-0.2)
        plt.title('Model output')
        # Plot provided contexts
        plt.subplot(rows, cols, 2*cols + t + 1);
        plt.imshow(model_in['context'][t][:T].numpy().transpose(), interpolation='none', aspect='auto');
        plt.xticks([]); plt.yticks([]);
        plt.ylabel('Operations', labelpad=-0.2)
        # Plot true contexts
        plt.scatter(np.arange(T), np.argmax(true_context,axis=-1)[:T], color=[1,0,0], s=1)    
        plt.title('True operation')    
        # Plot filtering posterior
        plt.subplot(rows, cols, 3*cols + t + 1);
        plt.imshow(model_out['activation'][t][:T].numpy().transpose()[reordered], interpolation='none', aspect='auto');
        plt.xticks([0, T-1]); plt.yticks([]);
        plt.xlabel('Time', labelpad=-6);
        plt.ylabel('Modules', labelpad=-0.2)
        # Plot inferred contexts
        plt.scatter(np.arange(T), np.argmax(trace_back['activation'][t][:,reordered],axis=-1)[:T], color=[1,0,0], s=1)    
        plt.title('Module posterior')
    plt.subplots_adjust(top=0.929,
    bottom=0.081,
    left=0.1,
    right=0.958,
    hspace=0.6,
    wspace=0.2)
            
    import pdb; pdb.set_trace()
                      
# Analyse trained model
def analyse_motor(model_dir):
    # Import matplotlib here to avoid cluster issues
    from matplotlib import pyplot as plt

    # Set unit size, which determines scale of all plots
    unit_size = 2

    # Load arguments from model_dir json file
    with open(os.path.join(model_dir, 'args.json')) as f:
        args = json.load(f)
    
    # Load arguments from model_dir json file
    with open(os.path.join(model_dir, 'args.json')) as f:
        args = json.load(f)

    # Overwrite some arguments with specific values for plotting
    args['do_test']=False
    args['feedback_density'] = 1    

    # This is a bit ugly, but avoids using a different model for training and analysis:
    # Replace the loaded "models" module by the source file saved in training - if it exists
    run.load_module((model_dir + '/source') if os.path.isdir(model_dir + '/source') else None)  
        
    # Load the environment
    env = maze.ContinuousMaze(args)
    # Plot a bunch of tasks
    env.plot_tasks(5, 1, offset=1)
    fig_a = plt.gcf()
    fig_a.set_size_inches(unit_size*1, unit_size*4)
    plt.subplots_adjust(
    top=0.98,
    bottom=0.02,
    left=0.1,
    right=0.9,
    hspace=0.065,
    wspace=0.2)
    
    # Get model
    model = run.get_model(args, env)

    # Keep model on cpu
    device = 'cpu'
    model.to(device)
    model.set_device(device)
    env.set_device(device)
    # Load model trained parameters    
    model.load_state_dict(torch.load(os.path.join(model_dir, 'model.zip'),
                                     map_location=torch.device(device), weights_only=True))   

    # Set model to evaluation mode
    model.eval()
    model.set_hard_sampling(True)

    # Wrapper function to log current state of modules
    model_actions, model_states, data_actions = run.get_module_actions(model, env)
    overlap = run.get_module_overlap(model, env, model_actions, data_actions)
    reordered = run.get_module_order(model, overlap)

    # Then plot the true operation and the learned trajectory
    fig_b = plt.figure(figsize=(unit_size*3, unit_size*2))
    for i, o in enumerate(reordered):
        ax = plt.subplot(2, int(model.n_modules/2), i+1)
        # Get path from actions
        data_path = env.get_path(torch.tensor(data_actions[i], device=env.device), 
                            start=env.task_init).cpu()
        # Plot true module
        env.plot_task(ax, data_path, [i for _ in range(len(env.operations[i]) + 1)], plot_path=False)            
        # Plot model path on top
        model_path = env.get_path(torch.tensor(model_actions[o], device=env.device), 
                            start=env.task_init).cpu()
        # Plot model module
        env.plot_path(ax, model_path)        
    plt.subplots_adjust(
    top=0.98,
    bottom=0.02,
    left=0.02,
    right=0.98,
    hspace=0.1,
    wspace=0.1)

    # Now run the model on *all* training tasks
    # Run one example of each task
    tasks = np.arange(env.task_start, env.task_stop)
    model_in = env.get_model_input(tasks)
    with torch.no_grad():
        model_out = model.nll(env, model_in['context'], model_in['output'], N=50)
    ancestors = model_out['ancestor'].detach().cpu().numpy()
    activations = model_out['activation'].detach().cpu().numpy()
    predictions = model_out['context_out'].detach().cpu().numpy()

    # Collect state history for each particle at each timestep
    # Dimensions particles x timesteeps x modules x timesteps
    # The second dimension indexes the history, the final dimension the timestep
    # So history of particle at timestep for task: history[task][particle, :, :, timestep]
    T = ancestors[0].shape[1]
    history = []
    for task_activations, task_ancestors in zip(activations, ancestors):
        task_history = np.zeros(list(task_activations.shape) + [T])    
        for t in range(1,T):
            # Copy over history *up to* previous timestep from the ancestor of current particle
            task_history[:, :t, :, t] = task_history[task_ancestors[:, t-1], :t, :, t-1]
            # Copy over resampled activation from previous timestep for current particle
            task_history[:, t, :, t] = task_activations[:, t-1, :]
        history.append(task_history)
       
    # Create figure to plot time-dependent transition matrix
    fig_c = plt.figure(figsize=(unit_size*3,unit_size*2)); 
    # Set pattern to look for
    p = [0, 1] # p = [0,1]: find pattern after switch. p = [1]: find pattern everywhere
    for col in range(5):
        # Find average next step given current pattern
        next_steps_data = np.zeros((env.n_operations, env.n_operations))
        next_steps_model = np.zeros((env.n_operations, env.n_operations))
        for i in range(env.n_operations): 
            # Collect transitions in ground truth context of data
            next_step_data = []
            for c in env.data_context[env.task_start:env.task_stop].numpy():
                zero_padded = np.concatenate([np.zeros((1, env.n_operations)), c], axis=0)
                for t in range(len(p), len(zero_padded-1)):
                    if np.all(zero_padded[(t-len(p)):t, i] == p):
                        if np.any(zero_padded[t]):
                            next_step_data.append(zero_padded[t])
            next_steps_data[i] = np.mean(np.stack(next_step_data,0),0) if len(next_step_data) > 0 else 0
            # Collect transitions in model output
            next_step_model = []
            for task_history, task_predictions in zip(history, predictions):
                for particle_history, particle_predictions in zip(task_history, task_predictions):
                    # Reorder modules
                    particle_history = particle_history[:, reordered, :]
                    particle_predictions = particle_predictions[:, reordered]
                    # Binarise history
                    particle_bin = 1.0*((particle_history == np.max(particle_history, axis=1)[:,None,:])
                                        & (particle_history > 0))
                    # Collect predictions for matching histories at each time step
                    for t in range(len(p), len(particle_history)):
                        if np.all(particle_bin[(t-len(p)):t, i, t-1] == p):
                            if np.any(particle_predictions[t-1]):
                                next_step_model.append(particle_predictions[t-1])
            next_steps_model[i] = np.mean(np.stack(next_step_model,0),0) if len(next_step_model) > 0 else 0

        # Append another step to the pattern
        p = p + [1]
        # Plot ground truth transitions
        plt.subplot(2,5,col+1)
        plt.imshow(next_steps_data)        
        if col == 0:
            plt.title('True transitions', ha='left')            
        plt.xticks([])
        plt.yticks([])           
        # Plot model transitions
        plt.subplot(2,5,5+col+1)
        plt.imshow(next_steps_model)        
        if col == 0:
            plt.title('Learned transitions', ha='left')            
        plt.xticks([])
        plt.yticks([])    
    plt.subplots_adjust(top=0.85,
    bottom=0.025,
    left=0.02,
    right=0.985,
    hspace=0.485,
    wspace=0.095)

    # Switch to test tasks
    args['do_test']=True
    env = maze.ContinuousMaze(args)
    env.set_device(device)

    # Run one example of each task
    tasks = np.arange(env.task_start, env.task_stop)
    model_in = env.get_model_input(tasks)
    with torch.no_grad():
        model_out = model.nll(env, model_in['context'], model_in['output'], N=50)
    trace_back = model.trace_back(model_out, N=1, sample=False)
    # Squeeze particles in backward trace (there's only one)
    trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}

    # Choose one task to plot
    t = 1; task = tasks[t]
    true_context = env.data_context[task].detach().numpy()    
    T = int(np.sum(true_context))

    # Plot the task, inferred path, and particle trajectories
    fig_d = plt.figure(figsize=(unit_size*3, unit_size*4))
    # Create axes
    gs = fig_d.add_gridspec(2,2,height_ratios=[3,1])
    ax0 = fig_d.add_subplot(gs[0,:])
    ax1 = fig_d.add_subplot(gs[1,0])
    ax2 = fig_d.add_subplot(gs[1,1])
    # Get feedback signal, if there's sparse feedback
    feedback = (torch.sum(model_in['context'][t][:T],-1)==1).numpy()
    # Plot the task trajectory with colour indicating the true module
    plt.sca(ax0)
    env.plot_task(ax0, env.data_coord[task][:T], 
                  torch.argmax(env.data_context[task], -1).numpy()[:T], 
                  feedback=feedback, plot_path=False)  
    # Plot the particle hypotheses: proposed steps before feedback
    run.plot_particles(ax0, env, model_out['pred_state'][t][:,:T,:].numpy(), 
                   model_out['ancestor'][t][:,:T].numpy(), model_out['context_out'][t][:,:T,reordered].numpy())           
    # Plot the smoothing trajectory: highest likelihood trajectory after task completion
    # Collect path given by output
    curr_path = torch.cat([env.task_init, trace_back['pred_state'][t][:T]]).numpy()
    # Plot the full path
    ax0.plot(curr_path[:,0], curr_path[:,1], 'w-', linewidth=3)
    # Plot the path where each action is coloured by the module that generated it
    env.plot_path(ax0, curr_path, np.argmax(trace_back['activation'][t][:T,reordered].numpy(),axis=-1).squeeze())
    plt.title('Full feedback');
    # Plot provided contexts
    plt.sca(ax1)
    plt.imshow(model_in['context'][t][:T].numpy().transpose(), interpolation='none', aspect='auto');
    plt.scatter(np.arange(T), np.argmax(true_context,axis=-1)[:T], color=[1,0,0], s=1)    
    plt.title('True skill')
    plt.xticks([0, T-1]); plt.yticks([]);
    plt.xlabel('Time', labelpad=-6);
    # Plot context posterior
    plt.sca(ax2)
    plt.imshow(np.mean(model_out['activation'][t].numpy(), axis=0)[:T].transpose()[reordered], interpolation='none', aspect='auto', vmin=0, vmax=1);
    plt.scatter(np.arange(T), np.argmax(trace_back['activation'][t][:,reordered],axis=-1)[:T], color=[1,0,0], s=1)    
    plt.title('Module posterior')
    plt.xticks([0, T-1]); plt.yticks([]);
    plt.xlabel('Time', labelpad=-6);
    plt.subplots_adjust(
    top=0.92,
    bottom=0.093,
    left=0.037,
    right=0.938,
    hspace=0.25,
    wspace=0.086)

    import pdb; pdb.set_trace()    
            
# Little helper function for making a bunch of directories
def make_dirs(base_dir, run_dir, sub_dirs=None):
    # Make new directory for this run, if it doesn't exist already
    if not os.path.exists(os.path.join(base_dir, run_dir)):
       os.makedirs(os.path.join(base_dir, run_dir))
    # Then make required sub-dir in each of those
    for d in [] if sub_dirs is None else sub_dirs:
        if not os.path.exists(os.path.join(base_dir, run_dir, d)):
           os.makedirs(os.path.join(base_dir, run_dir,d))

# Little helper function for parsing arguments, either provided by command line or dictionary
def prepare_args(defaults, inputs=None):
    # Set timestamp for run-dir if not provided
    if defaults['run_dir'] is None:
        defaults['run_dir'] = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    # If there is any additional input through input dictionary: overwrite defaults
    if inputs is not None:
        for key, val in inputs.items():
            defaults[key] = val
    return defaults

# Create model based on input arguments
def get_model(args, dataset):
    # Print where the models module is taken from
    print('Getting model from module', models)
    # Set pytorch random seed
    torch.manual_seed(args['seed'])
    # Set kwargs for all models
    task_kwargs = {'n_tasks': dataset.n_tasks, 'n_contexts': dataset.n_contexts, 
                   'n_steps': dataset.n_steps, 'n_operations': dataset.n_operations}
    # Set network parameters:
    # input dimensions, module hidden, module layers, gating layers, number of modules
    network_params = [dataset.n_dims_in, args['hidden'], 
                          args['layers'], dataset.n_dims_out, 
                          args['layers'], args['n_blocks']]
    
    if args['model'] == 0:
        model = models.HMMMoE(*network_params,
                              use_gru=args['use_gru'], rank=args['rank'],
                              true_modules=False, true_gating=args['true_context'], flat_gating=args['no_context'],
                              weight_init=args['weight_init'], sigma_init=args['sigma_init'], **task_kwargs)       
        
    elif args['model'] == 1:
        model = models.HMMMoE_RL(*network_params,
                              use_gru=args['use_gru'], rank=args['rank'],
                              true_modules=False, true_gating=args['true_context'],
                              task=1, weight_init=args['weight_init'], sigma_init=args['sigma_init'], **task_kwargs)               
    elif args['model'] == 2:
        model = models.RNNControl(dataset.n_dims_in + dataset.n_tasks,
                                  args['hidden'] * (dataset.n_operations + 1),
                                  dataset.n_dims_out, 
                                  weight_init=args['weight_init'], task_id=args['task_id'], **task_kwargs)
    return model

def load_module(path=None):
    # Re-import the models module, so that you can use an existing saved source file 
    model_spec = importlib.util.spec_from_file_location(
        "models", (os.getcwd() if path is None else path) + '/models.py')
    models = importlib.util.module_from_spec(model_spec)
    model_spec.loader.exec_module(models)
    # Now replace the imported module by the one we have just defined
    sys.modules['models'] = models
    # Reload self, so all own functions (like get_model()) use loaded models module
    importlib.reload(run) # This is why you need "import run" and "run.train()"    
    
    
def prepare_training(model, args):
    # Find if gpu is available
    gpu = torch.cuda.is_available()
    device = torch.device('cuda:0' if gpu else 'cpu')
    print('Is gpu available?', gpu, ': device = ', device)
    # Send models to gpu if available (and if not, it stays on the cpu)
    model.to(device)
    
    # Create optimiser     
    optimiser = torch.optim.Adam(model.parameters(), lr=args['lr'])
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimiser, start_factor=1.0, end_factor=0.1, total_iters=args['lr_steps'])
    
    # Load trained parameters from directory to resume from, if provided
    if args['resume_dir'] is not None:
        model.load_state_dict(torch.load(os.path.join(args['resume_dir'], 'model.zip'),
                                             map_location=torch.device(device)))
    
    # Freeze and unfreeze model components as required
    for p in [p for name, p in model.named_parameters() if 'action' in name]:
        p.requires_grad = not args['freeze_modules']
    for p in [p for name, p in model.named_parameters() if 'context' in name]:
        p.requires_grad = not args['freeze_gating']
    
    # Make directories for this training run
    make_dirs(args['base_dir'], args['run_dir'], ['tensorboard', 'test'])
    # Create a tensor board to stay updated on training progress. 
    # Start tensorboard with tensorboard --logdir=train
    writer = SummaryWriter(os.path.join(args['base_dir'], args['run_dir'], 'tensorboard'))    
    # Save the current arguments to a text file
    with open(os.path.join(args['base_dir'], args['run_dir'], 'args.json'), 'w') as fp:
        json.dump(args, fp)    

    # Return all the objects that are used for model training        
    return device, optimiser, scheduler, writer
    
def log_training(i, model, args, dataset, writer, train_start, model_in, model_out, current_loss, valid_loss):
    # Define a loss calculation separately for rule vs motor learning tasks
    if args['model'] == 1:
        def loss_func(model_in, model_out):
            return torch.linalg.vector_norm(model_in['output'] - model_out['inf_state'], dim=-1).detach().cpu().numpy()
    else:
        def loss_func(model_in, model_out):
            return torch.sum(torch.square(model_in['output'] - model_out['action_out']), dim=-1).detach().cpu().numpy()
    # Get numpy variant of valid_loss
    valid_loss = valid_loss.detach().cpu().numpy()
    # Plot the loss for best particle, then average across particles - not for control, which doesn't have particles
    if args['model'] != 2:
        # Get highest likelihood trajectory to calculate best-case accuracy
        trace_back = {k: torch.squeeze(v) for k, v in model.trace_back(model_out).items()} 
        loss_best = loss_func(model_in, trace_back)
        writer.add_scalar('Losses/best', np.mean(loss_best[valid_loss]), i)    
        # Average across particles for all the logging analyses
        for k, v in model_out.items():
            model_out[k] = v if k in ['ancestor'] else torch.mean(v, 1)
    # Log losses
    loss_mat = loss_func(model_in, model_out)
    writer.add_scalar('Losses/all', np.mean(loss_mat[valid_loss]), i)
    # Log probabilistic-model specific quantities
    if args['model'] != 2:
        writer.add_scalar('Losses/likehood', np.mean(current_loss.detach().cpu().numpy()), i)
        writer.add_scalar('Losses/sigma', np.exp(model.log_sigma.detach().cpu().numpy()), i)
        # Log gating accuracy: whether activated modules match true context
        n_to_plot = min([50, args['batch_size']]);
        module_mat = trace_back['activation'][:n_to_plot].detach().cpu().numpy().reshape((n_to_plot*dataset.n_steps,-1))
        context_mat = model_in['context'][:n_to_plot].detach().cpu().numpy().reshape((n_to_plot*dataset.n_steps,-1))
        module_corr = np.corrcoef(module_mat)[np.triu_indices(n_to_plot*dataset.n_steps,1)]
        context_corr = np.corrcoef(context_mat)[np.triu_indices(n_to_plot*dataset.n_steps,1)]
        include = np.logical_and(np.logical_not(np.isnan(module_corr)), np.logical_not(np.isnan(context_corr)))
        writer.add_scalar('Accuracies/Activations', np.corrcoef(module_corr[include], context_corr[include])[0,1], i)
        # Log module accuracy: whether module outputs match true components
        if args['model'] == 0:
            input_zeros = model.t(torch.zeros((1,dataset.n_dims_in)))
            input_eye = model.t(torch.eye(dataset.n_dims_in)[0].unsqueeze(0))                                      
            with torch.no_grad():
                result = torch.concatenate([model.action_output(rnn(input_zeros, 
                                                              rnn(input_eye, model.action_h0.unsqueeze(0))))
                                      for rnn in model.action_rnn]).cpu().numpy()
            reordered = np.argsort(np.argmax(result,-1))
            writer.add_scalar('Accuracies/Modules', np.corrcoef(
                result[reordered].reshape(-1), np.eye(model.n_modules).reshape(-1))[0,1], i)
        else:
            overlap = log_module_actions(model, dataset, save_path=(
                os.path.join(args['base_dir'], args['run_dir'],f'ops_{i:05}.png')
                if np.mod(i,args['save_interval']) == 1e10 else None))
            writer.add_scalar('Accuracies/Modules', np.mean(np.max(overlap, axis=0)), i)
    # Print overall progress
    print(f'Step {i}, {datetime.now() - train_start}:' + f' loss {np.mean(loss_mat[valid_loss])}')

def plot_particles(ax, env, state, ancestor, context=None):
    from matplotlib import pyplot as plt
    import matplotlib
    tab10 = matplotlib.colormaps['tab10']    
    # Plot particles with increasingly light colour
    N_ptc = ancestor.shape[0]
    T = ancestor.shape[1]
    # Each particle has a state and action *before* resampling
    # The ancestor indicates the particle that was resampled afterwards
    # Create a bunch of line segment per timestep
    s_prev = env.task_init.expand([N_ptc, -1]).detach().cpu().numpy()
    # This is just to support drawing functionality which has now been removed
    drawing = False
    for t in range(T):
        if drawing:
            # For drawing this is really simple: the state is a full segment
            segments = state[:,t,:].reshape(state.shape[0], -1, 2)
            segments = np.concatenate([segments, np.nan*segments[:,:1,:]],axis=1)
            segments = np.concatenate([s for s in segments])
        else:
            # Create array of line segments, separated by nans
            segments = np.nan*np.ones([N_ptc*3, 2])
            segments[0::3,:] = s_prev
            s_new = state[:,t,:] #+ 0.05*np.random.randn(*s_prev.shape)
            segments[1::3,:] = s_new
            # Update previous state by grabbing next state from correct ancestors
            s_prev = s_new[ancestor[:,t], :]
        if context is None:
            # Plot the current segments
            plt.plot(segments[:,0], segments[:,1], color=[t/T]*3)    
        else:
            # Colour the segments by the module that generated them
            curr_context = np.argmax(context[:,t,:], axis=-1)
            for i in range(context.shape[-1]):
                curr_seg = np.tile(curr_context==i, [
                    env.operation_samples+1 if drawing else 3,1]).T.flatten()
                plt.plot(segments[curr_seg,0], segments[curr_seg,1], ':', color=tab10(i)) 
                            
def get_module_actions(model, env): 
    # Unroll each module to get the module's sequence of actions    
    max_t = max([len(a) for a in env.operations])
    model_actions = []
    model_states = []    
    data_actions = [env.action_vecs[a].cpu().numpy() if model.task == 0 else a.numpy() 
                    for a in env.operations]
    for i, rnn in enumerate(model.action_rnn):
        h = model.action_h0.unsqueeze(0)
        s = [env.task_init]
        a = []
        for t in range(max_t):
            with torch.no_grad():
                h = rnn(s[-1], h)
                a_from_h = model.action_output[i](h)
                a.append(model.get_action(a_from_h))
            s.append(env.transition(s[-1], a[-1]))
            if model.task == 2:
                s[-1] = s[-1][...,-1,:]
        model_actions.append(torch.cat(a).cpu().numpy())
        model_states.append(torch.cat(s).cpu().numpy())
    return model_actions, model_states, data_actions        
        
def get_module_order(model, overlap):
    # Use result to re-order modules so they correspond to true modules
    all_orders = list(itertools.permutations(range(model.n_modules)))
    best_score = 0
    reordered = list(all_orders[0])
    for order in [list(o) for o in all_orders]:
        score = np.sum(1*(np.diag(overlap[order])>0.9) + np.diag(overlap[order]))
        if score > best_score:
            reordered = order
            best_score = score
    return reordered

def get_module_overlap(model, env, model_actions, data_actions):
    # Calculate normalised overlap between each operation and model module
    overlap = np.zeros((model.n_modules, env.n_operations))
    for i, model_a in enumerate(model_actions):
        for j, data_a in enumerate(data_actions):
            max_t = len(env.operations[j])
            if model_a.shape[-1] == 1:
                overlap[i,j] = np.corrcoef(model_a[:max_t].squeeze(), data_a[:max_t].squeeze())[0,1]
            else:
                overlap[i,j] = np.mean(np.diag(model_a[:max_t] @ data_a[:max_t].T) 
                                       / (np.linalg.norm(model_a[:max_t], axis=-1) \
                                          * np.linalg.norm(data_a[:max_t], axis=-1))) 
    return overlap
   
def plot_module_actions(model, env, model_actions, data_actions, reordered, save_path):
    # Save a figure of the current modules
    from matplotlib import pyplot as plt
    lim = np.max(np.abs(np.concatenate([np.cumsum(a, axis=0) for a in data_actions])))
    f = plt.figure(figsize=(6,3))
    for i, o in enumerate(reordered):
        ax = plt.subplot(1, model.n_modules, i+1)
        # Get path from actions
        data_path = env.get_path(torch.tensor(data_actions[i], device=env.device), 
                            start=env.task_init).cpu()
        # Plot true module
        env.plot_task(ax, data_path, [i for _ in range(len(env.operations[i]) + 1)], plot_path=False)            
        # Plot model path on top
        model_path = env.get_path(torch.tensor(model_actions[o], device=env.device), 
                            start=env.task_init).cpu()
        # Plot model module
        env.plot_path(ax, model_path)
        ax.axis('scaled')
        ax.set_xlim([-1.1*lim, 1.1*lim])
        ax.set_ylim([-1.1*lim, 1.1*lim])
        if model.task == 2:
            ax.invert_yaxis()
    plt.tight_layout()
    plt.savefig(save_path)            
    plt.close(f)
                    
def log_module_actions(model, env, save_path=None):
    # Wrapper function to log current state of modules
    model_actions, model_states, data_actions = get_module_actions(model, env)
    overlap = get_module_overlap(model, env, model_actions, data_actions)
    if save_path is not None:
        reordered = get_module_order(model, overlap)
        plot_module_actions(model, env, model_actions, data_actions, reordered, save_path)
    return overlap
    
def load_tensorboard(log_dir, variables):
    # Create dictionary to fill
    values ={}
    # If tensorboard directory does not exist (e.g. failed run): exit
    if not os.path.isdir(os.path.join(log_dir, 'tensorboard')):
        print(f'No tensorboard directory for {log_dir}')
        return values
    # Load tensorboard reader
    event_acc = EventAccumulator(os.path.join(log_dir, 'tensorboard'))
    event_acc.Reload()
    # Get variables available for this run
    var_avail = event_acc.Tags()['scalars']
    # Loop through variables
    for var in [v for v in variables if v in var_avail]:
        # Get value of requested variable
        time, step, val = zip(*[[e.wall_time, e.step, e.value] 
                                for e in event_acc.Scalars(var)])
        # If there are multiple concatenated sequences of values
        # (which happens when you forget to delete previous runs...),
        # only include the last sequence
        start = np.where(np.diff(step)<=0)[0][-1]+1 if np.any(np.diff(step)<0) else 0
        # Store them in the dictionary for the current directory
        values[var] = np.stack([time[start:], step[start:], val[start:]], -1)
    return values

def find_seed_matches(log_dir):
    # Check if this run has multiple copies with different random seeds
    is_seed = [s[0] == 'i' and s[1:].isdigit() for s in log_dir.split('_')]
    if any(is_seed):
        # Collect values across all random seeds. First find matches
        match_dirs = glob.glob('_'.join(['i*' if i else s 
                                         for s, i in zip(log_dir.split('_'), is_seed)]))
        # Only keep those that exactly match after removing seed
        match_dirs = [m for m in match_dirs 
                      if len(log_dir.split('_')) == len(m.split('_'))
                      and all([True if i else d_i == m_i
                               for d_i, m_i, i in zip(log_dir.split('_'), m.split('_'), is_seed)])]
        # Return list of matching dirs
        return match_dirs
    else:
        return [log_dir]
        
def compare(variables, dirs=('*'), legends=None, zoom=1500, match_seed=True):
    # Run e.g. with arguments like
    # variables: ['Losses/all', 'Boundaries/Correlation']
    # dirs: ['n_hmrnn_rand_h*', 'n_hmrnn_rand_l_h*', 'n_hmrnn_rand_l_k_h*', 'n_no_context_rand*'])
    
    # Import matplotlib here to avoid cluster issues
    from matplotlib import pyplot as plt
    from matplotlib import colormaps

    # Input dirs is a list of dirs to include, allowing for wildcards/filtering
    # First compile all names that match each of those
    dirs = [glob.glob('./train/' + d) for d in dirs]
    # Then only keep ones that are directories that contain tensorboards
    dirs = [d for matches in dirs for d in matches 
            if os.path.isdir(os.path.join(d, 'tensorboard'))]
    # Remove duplicates while preserving order
    dirs = [d for i, d in enumerate(dirs) if d not in dirs[:i]]
    # Get the final dir names for those
    names = [os.path.relpath(d, './train') for d in dirs]
    # Then load the requested variables for each of them
    values = {}
    for i, (n, d) in enumerate(zip(names, dirs)):
        # Print progress
        print(f'Now loading {n}, {i} / {len(names)}')        
        # Find directories with that match but have different seeds
        match_dirs = find_seed_matches(d) if match_seed else [d]
        if len(match_dirs) > 1:
            # Then load the tensorboard for each match dir - will include self
            seed_vals = [load_tensorboard(m, variables) for m in match_dirs]
            seed_vals = [sv for sv in seed_vals if len(sv.keys())>0]            
            # Check lengths, chuck out shortest, in case there were errors
            steps = [list(v.values())[0].shape[0] for v in seed_vals]
            # Print the erronous datasets
            for md, s in zip(match_dirs, steps):
                if s < max(steps):
                    print(f'Run {md} discarded: {s} < {max(steps)}')
            seed_vals = [v for s, v in zip(steps, seed_vals) if s == max(steps)]
            # Then replace the third column (actual values) by mean across seeds, then se
            values[n] = {k: np.concatenate([
                seed_vals[0][k][:,:2], 
                np.nanmean(np.stack([vals[k][:,2] for vals in seed_vals], axis=1), axis=1)[:,None],
                (np.nanstd(np.stack([vals[k][:,2] for vals in seed_vals], axis=1), axis=1)[:,None] / np.sqrt(len(seed_vals))
                if len(seed_vals) > 0 else np.zeros_like(seed_vals[0][k][:,0]))], 
                axis=1)
                for k in seed_vals[0].keys()}
        else:
            # Create dictionary for current training run
            values[n] = load_tensorboard(d, variables)
    
    #names = [n[:-2] for n in names]
    # Assign plot colours: unique for each, and the same for anything ending on _2
    cm = colormaps.get_cmap('tab10')
    cols = {}
    for i, n in enumerate([n for n in names if n[-2:] != '_2']):
        cols[n] = cm(i)
    for n in [n for n in names if n[-2:] == '_2']:
        if n[:-2] in cols.keys():
            cols[n] = cols[n[:-2]]
        else:
            i += 1
            cols[n] = cm(i)            
    
    # Make legend entries for names, if provided
    legends = {} if legends is None else legends
    #legends = {n[:-2]: v for n, v in legends.items()}
    for name in names:
        if name not in legends.keys():
            legends[name] = name
    
    # Then plot a figure for each variable, containing all training runs
    figures = []
    for var in variables:
        #f = plt.figure(figsize=(4,4)); 
        f = plt.figure(figsize=(8,4)); 
        # Plot two subplots: one for full training, one for first 1500 iters
        for i, xmax in enumerate([None,zoom]):
            ax = plt.subplot(1,2,i+1)
            for j, (name, vals) in enumerate(values.items()):
                if var in vals.keys():
                    plt.plot(vals[var][:,1], vals[var][:,2], color=cols[name],
                             linestyle=(':' if name[-2:] == '_2' else '-'), 
                             #dashes=[2,2*(j+1)],
                             label=legends[name])
                    if vals[var].shape[1] ==4 and np.any(vals[var][:,3]):
                        plt.fill_between(vals[var][:,1], vals[var][:,2] - vals[var][:,3],
                                         vals[var][:,2] + vals[var][:,3], alpha=0.5, color=cols[name])
            if i == 0:
                plt.legend(loc='upper right', fontsize='6')
            plt.ylabel(var)
            plt.xlabel('Training iterations')
            if xmax is not None:
                plt.xlim([0, xmax])
            if var[:6] == 'Losses':
                # Loss plotting specifics: set max to median of first iteration
                limdat = np.array([l.get_ydata()[0] for l in ax.get_lines()])
                limdat = limdat[~np.isnan(limdat)]
                limdat = [1] if len(limdat) == 0 else limdat
                plt.ylim([np.min([np.min(l.get_ydata()) for l in ax.get_lines()]), np.median(limdat)])
                #plt.ylim([0, np.max([np.max(l.get_ydata()) for l in ax.get_lines()])])
            else:
                # Others are correlation accuracies, so set between -0.1 and 1.1
                plt.ylim([-0.1, 1.1])
        plt.tight_layout()
        figures.append(f)
    return figures            
        
if __name__ == '__main__':
    # This is a bit of a stupid workaround (usually you could just "train()" instead of "run.train()"), 
    # but it's required if I want to reload models module for analysing existing models
    model_type = parser.parse_args().model
    if model_type == 0:
        run.train_rule()
    elif model_type == 1:
        run.train_motor()
    elif model_type == 2:
        run.train_rule()
        
