#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Aug 21 16:22:19 2025

@author: XXXX
"""

from matplotlib import pyplot as plt
import os
import json
import torch
import numpy as np

import run
import data


# Set all font sizes as in https://stackoverflow.com/a/39566040
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rcParams['svg.fonttype'] = 'none'    # Make sure svgs have editable text
plt.rcParams["font.family"] = ["Arial"]  # Use consistent font; need to download Arial on Ubuntu

# Set unit size which determines the size of all figures
# Total figure size is 4x10 units
unit_size = 0.65

# Collect data across all model runs
results = []

### PANEL A, B ###

base_dirs = ['./train/Final/rule_control', './train/Final/rule_id']
for base_dir in base_dirs:
    seeds = [0, 1, 2, 3, 4]
    model_dirs = [base_dir + f'/ctrl_all_all_f_i{i}_v6' for i in seeds]
    
    # Collect the MSE for each task for each model
    train_error = []
    test_error = []
    
    # Run through models and get their performance
    for model_dir in model_dirs:    
        # Load arguments from model_dir json file
        with open(os.path.join(model_dir, 'args.json')) as f:
            args = json.load(f)
        
        # Frst plot training tasks
        args['do_test']=False
        
        # Replace the loaded "models" module by the source file saved in training - if it exists
        run.load_module(model_dir + '/source')
        
        # Prepare a dataset, defined on a set of tasks
        args['data_dir'] = os.path.join(
            os.getcwd(), os.path.basename(os.path.normpath(args['data_dir'])))
        
        # Specify data object to use
        dataset = data.TaskSequenceDataset(args, task_samples=1e4)
        
        # Get model
        model = run.get_model(args, dataset)
        
        # Keep model on cpu
        device = 'cpu'
        model.to(device)
        model.set_device(device)
        # Load model trained parameters    
        model.load_state_dict(torch.load(os.path.join(model_dir, 'model.zip'),
                                         map_location=torch.device(device), weights_only=True))    
        # Set model to evaluation mode
        model.eval()
        
        # Now run the model on train tasks
        with torch.no_grad():
            model_in = [dataset.__getitem__(int(t*dataset.task_samples)+1) for t in range(dataset.task_stop - dataset.task_start)]
            model_in = {k: torch.stack([m[k] for m in model_in]) for k in model_in[0].keys() if k not in ['id']}
            model_out = model(model_in['input'], model_in['task'])
            
        # Calculate error per task
        valid_loss = torch.sum(model_in['context'], -1) > 0
        loss_mat = np.sum(((model_in['output'] - model_out['action_out'])*valid_loss.unsqueeze(-1)).numpy()**2,axis=-1)
        task_loss = np.sum(loss_mat, -1) / np.sum(valid_loss.numpy(), -1)
    
        # Add task loss to errors
        train_error.append(task_loss)
        
        # Then repeat the same for test tasks
        args['do_test']=True
        
        # Specify data object to use
        dataset = data.TaskSequenceDataset(args, task_samples=1e4)
        
        # Now run the model on test tasks
        with torch.no_grad():
            model_in = [dataset.__getitem__(int(t*dataset.task_samples)+1) for t in range(dataset.task_stop - dataset.task_start)]
            model_in = {k: torch.stack([m[k] for m in model_in]) for k in model_in[0].keys() if k not in ['id']}
            model_out = model(model_in['input'], model_in['task'])
            
        # Calculate error per task
        valid_loss = torch.sum(model_in['context'], -1) > 0
        loss_mat = np.sum(((model_in['output'] - model_out['action_out'])*valid_loss.unsqueeze(-1)).numpy()**2,axis=-1)
        task_loss = np.sum(loss_mat, -1) / np.sum(valid_loss.numpy(), -1)
    
        # Add task loss to errors
        test_error.append(task_loss)
        
    # Stack errors into big matrix
    train_error = np.stack(train_error)
    test_error = np.stack(test_error)
    
    # And append to data
    results.append([train_error, test_error])
        
### PANEL C, D ###

base_dirs = ['./train/Final/rule_flat', './train/Final/rule_full']
for base_dir in base_dirs:
    seeds = [0, 1, 2, 3, 4]
    model_dirs = [base_dir + f'/hmm_all_all_f_i{i}_v6' for i in seeds]
    
    # Collect the MSE for each task for each model
    train_error = []
    test_error = []
    sparse_error = []
    
    # Run through models and get their performance
    for model_dir in model_dirs:    
        # Load arguments from model_dir json file
        with open(os.path.join(model_dir, 'args.json')) as f:
            args = json.load(f)
        
        # Frst plot training tasks
        args['do_test']=False
        
        # Replace the loaded "models" module by the source file saved in training - if it exists
        run.load_module(model_dir + '/source')
        
        # Prepare a dataset, defined on a set of tasks
        args['data_dir'] = os.path.join(
            os.getcwd(), os.path.basename(os.path.normpath(args['data_dir'])))
        
        # Specify data object to use
        dataset = data.TaskSequenceDataset(args, task_samples=1e4)
        
        # Get model
        model = run.get_model(args, dataset)
        
        # Keep model on cpu
        device = 'cpu'
        model.to(device)
        model.set_device(device)
        # Load model trained parameters    
        model.load_state_dict(torch.load(os.path.join(model_dir, 'model.zip'),
                                         map_location=torch.device(device), weights_only=True))    
        # Set model to evaluation mode
        model.eval()
        
        # Now run the model on test tasks
        with torch.no_grad():
            model_in = [dataset.__getitem__(int(t*dataset.task_samples)+1) for t in range(dataset.task_stop - dataset.task_start)]
            model_in = {k: torch.stack([m[k] for m in model_in]) for k in model_in[0].keys() if k not in ['id']}
            model_out = model.nll(model_in['input'], model_in['context'], model_in['output'], N=100)
            trace_back = model.trace_back(model_out, N=1, sample=False)
        # Squeeze particles in backward trace (there's only one)
        trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}
            
        # Calculate error per task
        valid_loss = torch.sum(model_in['context'], -1) > 0
        loss_mat = np.sum(((model_in['output'] - trace_back['action_out'])*valid_loss.unsqueeze(-1)).numpy()**2,axis=-1)
        task_loss = np.sum(loss_mat, -1) / np.sum(valid_loss.numpy(), -1)
    
        # Add task loss to errors
        train_error.append(task_loss)
        
        # Then repeat the same for test tasks
        args['do_test']=True
        
        # Specify data object to use
        dataset = data.TaskSequenceDataset(args, task_samples=1e4)
        
        # Now run the model on test tasks
        with torch.no_grad():
            model_in = [dataset.__getitem__(int(t*dataset.task_samples)+1) for t in range(dataset.task_stop - dataset.task_start)]
            model_in = {k: torch.stack([m[k] for m in model_in]) for k in model_in[0].keys() if k not in ['id']}
            model_out = model.nll(model_in['input'], model_in['context'], model_in['output'], N=100)
            trace_back = model.trace_back(model_out, N=1, sample=False)
        # Squeeze particles in backward trace (there's only one)
        trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}
            
        # Calculate error per task
        valid_loss = torch.sum(model_in['context'], -1) > 0
        loss_mat = np.sum(((model_in['output'] - trace_back['action_out'])*valid_loss.unsqueeze(-1)).numpy()**2,axis=-1)
        task_loss = np.sum(loss_mat, -1) / np.sum(valid_loss.numpy(), -1)
    
        # Add task loss to errors
        test_error.append(task_loss)
        
        # The repeat the same for sparse feedback
        args['feedback_density'] = 0.25
        
        # Specify data object to use
        dataset = data.TaskSequenceDataset(args, task_samples=1e4)
        
        # Now run the model on test tasks
        with torch.no_grad():
            model_in = [dataset.__getitem__(int(t*dataset.task_samples)+1) for t in range(dataset.task_stop - dataset.task_start)]
            model_in = {k: torch.stack([m[k] for m in model_in]) for k in model_in[0].keys() if k not in ['id']}
            model_out = model.nll(model_in['input'], model_in['context'], model_in['output'], N=100)
            trace_back = model.trace_back(model_out, N=1, sample=False)
        # Squeeze particles in backward trace (there's only one)
        trace_back = {k: v.squeeze(1) for k, v in trace_back.items()}
            
        # Calculate error per task
        valid_loss = torch.sum(model_in['context'], -1) > 0
        loss_mat = np.sum(((model_in['output'] - trace_back['action_out'])*valid_loss.unsqueeze(-1)).numpy()**2,axis=-1)
        task_loss = np.sum(loss_mat, -1) / np.sum(valid_loss.numpy(), -1)
    
        # Add task loss to errors
        sparse_error.append(task_loss)
                
    # Stack errors into big matrix
    train_error = np.stack(train_error)
    test_error = np.stack(test_error)
    sparse_error = np.stack(sparse_error)
    
    # And append to data
    results.append([train_error, test_error, sparse_error])
    
### PLOT ALL PANELS ###

names = ['RNN, no task id', 'RNN with task id', 'Flat transitions', 'Full model']
labels = ['a', 'b', 'c', 'd']

# Set ylim so it's the same across figures
y_max = 30
y_inset = [0,60]

# Plot all panels
x_center = 0.5 + 2*np.arange(3)
x_ticks = ['Train', 'Test', 'Sparse']
for curr_plot, (dat, name, label) in enumerate(zip(results, names, labels)):
    # Collect stats for current data
    means = [np.mean(d, axis=-1) for d in dat]
    sems = [np.std(d, axis=-1)/np.sqrt(d.shape[-1]) for d in dat]    
    # Make a plot
    fig, ax = plt.subplots(figsize=(unit_size*len(dat), unit_size*2))
    plt.bar(x_center[:len(dat)], [np.mean(m) for m in means], width=1.2, facecolor=[1,1,1], edgecolor=[0,0,0])
    for i, (m, s) in enumerate(zip(means, sems)):        
        plt.errorbar(i*2 + np.linspace(0,1,len(m)), m, s, marker='o', color=[0.7, 0.7, 0.7], linestyle='')
    plt.xticks(x_center[:len(dat)],labels=x_ticks[:len(dat)])
    plt.xlim([x_center[0]-1, x_center[len(dat)-1]+1])
    curr_lims = [0, y_max]
    plt.yticks(curr_lims, labels=[f'{y:.0f}' for y in curr_lims])
    plt.ylim(curr_lims)
    plt.ylabel('MSE', labelpad=-10)
    plt.title(name)
    # For the second plot, with huge errors, create an inset
    if curr_plot == 1:
        # Draw big arrow from right plot to inset
        ax.annotate("", xytext=(x_center[0], y_max/2), xy=(x_center[1], y_max/2),
                    arrowprops=dict(arrowstyle="->"))        
        curr_bar = len(dat)-1
        m = means[curr_bar]; s = sems[curr_bar]
        # Create inset
        axins = ax.inset_axes([0.25, 0.3, 0.3, 0.6],
                              xlim=[x_center[curr_bar] - 1, x_center[curr_bar]+1], 
                              ylim=y_inset, xticks=[], 
                              yticks=y_inset, yticklabels=[f'{y:.0f}' for y in y_inset])
        # Plot result
        axins.bar(x_center[curr_bar], np.mean(m), width=1.2, facecolor=[1,1,1], edgecolor=[0,0,0])
        axins.errorbar(curr_bar*2 + np.linspace(0,1,len(m)), m, s, marker='o', color=[0.7, 0.7, 0.7], linestyle='')
    plt.subplots_adjust(top=0.834,
        bottom=0.164,
        left=0.196,
        right=0.908,
        hspace=0.2,
        wspace=0.2)    
    plt.savefig(f"figures/3{label}.svg", format="svg")
    plt.savefig(f"figures/3{label}.png", format="png")    
