"""This module provides utility functions for captum results."""

import os
from typing import List
from matplotlib import pyplot as plt
from Code.post_processing_utils import plot_attr_vals

def get_plots(attr_info, input_tokens, fig_info, removed_indices:List[int], 
              normalize:bool=True, save_path=None, fixed_width:bool=False):
    """
    Generate plots for the given attribute information and input tokens.

    Args:
        attr_info (dict): A dictionary containing attribute information.
        input_tokens (list): A list of input tokens.
        fig_info (dict): A dictionary containing figure information.
        removed_indices (List[int]): A list of indices to be removed.
        normalize (bool, optional): Whether to normalize the plots. Defaults to True.
        save_path (str, optional): The path to save the plots. Defaults to None.
        fixed_width (bool, optional): Whether to use a fixed width for the plots. Defaults to False.

    Returns:
        tuple: A tuple containing the generated figure and axes.
    """
    
    attr_info_keys = list(attr_info.keys())
    # Keep the string before "_attr_info_" in the key
    attr_models = [attr_info_key.split("_attr_info")[0] for attr_info_key in attr_info_keys]

    inner_keys = [sorted(list(attr_info.get(attr_info_key).keys())) for attr_info_key in attr_info_keys]
    print(f'inner_keys: {inner_keys}')

    filtered_key = 'model_attr_embeddings_logit_index'
    inner_filtered_keys = [[key for key in inner_key if filtered_key in key] for inner_key in inner_keys]
    print(f'inner_foltered_keys: {inner_filtered_keys}')

    l1 = len(attr_info_keys)
    l2 = len(inner_filtered_keys[0])

    print(f'l1: {l1}')
    print(f'l2: {l2}')

    num_plots = l1*l2
    print(f'number of plots: {num_plots}')

    if fixed_width:
        width = fig_info.get('width', 20)
    else:
        width = len(input_tokens) * 2.5

    height = fig_info.get('height', 3)
    font_size = fig_info.get('font_size', 14)
    y_font_size = font_size - 3
    dpi = fig_info.get("dpi", 300)

    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"

    fig, axes = plt.subplots(num_plots, 1, figsize=(width, 2.75*num_plots))

    for i in range(0, l1):
        attr_model_name = attr_models[i]
        print(f'attr_model_name: {attr_model_name}')

        for j in range(0, l2):
            attr_info_logit = f'model_attr_embeddings_logit_index_{j}'
            print(f'attr_info_logit: {attr_info_logit}')
            attr_vals = attr_info.get(attr_info_keys[i]).get(attr_info_logit).squeeze(0)
            print(f'attr_vals is {attr_vals}')

            plot_attr_vals(attr_vals, input_tokens, fig_info, removed_indices, normalize, axis=axes[i*l2+j])
            axes[i*l2+j].set_ylabel(f'{attr_model_name}-Class={j}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    # use tight_layout to adjust the spacing between subplots
    plt.tight_layout()
    
    # reduce the sapce between plots
    if "heatmap" in fig_info.get('plot_type'):
        fig.subplots_adjust(wspace=0, hspace=0.5)

    else:
        fig.subplots_adjust(wspace=0, hspace=0.5)

    fig.set_size_inches(width / 2.54, (1.25)*height * num_plots / 2.54)

    if save_path is not None:
        fig.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')
    return fig, axes

def save_single_plots(attr_info, input_tokens, fig_info, removed_indices:List[int], 
              normalize:bool=True, save_dir=None, fixed_width:bool=False):
    """
    Generate plots for the given attribute information and input tokens.

    Args:
        attr_info (dict): A dictionary containing attribute information.
        input_tokens (list): A list of input tokens.
        fig_info (dict): A dictionary containing figure information.
        removed_indices (List[int]): A list of indices to be removed.
        normalize (bool, optional): Whether to normalize the plots. Defaults to True.
        save_path (str, optional): The path to save the plots. Defaults to None.
        fixed_width (bool, optional): Whether to use a fixed width for the plots. Defaults to False.

    Returns:
        tuple: A tuple containing the generated figure and axes.
    """
    
    attr_info_keys = list(attr_info.keys())
    # Keep the string before "_attr_info_" in the key
    attr_models = [attr_info_key.split("_attr_info")[0] for attr_info_key in attr_info_keys]

    inner_keys = [sorted(list(attr_info.get(attr_info_key).keys())) for attr_info_key in attr_info_keys]
    print(f'inner_keys: {inner_keys}')

    filtered_key = 'model_attr_embeddings_logit_index'
    inner_filtered_keys = [[key for key in inner_key if filtered_key in key] for inner_key in inner_keys]
    print(f'inner_foltered_keys: {inner_filtered_keys}')

    l1 = len(attr_info_keys)
    l2 = len(inner_filtered_keys[0])

    print(f'l1: {l1}')
    print(f'l2: {l2}')

    num_plots = l1*l2
    print(f'number of plots: {num_plots}')

    plot_type = fig_info.get('plot_type', 'heatmap')
    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"


    for i in range(0, l1):
        attr_model_name = attr_models[i]
        print(f'attr_model_name: {attr_model_name}')

        for j in range(0, l2):
            attr_info_logit = f'model_attr_embeddings_logit_index_{j}'
            print(f'attr_info_logit: {attr_info_logit}')

            attr_vals = attr_info.get(attr_info_keys[i]).get(attr_info_logit).squeeze(0)
            print(f'attr_vals is {attr_vals}')

            _, axes = plt.subplots(1,1)

            plot_attr_vals(attr_vals=attr_vals, 
                           input_tokens=input_tokens, 
                           fig_info=fig_info, 
                           removed_indices=removed_indices, 
                           normalize=normalize, 
                           axis=axes,
                           save_path=os.path.join(save_dir, f'{attr_model_name}-Class={j}_{plot_type}.pdf'),
                           fixed_width=fixed_width)
            

def get_plots_qa(attr_info, input_tokens, fig_info, removed_indices:List[int], 
              normalize:bool=True, save_path=None, fixed_width:bool=False):
    """
    Generate plots for the given attribute information and input tokens.

    Args:
        attr_info (dict): A dictionary containing attribute information.
        input_tokens (list): A list of input tokens.
        fig_info (dict): A dictionary containing figure information.
        removed_indices (List[int]): A list of indices to be removed.
        normalize (bool, optional): Whether to normalize the plots. Defaults to True.
        save_path (str, optional): The path to save the plots. Defaults to None.
        fixed_width (bool, optional): Whether to use a fixed width for the plots. Defaults to False.

    Returns:
        tuple: A tuple containing the generated figure and axes.
    """
    
    attr_info_keys = list(attr_info.keys())
    attr_info_keys_start = [key for key in attr_info_keys if 'start' in key]
    attr_info_keys_end = [key for key in attr_info_keys if 'end' in key]

    # Keep the string before "_attr_info_" in the key
    attr_models_start = [attr_info_key.split("_attr_info")[0] for attr_info_key in attr_info_keys_start]
    attr_models_end = [attr_info_key.split("_attr_info")[0] for attr_info_key in attr_info_keys_end]

    assert set(attr_models_start) == set(attr_models_end), "Start and end models should be the same."


    inner_keys_start = [sorted(list(attr_info.get(attr_info_key).keys())) for attr_info_key in attr_info_keys_start]
    print(f'inner_keys_start: {inner_keys_start}')

    inner_keys_end = [sorted(list(attr_info.get(attr_info_key).keys())) for attr_info_key in attr_info_keys_end]
    print(f'inner_keys_end: {inner_keys_end}')

    filtered_key = 'model_attr_embeddings_logit_index'
    inner_filtered_keys_start = [[key for key in inner_key if filtered_key in key] for inner_key in inner_keys_start]
    print(f'inner_filtered_keys_start: {inner_filtered_keys_start}')

    inner_filtered_keys_end = [[key for key in inner_key if filtered_key in key] for inner_key in inner_keys_end]
    print(f'inner_filtered_keys_end: {inner_filtered_keys_end}')

    assert len(inner_filtered_keys_start) == len(inner_filtered_keys_end), "Start and end keys should be the same."
    assert len(inner_filtered_keys_start[0]) == len(inner_filtered_keys_end[0]), "Start and end keys should be the same."

    l1 = len(attr_info_keys_start)
    l2 = len(inner_filtered_keys_start[0])

    print(f'l1: {l1}')
    print(f'l2: {l2}')

    num_plots = l1*(2*l2)
    print(f'number of plots: {num_plots}')

    if fixed_width:
        width = fig_info.get('width', 20)
    else:
        width = len(input_tokens) * 2.5
    
    height = fig_info.get('height', 3)
    font_size = fig_info.get('font_size', 14)
    y_font_size = font_size - 3
    dpi = fig_info.get("dpi", 300)

    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"

    fig, axes = plt.subplots(num_plots, 1)

    for i in range(0, l1):
        counter_i= i*(2*l2)

        for j in range(0, l2):
            attr_info_logit = f'model_attr_embeddings_logit_index_{j}'
            attr_vals = attr_info.get(attr_info_keys_start[i]).get(attr_info_logit).squeeze(0)

            plot_attr_vals(attr_vals, input_tokens, fig_info, removed_indices, normalize, axis=axes[counter_i+j])
            axes[counter_i+j].set_ylabel(f'{attr_models_start[i]}-Start={j}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

        for j in range(0, l2):
            attr_info_logit = f'model_attr_embeddings_logit_index_{j}'
            attr_vals = attr_info.get(attr_info_keys_end[i]).get(attr_info_logit).squeeze(0)

            plot_attr_vals(attr_vals, input_tokens, fig_info, removed_indices, normalize, axis=axes[counter_i+l2+j])
            axes[counter_i+l2+j].set_ylabel(f'{attr_models_end[i]}-End={j}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
    
    # use tight_layout to adjust the spacing between subplots
    plt.tight_layout()
    
    # reduce the sapce between plots
    if "heatmap" in fig_info.get('plot_type'):
        fig.subplots_adjust(wspace=0, hspace=0.5)

    else:
        fig.subplots_adjust(wspace=0, hspace=0.5)

    # set the figure size to the desired size
    fig.set_size_inches(width / 2.54, (1.25)*height * num_plots / 2.54)

    if save_path is not None:
        fig.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')
    return fig, axes


def save_single_plots_qa(attr_info, input_tokens, fig_info, removed_indices:List[int], 
              normalize:bool=True, save_dir=None, fixed_width:bool=False):
    """
    Generate plots for the given attribute information and input tokens.

    Args:
        attr_info (dict): A dictionary containing attribute information.
        input_tokens (list): A list of input tokens.
        fig_info (dict): A dictionary containing figure information.
        removed_indices (List[int]): A list of indices to be removed.
        normalize (bool, optional): Whether to normalize the plots. Defaults to True.
        save_path (str, optional): The path to save the plots. Defaults to None.
        fixed_width (bool, optional): Whether to use a fixed width for the plots. Defaults to False.

    Returns:
        tuple: A tuple containing the generated figure and axes.
    """
    
    attr_info_keys = list(attr_info.keys())
    attr_info_keys_start = [key for key in attr_info_keys if 'start' in key]
    attr_info_keys_end = [key for key in attr_info_keys if 'end' in key]

    # Keep the string before "_attr_info_" in the key
    attr_models_start = [attr_info_key.split("_attr_info")[0] for attr_info_key in attr_info_keys_start]
    attr_models_end = [attr_info_key.split("_attr_info")[0] for attr_info_key in attr_info_keys_end]

    assert set(attr_models_start) == set(attr_models_end), "Start and end models should be the same."


    inner_keys_start = [sorted(list(attr_info.get(attr_info_key).keys())) for attr_info_key in attr_info_keys_start]
    print(f'inner_keys_start: {inner_keys_start}')

    inner_keys_end = [sorted(list(attr_info.get(attr_info_key).keys())) for attr_info_key in attr_info_keys_end]
    print(f'inner_keys_end: {inner_keys_end}')

    filtered_key = 'model_attr_embeddings_logit_index'
    inner_filtered_keys_start = [[key for key in inner_key if filtered_key in key] for inner_key in inner_keys_start]
    print(f'inner_foltered_keys_start: {inner_filtered_keys_start}')

    inner_filtered_keys_end = [[key for key in inner_key if filtered_key in key] for inner_key in inner_keys_end]
    print(f'inner_foltered_keys_end: {inner_filtered_keys_end}')

    assert len(inner_filtered_keys_start) == len(inner_filtered_keys_end), "Start and end keys should be the same."
    assert len(inner_filtered_keys_start[0]) == len(inner_filtered_keys_end[0]), "Start and end keys should be the same."

    l1 = len(attr_info_keys_start)
    l2 = len(inner_filtered_keys_start[0])

    print(f'l1: {l1}')
    print(f'l2: {l2}')

    num_plots = l1*(2*l2)
    print(f'number of plots: {num_plots}')

    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"


    for i in range(0, l1):
        counter_i= i*(2*l2)

        for j in range(0, l2):
            attr_info_logit = f'model_attr_embeddings_logit_index_{j}'
            attr_vals = attr_info.get(attr_info_keys_start[i]).get(attr_info_logit).squeeze(0)

            _, axes = plt.subplots(1,1)

            plot_attr_vals(attr_vals=attr_vals, 
                           input_tokens=input_tokens, 
                           fig_info=fig_info, 
                           removed_indices=removed_indices, 
                           normalize=normalize, 
                           axis=axes,
                           save_path=os.path.join(save_dir, f'{attr_models_start[i]}-Start={j}_heatmap.pdf'),
                           fixed_width=fixed_width)

        for j in range(0, l2):
            attr_info_logit = f'model_attr_embeddings_logit_index_{j}'
            attr_vals = attr_info.get(attr_info_keys_end[i]).get(attr_info_logit).squeeze(0)

            _, axes = plt.subplots(1,1)

            plot_attr_vals(attr_vals=attr_vals, 
                           input_tokens=input_tokens, 
                           fig_info=fig_info, 
                           removed_indices=removed_indices, 
                           normalize=normalize, 
                           axis=axes,
                           save_path=os.path.join(save_dir, f'{attr_models_end[i]}-End={j}_heatmap.pdf'),
                           fixed_width=fixed_width)