# Script to organize few-shot learning results of all the methods
# to a latex table file

import os
import sys
sys.path.append("../")
from utils import *
from res_all import ACCURACIES_ALL



if __name__ == "__main__":
    print(f"Processing results...")

    latex_table = "\\toprule \n"
    latex_table += "\\multirow{2}{*}{Method} & "
    latex_table += "\\multicolumn{2}{c}{SmallNORB} & "
    latex_table += "\\multicolumn{2}{c}{Shapes3D} & "
    latex_table += "\\multicolumn{2}{c}{Causal3D} & "
    latex_table += "\\multicolumn{2}{c}{MPI3D-Easy} & "
    latex_table += "\\multicolumn{2}{c}{MPI3D-Hard} \\\\ \n"
    latex_table += " & 5-Shot & 10-Shot & 5-Shot & 10-Shot & 5-Shot & 10-Shot & 5-Shot & 10-Shot & 5-Shot & 10-Shot \\\\ \n"
    latex_table += "\\midrule \n"

    for method, res_dict in ACCURACIES_ALL.items():
        if method.startswith("Ablate"):
            continue
        if method == 'Few-Shot Direct Adaptation':
            latex_table += "FSDA & "
        elif method == 'Pre-Training and Fine-Tuning':
            latex_table += "PTFT & "
        elif method == "Supervised-Original":
            latex_table += "\\makecell[r]{Supervised-\\\\ \hspace{8pt} Original} & "
        elif method == "Supervised-All":
            latex_table += "\\makecell[r]{Supervised-\\\\ \hspace{8pt} All} & "
        elif method == "Supervised-Oracle":
            latex_table += "\\makecell[r]{Supervised-\\\\ \hspace{8pt} Oracle} & "
        elif method == "Meta-GMVAE":
            latex_table += "\\makecell[r]{Meta-\\\\ \hspace{8pt} GMVAE} & "
        elif method == "CACTUS-DC":
            latex_table += "\\makecell[r]{CACTUS-\\\\ \hspace{8pt} DC} & "
        elif method == "CACTUS-DINO":
            latex_table += "\\makecell[r]{CACTUS-\\\\ \hspace{8pt} DINO} & "
        elif method == 'DRESS':
            latex_table += "\\textbf{DRESS} & "
        else:
            latex_table += f"{method} & "
        for ds in ['smallnorb', 'shapes3d', 'causal3d', 'mpi3deasy', 'mpi3dhard']:
            for shot in ['five-shot', "ten-shot"]:
                res_vals = res_dict[shot][ds]
                if len(res_vals) == 0:
                    latex_table += "TODO"
                else:
                    res_avg = np.mean(res_vals)
                    res_std = np.std(res_vals) / np.sqrt(len(res_vals))
                    if (method == "DRESS" and ds in ['mpi3deasy', 'mpi3dhard', 'smallnorb', 'causal3d']) or \
                        (method == "PsCo" and ds == 'shapes3d'):
                        # bold font
                        latex_table += f"\\makecell[l]{{\\textbf{{{res_avg:.1f}}}\% \\\\ {{\\tiny $\pm$ \\textbf{{{res_std:.1f}}}}}\%}}"
                    else:
                        latex_table += f"\\makecell[l]{{${res_avg:.1f}\%$ \\\\ {{\\tiny $\pm {res_std:.1f}\%$}}}}"
                if ds != "mpi3dhard":
                    latex_table += " & "
                else:
                    if shot == "five-shot":
                        latex_table += " & "
                    else:
                        latex_table += "\\\\ \n"
        if method in ["Supervised-Oracle", 
                      "Few-Shot Direct Adaptation", 
                      "PsCo"]:
            latex_table += "\hline \n"
    # finishing the table by an underline
    latex_table += "\\bottomrule \n" 
    
    # write the results to a file 
    latex_table_filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), 
                                        "res_table_simpleds.tex")
    with open(latex_table_filename, "w") as f:
        f.write(latex_table)
    
    print(f"Latex table generated and saved to {latex_table_filename} file!")

    print("Script finished!")