#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jul 25 14:28:15 2024

@author: XXXX
"""

import run
from matplotlib import pyplot as plt

# Set modes of variation
model = 'gating' 
feedback = 'f_' # '' or 'f_'
tasks = 'rand' # 'rand' or 'nonm'
data = 'xor' # 'n' or 'u'
folders = ['gate_xor']
# Create all postfixes
postfixes = ['h_i0_v1',
             'l_h_i0_v1',
             'i0_v1',
             'l_i0_v1',
             'k_i0_v1',
             'l_k_i0_v1',
             'h_i0_v6',
             'l_h_i0_v6',
             'i0_v6',
             'l_i0_v6',
             'k_i0_v6',
             'l_k_i0_v6']
# And the names of those
names = ['single expert, no gating, freeze experts',
         'single expert, no gating, relearn experts',
         'single expert, gate neurons, freeze experts',
         'single expert, gate neurons, relearn experts',
         'single expert, context gating, freeze experts',
         'single expert, context gating, relearn experts',
         'multi expert, no gating, freeze experts',
         'multi expert, no gating, relearn experts',
         'multi expert, gate experts, freeze experts',
         'multi expert, gate experts, relearn experts',
         'multi expert, context gating, freeze experts',
         'multi expert, context gating, relearn experts']
# Create dirs from all postfixes
dirs=[folder + '/' + model + '_' + tasks + '_' + data + '_' + feedback + postfix + test 
      for test in ['', '_2'] for folder in folders for postfix in postfixes]
if len(folders)>1:
    names = [folder + ', ' + name for folder in folders for name in names]

# Create dictionary for legends
legends = {d: ('Test ' if d[-1]=='2' else 'Train ') + n for d, n in zip(dirs, names+names)}
# Create sets of comparisons to include
comparisons = {'XOR_Single': [0,2,4],
               'XOR_SingleRelearn': [1,3,5],
               'XOR_Multi': [6,8,10],
               'XOR_MultiRelearn': [7, 9, 11]}

# Set values to plot
vals = ['Losses/all', 'Accuracies/Activations', 'Accuracies/Modularity']#['Losses/all', 'Rewards/total', 'Accuracies/Boundaries', 'Accuracies/Activations']
for c_k, c_v in comparisons.items():
    curr_dirs = [dirs[i] for i in c_v] + [dirs[i+len(names)] for i in c_v] #+ dirs[-2:]
    figs = run.compare(vals, dirs=curr_dirs, legends=legends, zoom=4000)
    for i, (name, fig) in enumerate(zip(vals, figs)):
        filename = f'{c_k}{i}_{name.split("/")[0]}.png'
        fig.savefig('/home/XXXX/Documents/Figs/MoE/' + filename)
        #plt.close(fig)