#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Aug 19 17:58:19 2025

@author: XXXX
"""

from matplotlib import pyplot as plt
import os
import json
import torch
import numpy as np

import run
import data


# Set all font sizes as in https://stackoverflow.com/a/39566040
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rcParams['svg.fonttype'] = 'none'    # Make sure svgs have editable text
plt.rcParams["font.family"] = ["Arial"]  # Use consistent font; need to download Arial on Ubuntu

# Set unit size which determines the size of all figures
# Total figure size is 4x10 units
unit_size = 0.65

## PANEL A ###

# Set list of seeds to include
base_dir = './train/Final/rule_full'
seeds = [0, 1, 2, 3, 4]
model_dirs = [base_dir + f'/rule_all_i{i}_v6' for i in seeds]
# Get the final dir names for those
names = [os.path.relpath(d, './train') for d in model_dirs]

# Determine which variables to plot
variables = ['Losses/likelihood', 'Losses/best', 'Accuracies/Modules', 'Accuracies/Activations']
var_names = ['Loss', 'Error', 'Modules', 'Gating']
# Then load the requested variables for each of them
runs = {}
for i, (n, d) in enumerate(zip(names, model_dirs)):
    # Print progress
    print(f'Now loading {n}, {i} / {len(names)}')        
    runs[n] = run.load_tensorboard(d, variables)
# Plot the results: individual runs in grey, mean in black
fig = plt.figure(figsize=(unit_size*2,unit_size*4));
for i, (n, v) in enumerate(zip(var_names, variables)):
    plt.subplot(4,1,i+1)
    curr_xs = np.stack([val[v][:,1] for val in runs.values()])
    curr_ys = np.stack([val[v][:,2] for val in runs.values()])
    plt.plot(curr_xs.T, curr_ys.T, color=[0.7, 0.7, 0.7], linewidth=0.5)
    plt.plot(curr_xs[0], np.mean(curr_ys, axis=0), color=[0, 0, 0], linewidth=1)
    if i == 3:
        curr_xticks = np.linspace(curr_xs[0][0], curr_xs[0][-1], 2)
        plt.xticks(curr_xticks, labels=[f'{x/1000:.0f}k' for x in curr_xticks])
        plt.xlabel('Iters', labelpad=-6)
    else:
        plt.xticks([])
    if i == 0:
        curr_ylim = [np.min(curr_ys)*1.1, np.max(curr_ys)*1.1]
        curr_yticks = [np.min(curr_ys), np.max(curr_ys)]
        plt.ylabel('NLL', rotation=0, labelpad=-12, va='top')        
    if i == 1:
        curr_ylim = [-np.max(curr_ys)*0.1, np.max(curr_ys)*1.1]
        curr_yticks = [0, np.max(curr_ys)]        
        plt.ylabel('MSE', rotation=0, labelpad=-12, va='top')        
    if i > 1:
        curr_ylim = [-0.1, 1.1]
        curr_yticks = [0, 1]        
        plt.ylabel('Corr', rotation=0, labelpad=-8, va='top')        
    plt.ylim(curr_ylim)
    plt.yticks(curr_yticks, labels=[f'{y:.1f}' for y in curr_yticks])
    plt.title(n)
plt.subplots_adjust(top=0.92,
bottom=0.085,
left=0.27,
right=0.96,
hspace=0.51,
wspace=0.2)
plt.savefig("figures/2a.svg", format="svg")
plt.savefig("figures/2a.png", format="png")
        
### PANEL B ###

# Select specific seed to plot
plot_seed = 0
# Plot learning curves
model_dir = model_dirs[seeds[plot_seed]]

# Load arguments from model_dir json file
with open(os.path.join(model_dir, 'args.json')) as f:
    args = json.load(f)

# Frst plot training tasks
args['do_test']=False

# This is a bit ugly, but avoids using a different model for training and analysis:
# Replace the loaded "models" module by the source file saved in training - if it exists
run.load_module((model_dir + '/source') if os.path.isdir(model_dir + '/source') else None)  

# Prepare a dataset, defined on a set of tasks
args['data_dir'] = os.path.join(
    os.getcwd(), os.path.basename(os.path.normpath(args['data_dir'])))

# Specify data object to use
dataset = data.TaskSequenceDataset(args, task_samples=1e4)

# Get model
model = run.get_model(args, dataset)

# Keep model on cpu
device = 'cpu'
model.to(device)
model.set_device(device)
# Load model trained parameters    
model.load_state_dict(torch.load(os.path.join(model_dir, 'model.zip'),
                                 map_location=torch.device(device), weights_only=True))    
# Set model to evaluation mode
model.eval()
model.set_hard_sampling(True)

# Plot operation learned by each module, given a previous input of zeros and a current input of a 1
result = torch.concatenate([model.action_output(rnn(torch.zeros((1,dataset.n_dims_in)), 
                                              rnn(torch.eye(dataset.n_dims_in)[0].unsqueeze(0), model.action_h0.unsqueeze(0))))
                      for rnn in model.action_rnn])
# Use result to re-order modules, so they correspond to true modules
reordered = np.argsort(np.argmax(result.detach().cpu().numpy(),-1))
   
# Plot the full matrix for each operation
fig = plt.figure(figsize=(unit_size*4,unit_size*2))
for i, o in enumerate(reordered):
    # Select current module
    rnn = model.action_rnn[o]
    # Plot ground truth operation
    plt.subplot(2, model.n_modules, i + 1)
    plt.imshow(np.stack([np.roll(v, i) for v in np.eye(dataset.n_dims_in)]))
    plt.xticks([])
    plt.yticks([])
    if i == 0:
        plt.title('True operation', ha='left')    
    # Plot matrix operation for this module
    plt.subplot(2, model.n_modules, model.n_modules + i + 1)
    result = torch.concatenate([model.action_output(rnn(torch.zeros((1,dataset.n_dims_in)), 
                                                  rnn(input_vec.unsqueeze(0), model.action_h0.unsqueeze(0))))
                          for input_vec in torch.eye(dataset.n_dims_in)])
    plt.imshow(result.detach().cpu().numpy())
    if i==0:
        plt.title('Learned module', ha='left')    
    plt.xticks([])
    plt.yticks([])
plt.subplots_adjust(top=0.89,
bottom=0.0,
left=0.02,
right=0.985,
hspace=0.2,
wspace=0.095)
plt.savefig("figures/2b.svg", format="svg")
plt.savefig("figures/2b.png", format="png")

### PANEL C ###

# Now run the model on *all* training tasks
with torch.no_grad():
    model_in = dataset.get_model_input(np.arange(dataset.task_start, dataset.task_stop))
    model_out = model.nll(model_in[0], model_in[1], model_in[3],N=100)
ancestors = model_out['ancestor'].detach().cpu().numpy()
activations = model_out['activation'].detach().cpu().numpy()
predictions = model_out['context_out'].detach().cpu().numpy()

# Collect state history for each particle at each timestep
# Dimensions particles x timesteeps x modules x timesteps
# The second dimension indexes the history, the final dimension the timestep
# So history of particle at timestep for task: history[task][particle, :, :, timestep]
T = ancestors[0].shape[1]
history = []
for task_activations, task_ancestors in zip(activations, ancestors):
    task_history = np.zeros(list(task_activations.shape) + [T])    
    for t in range(1,T):
        # Copy over history *up to* previous timestep from the ancestor of current particle
        task_history[:, :t, :, t] = task_history[task_ancestors[:, t-1], :t, :, t-1]
        # Copy over resampled activation from previous timestep for current particle
        task_history[:, t, :, t] = task_activations[:, t-1, :]
    history.append(task_history)
   
# Create figure to plot time-dependent transition matrix
fig = plt.figure(figsize=(unit_size*4,unit_size*2)); 
# Set pattern to look for
p = [0, 1] # p = [0,1]: find pattern after switch. p = [1]: find pattern everywhere
for col in range(5):
    # Find average next step given current pattern
    next_steps_data = np.zeros((dataset.n_operations, dataset.n_operations))
    next_steps_model = np.zeros((dataset.n_operations, dataset.n_operations))
    for i in range(dataset.n_operations): 
        # Collect transitions in ground truth context of data
        next_step_data = []
        for c in dataset.task_contexts[dataset.task_start:dataset.task_stop].numpy():
            zero_padded = np.concatenate([np.zeros((1, dataset.n_operations)), c], axis=0)
            for t in range(len(p), len(zero_padded-1)):
                if np.all(zero_padded[(t-len(p)):t, i] == p):
                    if np.any(zero_padded[t]):
                        next_step_data.append(zero_padded[t])
        next_steps_data[i] = np.mean(np.stack(next_step_data,0),0) if len(next_step_data) > 0 else 0
        # Collect transitions in model output
        next_step_model = []
        for task_history, task_predictions in zip(history, predictions):
            for particle_history, particle_predictions in zip(task_history, task_predictions):
                # Reorder modules
                particle_history = particle_history[:, reordered, :]
                particle_predictions = particle_predictions[:, reordered]
                # Binarise history
                particle_bin = 1.0*((particle_history == np.max(particle_history, axis=1)[:,None,:])
                                    & (particle_history > 0))
                # Collect predictions for matching histories at each time step
                for t in range(len(p), len(particle_history)):
                    if np.all(particle_bin[(t-len(p)):t, i, t-1] == p):
                        if np.any(particle_predictions[t-1]):
                            next_step_model.append(particle_predictions[t-1])
        next_steps_model[i] = np.mean(np.stack(next_step_model,0),0) if len(next_step_model) > 0 else 0

    # Append another step to the pattern
    p = p + [1]
    # Plot ground truth transitions
    plt.subplot(2,5,col+1)
    plt.imshow(next_steps_data)        
    if col == 0:
        plt.title('True transitions', ha='left')            
    plt.xticks([])
    plt.yticks([])           
    # Plot model transitions
    plt.subplot(2,5,5+col+1)
    plt.imshow(next_steps_model)        
    if col == 0:
        plt.title('Learned transitions', ha='left')            
    plt.xticks([])
    plt.yticks([])    
plt.subplots_adjust(top=0.85,
bottom=0.025,
left=0.02,
right=0.985,
hspace=0.485,
wspace=0.095)
plt.savefig("figures/2c.svg", format="svg")
plt.savefig("figures/2c.png", format="png")
    
### PANEL D ###

# Then plot individual examples for test tasks
args['do_test']=True

# Specify data object to use
dataset = data.TaskSequenceDataset(args, task_samples=1e4)

# Now run the model on test tasks
with torch.no_grad():
    model_in = [dataset.__getitem__(int(t*dataset.task_samples)+1) for t in range(dataset.task_stop - dataset.task_start)]
    model_in = {k: torch.stack([m[k] for m in model_in]) for k in model_in[0].keys() if k not in ['id']}
    model_out = model.nll(model_in['input'], model_in['context'], model_in['output'], N=100)
    trace_back = model.trace_back(model_out, N=1, sample=False)
# Average particles in forward pass   
model_out = {k: v if k in  ['ancestor', 'particle'] else torch.mean(v, 1) for k, v in model_out.items()}
# Squeeze particles in backward trace (there's only one)
trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}
   
# Grab a nice test task (that switches multiple times)
test_tasks = [1]

# Plot performance on individual tasks
cols = len(test_tasks)
rows = 4
fig = plt.figure(figsize=(unit_size*2,unit_size*4))
for t, task in enumerate(test_tasks):
    # Get length of this current task: sum of contexts in true data
    true_context = dataset.task_contexts[task + dataset.task_start].detach().numpy()    
    T = int(np.sum(true_context))
    # Plot target output    
    plt.subplot(rows, cols, 0*cols + t + 1);
    plt.imshow(model_in['output'][task][:T].numpy().transpose(), interpolation='none', aspect='auto');
    plt.xticks([]); plt.yticks([]);
    plt.ylabel('$y_t$', labelpad=-0.2)
    plt.title('True output')
    # Plot inferred output
    plt.subplot(rows, cols, 1*cols + t + 1);
    plt.imshow(trace_back['action_out'][task][:T].numpy().transpose(), interpolation='none', aspect='auto');
    plt.xticks([]); plt.yticks([]);
    plt.ylabel('$\mu_t$', labelpad=-0.2)
    plt.title('Model output')
    # Plot provided contexts
    plt.subplot(rows, cols, 2*cols + t + 1);
    plt.imshow(model_in['context'][task][:T].numpy().transpose(), interpolation='none', aspect='auto');
    plt.xticks([]); plt.yticks([]);
    plt.ylabel('Operations', labelpad=-0.2)
    # Plot true contexts
    plt.scatter(np.arange(T), np.argmax(true_context,axis=-1)[:T], color=[1,0,0], s=1)    
    plt.title('True operation')    
    # Plot filtering posterior
    plt.subplot(rows, cols, 3*cols + t + 1);
    plt.imshow(model_out['activation'][task][:T].numpy().transpose()[reordered], interpolation='none', aspect='auto');
    plt.xticks([0, T-1]); plt.yticks([]);
    plt.xlabel('Time', labelpad=-6);
    plt.ylabel('Modules', labelpad=-0.2)
    # Plot inferred contexts
    plt.scatter(np.arange(T), np.argmax(trace_back['activation'][task][:,reordered],axis=-1)[:T], color=[1,0,0], s=1)    
    plt.title('Module posterior')
plt.subplots_adjust(top=0.929,
bottom=0.081,
left=0.1,
right=0.958,
hspace=0.6,
wspace=0.2)
plt.savefig("figures/2d.svg", format="svg")
plt.savefig("figures/2d.png", format="png")

### PANEL E ###

# Repeat the exact same thing, but now with partial feedback
# Set a fixed feedback sequence that provides a clean demonstration
feedback = torch.tensor([0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], dtype=torch.float)

# Specify data object to use
dataset = data.TaskSequenceDataset(args, task_samples=1e4)

# Now run the model on test tasks
with torch.no_grad():
    model_in = [dataset.__getitem__(int(t*dataset.task_samples)+1) for t in range(dataset.task_stop - dataset.task_start)]
    model_in = {k: torch.stack([m[k] for m in model_in]) for k in model_in[0].keys() if k not in ['id']}
    model_in['context'] = model_in['context'] * torch.stack([feedback for _ in range(dataset.n_operations)], axis=-1).unsqueeze(0)
    model_out = model.nll(model_in['input'], model_in['context'], model_in['output'], N=100)
    trace_back = model.trace_back(model_out, N=1, sample=False)
# Average particles in forward pass   
model_out = {k: v if k in  ['ancestor', 'particle'] else torch.mean(v, 1) for k, v in model_out.items()}
# Squeeze particles in backward trace (there's only one)
trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}
      
# Grab a nice test task (that switches multiple times)
test_tasks = [1]

# Plot performance on individual tasks
cols = len(test_tasks)
rows = 4
fig = plt.figure(figsize=(unit_size*2,unit_size*4))
for t, task in enumerate(test_tasks):
    # Get length of this current task: sum of contexts in true data
    true_context = dataset.task_contexts[task + dataset.task_start].detach().numpy()    
    T = int(np.sum(true_context))
    # Plot target output    
    plt.subplot(rows, cols, 0*cols + t + 1);
    plt.imshow(model_in['output'][task][:T].numpy().transpose(), interpolation='none', aspect='auto');
    plt.xticks([]); plt.yticks([]);
    plt.ylabel('$y_t$', labelpad=-0.2)
    plt.title('True output')
    # Plot inferred output
    plt.subplot(rows, cols, 1*cols + t + 1);
    plt.imshow(trace_back['action_out'][task][:T].numpy().transpose(), interpolation='none', aspect='auto');
    plt.xticks([]); plt.yticks([]);
    plt.ylabel('$\mu_t$', labelpad=-0.2)
    plt.title('Model output')
    # Plot provided contexts
    plt.subplot(rows, cols, 2*cols + t + 1);
    plt.imshow(model_in['context'][task][:T].numpy().transpose(), interpolation='none', aspect='auto');
    plt.xticks([]); plt.yticks([]);
    plt.ylabel('Operations', labelpad=-0.2)
    # Plot true contexts
    plt.scatter(np.arange(T), np.argmax(true_context,axis=-1)[:T], color=[1,0,0], s=1)    
    plt.title('True operation')    
    # Plot filtering posterior
    plt.subplot(rows, cols, 3*cols + t + 1);
    plt.imshow(model_out['activation'][task][:T].numpy().transpose()[reordered], interpolation='none', aspect='auto');
    plt.xticks([0, T-1]); plt.yticks([]);
    plt.xlabel('Time', labelpad=-6);
    plt.ylabel('Modules', labelpad=-0.2)
    # Plot inferred contexts
    plt.scatter(np.arange(T), np.argmax(trace_back['activation'][task][:,reordered],axis=-1)[:T], color=[1,0,0], s=1)    
    plt.title('Module posterior')
plt.subplots_adjust(top=0.929,
bottom=0.081,
left=0.1,
right=0.958,
hspace=0.6,
wspace=0.2)
plt.savefig("figures/2e.svg", format="svg")
plt.savefig("figures/2e.png", format="png")

### PANEL F ###

# Now create one huge task, 3x longer than the previous ones
# First reset seed so this is all reproducible
np.random.seed(0)
torch.manual_seed(0)

# Create this new megatask
N_ops = 12
task_ops = np.random.randint(0,dataset.n_operations,N_ops)
task = [dataset.operations[o] for o in task_ops]
task_context = torch.stack([torch.eye(dataset.n_operations)[o] for o in task_ops for _ in range(dataset.durations[o])])
steps = task_context.shape[0]

# Then manipulate the current dataset object to use this task
dataset.n_steps = steps
dataset.tasks[0] = task
dataset.task_ops[0] = task_ops
dataset.task_contexts = task_context.unsqueeze(0)
model.n_steps = steps

# Then generate data from this task (slightly ugly: repeat task twice to get dimension right)
feedback = torch.tensor(np.random.rand(steps)<0.4, dtype=torch.float)
input_data, input_context, input_task, output_data = dataset.get_model_input([0,0])
model_in = {'input': input_data, 'output': output_data, 'context': input_context, 'task': input_task}

# Now run the model on this task
with torch.no_grad():
    model_in['context'] = model_in['context'] * torch.stack([feedback for _ in range(dataset.n_operations)], axis=-1).unsqueeze(0)
    model_out = model.nll(model_in['input'], model_in['context'], model_in['output'], N=250)
    trace_back = model.trace_back(model_out, N=1, sample=False)
# Average particles in forward pass   
model_out = {k: v if k in  ['ancestor', 'particle'] else torch.mean(v, 1) for k, v in model_out.items()}
# Squeeze particles in backward trace (there's only one)
trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}

# Plot result
rows = 2
t = 0
fig = plt.figure(figsize=(unit_size*10,unit_size*2))
# Get length of this current task: sum of contexts in true data
true_context = dataset.task_contexts[0].detach().numpy()    
T = int(np.sum(true_context))
# Plot provided contexts
plt.subplot(2, 1, 1);
plt.imshow(model_in['context'][t][:T].numpy().transpose(), interpolation='none', aspect='auto');
plt.xticks([]); plt.yticks([]);
plt.ylabel('Operations', labelpad=-0.2)
# Plot true contexts
plt.scatter(np.arange(T), np.argmax(true_context,axis=-1)[:T], color=[1,0,0], s=1)    
plt.title('Test task that is four times as long as all training tasks')
# Plot filtering posterior
plt.subplot(2, 1, 2);
plt.imshow(model_out['activation'][t][:T].numpy().transpose()[reordered], interpolation='none', aspect='auto');
plt.xticks([0, T-1]); plt.yticks([]);
plt.xlabel('Time', labelpad=-6);
plt.ylabel('Modules', labelpad=-0.2)    
# Plot inferred contexts
plt.scatter(np.arange(T), np.argmax(trace_back['activation'][t][:,reordered],axis=-1)[:T], color=[1,0,0], s=1)    
plt.subplots_adjust(top=0.87,
bottom=0.155,
left=0.02,
right=0.99,
hspace=0.1,
wspace=0.2)
plt.savefig("figures/2f.svg", format="svg")
plt.savefig("figures/2f.png", format="png")