#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 22 12:58:32 2025

@author: XXXX
"""

from matplotlib import pyplot as plt
import os
import json
import torch
import numpy as np

import run
import maze


# Set all font sizes as in https://stackoverflow.com/a/39566040
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rcParams['svg.fonttype'] = 'none'    # Make sure svgs have editable text
plt.rcParams["font.family"] = ["Arial"]  # Use consistent font; need to download Arial on Ubuntu

# Set unit size which determines the size of all figures
# Total figure size is 4x10 units
unit_size = 0.65

# Load trained model
model_dir = './train/Final/motor/maze_all_all_f_i4_v6'

# Load arguments from model_dir json file
with open(os.path.join(model_dir, 'args.json')) as f:
    args = json.load(f)

# Overwrite some arguments with specific values for plotting
args['do_test']=False
args['feedback_density'] = 1    

# This is a bit ugly, but avoids using a different model for training and analysis:
# Replace the loaded "models" module by the source file saved in training - if it exists
run.load_module((model_dir + '/source') if os.path.isdir(model_dir + '/source') else None)  
    
# Load the environment
env = maze.ContinuousMaze(args)

### PANEL A ###

# Plot a bunch of tasks
env.plot_tasks(5, 1, offset=1)
fig_a = plt.gcf()
fig_a.set_size_inches(unit_size*1, unit_size*4)

plt.subplots_adjust(
top=0.98,
bottom=0.02,
left=0.1,
right=0.9,
hspace=0.065,
wspace=0.2)

plt.savefig("figures/4a.svg", format="svg")
plt.savefig("figures/4a.png", format="png")

### PANEL B ###

# Get model
model = run.get_model(args, env)

# Keep model on cpu
device = 'cpu'
model.to(device)
model.set_device(device)
env.set_device(device)
# Load model trained parameters    
model.load_state_dict(torch.load(os.path.join(model_dir, 'model.zip'),
                                 map_location=torch.device(device), weights_only=True))   

# Set model to evaluation mode
model.eval()
model.set_hard_sampling(True)

# Wrapper function to log current state of modules
model_actions, model_states, data_actions = run.get_module_actions(model, env)
overlap = run.get_module_overlap(model, env, model_actions, data_actions)
reordered = run.get_module_order(model, overlap)

# Then plot the true operation and the learned trajectory
fig_b = plt.figure(figsize=(unit_size*3, unit_size*2))
for i, o in enumerate(reordered):
    ax = plt.subplot(2, int(model.n_modules/2), i+1)
    # Get path from actions
    data_path = env.get_path(torch.tensor(data_actions[i], device=env.device), 
                        start=env.task_init).cpu()
    # Plot true module
    env.plot_task(ax, data_path, [i for _ in range(len(env.operations[i]) + 1)], plot_path=False)            
    # Plot model path on top
    model_path = env.get_path(torch.tensor(model_actions[o], device=env.device), 
                        start=env.task_init).cpu()
    # Plot model module
    env.plot_path(ax, model_path)
    
plt.subplots_adjust(
top=0.98,
bottom=0.02,
left=0.02,
right=0.98,
hspace=0.1,
wspace=0.1)
    
plt.savefig("figures/4b.svg", format="svg")
plt.savefig("figures/4b.png", format="png")

### PANEL C ### 

# Now run the model on *all* training tasks
# Run one example of each task
tasks = np.arange(env.task_start, env.task_stop)
model_in = env.get_model_input(tasks)
with torch.no_grad():
    model_out = model.nll(env, model_in['context'], model_in['output'], N=50)
ancestors = model_out['ancestor'].detach().cpu().numpy()
activations = model_out['activation'].detach().cpu().numpy()
predictions = model_out['context_out'].detach().cpu().numpy()

# Collect state history for each particle at each timestep
# Dimensions particles x timesteeps x modules x timesteps
# The second dimension indexes the history, the final dimension the timestep
# So history of particle at timestep for task: history[task][particle, :, :, timestep]
T = ancestors[0].shape[1]
history = []
for task_activations, task_ancestors in zip(activations, ancestors):
    task_history = np.zeros(list(task_activations.shape) + [T])    
    for t in range(1,T):
        # Copy over history *up to* previous timestep from the ancestor of current particle
        task_history[:, :t, :, t] = task_history[task_ancestors[:, t-1], :t, :, t-1]
        # Copy over resampled activation from previous timestep for current particle
        task_history[:, t, :, t] = task_activations[:, t-1, :]
    history.append(task_history)
   
# Create figure to plot time-dependent transition matrix
fig_c = plt.figure(figsize=(unit_size*3,unit_size*2)); 
# Set pattern to look for
p = [0, 1] # p = [0,1]: find pattern after switch. p = [1]: find pattern everywhere
for col in range(5):
    # Find average next step given current pattern
    next_steps_data = np.zeros((env.n_operations, env.n_operations))
    next_steps_model = np.zeros((env.n_operations, env.n_operations))
    for i in range(env.n_operations): 
        # Collect transitions in ground truth context of data
        next_step_data = []
        for c in env.data_context[env.task_start:env.task_stop].numpy():
            zero_padded = np.concatenate([np.zeros((1, env.n_operations)), c], axis=0)
            for t in range(len(p), len(zero_padded-1)):
                if np.all(zero_padded[(t-len(p)):t, i] == p):
                    if np.any(zero_padded[t]):
                        next_step_data.append(zero_padded[t])
        next_steps_data[i] = np.mean(np.stack(next_step_data,0),0) if len(next_step_data) > 0 else 0
        # Collect transitions in model output
        next_step_model = []
        for task_history, task_predictions in zip(history, predictions):
            for particle_history, particle_predictions in zip(task_history, task_predictions):
                # Reorder modules
                particle_history = particle_history[:, reordered, :]
                particle_predictions = particle_predictions[:, reordered]
                # Binarise history
                particle_bin = 1.0*((particle_history == np.max(particle_history, axis=1)[:,None,:])
                                    & (particle_history > 0))
                # Collect predictions for matching histories at each time step
                for t in range(len(p), len(particle_history)):
                    if np.all(particle_bin[(t-len(p)):t, i, t-1] == p):
                        if np.any(particle_predictions[t-1]):
                            next_step_model.append(particle_predictions[t-1])
        next_steps_model[i] = np.mean(np.stack(next_step_model,0),0) if len(next_step_model) > 0 else 0

    # Append another step to the pattern
    p = p + [1]
    # Plot ground truth transitions
    plt.subplot(2,5,col+1)
    plt.imshow(next_steps_data)        
    if col == 0:
        plt.title('True transitions', ha='left')            
    plt.xticks([])
    plt.yticks([])           
    # Plot model transitions
    plt.subplot(2,5,5+col+1)
    plt.imshow(next_steps_model)        
    if col == 0:
        plt.title('Learned transitions', ha='left')            
    plt.xticks([])
    plt.yticks([])    
plt.subplots_adjust(top=0.85,
bottom=0.025,
left=0.02,
right=0.985,
hspace=0.485,
wspace=0.095)
plt.savefig("figures/4c.svg", format="svg")
plt.savefig("figures/4c.png", format="png")

### PANEL D ###

# Switch to test tasks
args['do_test']=True
env = maze.ContinuousMaze(args)
env.set_device(device)

# Run one example of each task
tasks = np.arange(env.task_start, env.task_stop)
model_in = env.get_model_input(tasks)
with torch.no_grad():
    model_out = model.nll(env, model_in['context'], model_in['output'], N=50)
trace_back = model.trace_back(model_out, N=1, sample=False)
# Squeeze particles in backward trace (there's only one)
trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}

# Choose one task to plot
t = 1; task = tasks[t]
true_context = env.data_context[task].detach().numpy()    
T = int(np.sum(true_context))

# Plot the task, inferred path, and particle trajectories
fig_d = plt.figure(figsize=(unit_size*3, unit_size*4))
# Create axes
gs = fig_d.add_gridspec(2,2,height_ratios=[3,1])
ax0 = fig_d.add_subplot(gs[0,:])
ax1 = fig_d.add_subplot(gs[1,0])
ax2 = fig_d.add_subplot(gs[1,1])
# Get feedback signal, if there's sparse feedback
feedback = (torch.sum(model_in['context'][t][:T],-1)==1).numpy()
# Plot the task trajectory with colour indicating the true module
plt.sca(ax0)
env.plot_task(ax0, env.data_coord[task][:T], 
              torch.argmax(env.data_context[task], -1).numpy()[:T], 
              feedback=feedback, plot_path=False)  
# Plot the particle hypotheses: proposed steps before feedback
run.plot_particles(ax0, env, model_out['pred_state'][t][:,:T,:].numpy(), 
               model_out['ancestor'][t][:,:T].numpy(), model_out['context_out'][t][:,:T,reordered].numpy())           
# Plot the smoothing trajectory: highest likelihood trajectory after task completion
# Collect path given by output
curr_path = torch.cat([env.task_init, trace_back['pred_state'][t][:T]]).numpy()
# Plot the full path
ax0.plot(curr_path[:,0], curr_path[:,1], 'w-', linewidth=3)
# Plot the path where each action is coloured by the module that generated it
env.plot_path(ax0, curr_path, np.argmax(trace_back['activation'][t][:T,reordered].numpy(),axis=-1).squeeze())
plt.title('Full feedback');
# Plot provided contexts
plt.sca(ax1)
plt.imshow(model_in['context'][t][:T].numpy().transpose(), interpolation='none', aspect='auto');
plt.scatter(np.arange(T), np.argmax(true_context,axis=-1)[:T], color=[1,0,0], s=1)    
plt.title('True skill')
plt.xticks([0, T-1]); plt.yticks([]);
plt.xlabel('Time', labelpad=-6);
# Plot context posterior
plt.sca(ax2)
plt.imshow(np.mean(model_out['activation'][t].numpy(), axis=0)[:T].transpose()[reordered], interpolation='none', aspect='auto', vmin=0, vmax=1);
plt.scatter(np.arange(T), np.argmax(trace_back['activation'][t][:,reordered],axis=-1)[:T], color=[1,0,0], s=1)    
plt.title('Module posterior')
plt.xticks([0, T-1]); plt.yticks([]);
plt.xlabel('Time', labelpad=-6);

plt.subplots_adjust(
top=0.92,
bottom=0.093,
left=0.037,
right=0.938,
hspace=0.25,
wspace=0.086)

plt.savefig("figures/4d.svg", format="svg")
plt.savefig("figures/4d.png", format="png")

### PANEL E ###

# First reset seed so this is all reproducible
np.random.seed(0)
torch.manual_seed(0)

# Repeat for sparse feedback; set specific feedback pattern for clean point
feedback = torch.tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0], dtype=torch.float)

# Run one example of each task
tasks = np.arange(env.task_start, env.task_stop)
model_in = env.get_model_input(tasks)
model_in['context'] = model_in['context'] * torch.stack([feedback for _ in range(env.n_operations)], axis=-1).unsqueeze(0)
with torch.no_grad():
    model_out = model.nll(env, model_in['context'], model_in['output'], N=50)
trace_back = model.trace_back(model_out, N=1, sample=False)
# Squeeze particles in backward trace (there's only one)
trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}

# Choose one task to plot
t = 1; task = tasks[t]
true_context = env.data_context[task].detach().numpy()    
T = int(np.sum(true_context))

# Plot the task, inferred path, and particle trajectories
fig_e = plt.figure(figsize=(unit_size*3, unit_size*4))
# Create axes
gs = fig_e.add_gridspec(2,2,height_ratios=[3,1])
ax0 = fig_e.add_subplot(gs[0,:])
ax1 = fig_e.add_subplot(gs[1,0])
ax2 = fig_e.add_subplot(gs[1,1])
# Get feedback signal, if there's sparse feedback
feedback = (torch.sum(model_in['context'][t][:T],-1)==1).numpy()
# Plot the task trajectory with colour indicating the true module
plt.sca(ax0)
env.plot_task(ax0, env.data_coord[task][:T], 
              torch.argmax(env.data_context[task], -1).numpy()[:T], 
              feedback=feedback, plot_path=False)  
# Plot the particle hypotheses: proposed steps before feedback
run.plot_particles(ax0, env, model_out['pred_state'][t][:,:T,:].numpy(), 
               model_out['ancestor'][t][:,:T].numpy(), model_out['context_out'][t][:,:T,reordered].numpy())           
# Plot the smoothing trajectory: highest likelihood trajectory after task completion
# Collect path given by output
curr_path = torch.cat([env.task_init, trace_back['pred_state'][t][:T]]).numpy()
# Plot the full path
ax0.plot(curr_path[:,0], curr_path[:,1], 'w-', linewidth=3)
# Plot the path where each action is coloured by the module that generated it
env.plot_path(ax0, curr_path, np.argmax(trace_back['activation'][t][:T,reordered].numpy(),axis=-1).squeeze())
plt.title('Sparse feedback');
# Plot provided contexts
plt.sca(ax1)
plt.imshow(model_in['context'][t][:T].numpy().transpose(), interpolation='none', aspect='auto');
plt.scatter(np.arange(T), np.argmax(true_context,axis=-1)[:T], color=[1,0,0], s=1)    
plt.title('True skill')
plt.xticks([0, T-1]); plt.yticks([]);
plt.xlabel('Time', labelpad=-6);
# Plot context posterior
plt.sca(ax2)
plt.imshow(np.mean(model_out['activation'][t].numpy(), axis=0)[:T].transpose()[reordered], interpolation='none', aspect='auto', vmin=0, vmax=1);
plt.scatter(np.arange(T), np.argmax(trace_back['activation'][t][:,reordered],axis=-1)[:T], color=[1,0,0], s=1)    
plt.title('Module posterior')
plt.xticks([0, T-1]); plt.yticks([]);
plt.xlabel('Time', labelpad=-6);

plt.subplots_adjust(
top=0.92,
bottom=0.093,
left=0.037,
right=0.938,
hspace=0.25,
wspace=0.086)

plt.savefig("figures/4e.svg", format="svg")
plt.savefig("figures/4e.png", format="png")
