import argparse
import time
from pathlib import Path
from typing import List

import torch
from torch import Tensor

from attention_modules import attention_configdict, attention_moduledict
from configs import dataset_paths, column_labels
from util_funcs import plot_eigenvalues, plot_heads


def visualize_clusters(input_data: Tensor,
                       attention_module: torch.nn.ModuleList,
                       mask: Tensor = None,
                       eigen_plots: bool = True,
                       top_n_eigen_vals: int = 20,
                       elevation: int = 15,
                       azimuthal: int = -135,
                       axes_labels: List = None,
                       axes_ticks: bool = True,
                       save_pt_path: str = None,
                       save_png_path: str = None,
                       plot_attention_lines_interval: int = 20,
                       device: str = None,
                       ):
    """
    produces visualization of attention clusters at each iteration,
    with optional eigenvalue plots.


    :param input_data: Tensor of shape (batch, length, dim). dim must be
        divisible by 3 (3d plot for each head)
    :param attention_module: Attention layers packed in a ModuleList. Each layer
        must accept arguments (hidden_states, attention_mask) and return tuple of
        (output, attention_matrix, ...) as in HuggingFace implementations.
    :param mask: Optional attention mask
    :param eigen_plots: Plots eigenvalues if True
    :param top_n_eigen_vals: top n eigenvalues to plot
    :param elevation: Elevation angle for 3d plots
    :param azimuthal: Azimuthal angle for 3d plots
    :param axes_labels: Axes labels for data with known features
    :param axes_ticks: Shows axes ticks if True
    :param save_pt_path: Saves final outputs tensors to the specified path
    :param save_png_path: Saves png plots to this path
    :param plot_attention_lines_interval: plots attention lines every interval specified.
           (note: this slows down computation)
    :param device: 'cpu' or 'cuda'. If None, uses available device
    """

    iter_start = 1
    attention_matrix = None
    eigen_vals, eigen_vecs = list([]), list([])

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        assert device in ['cpu', 'cuda'], 'device must be "cpu" or "cuda"'

    if save_png_path is not None:
        Path(save_png_path).mkdir(parents=True, exist_ok=True)

    print(f'running on {device}')

    for layer in attention_module:
        start = time.time()
        input_data = input_data.to(device)
        if mask is not None:
            mask = mask.to(device)

        if eigen_plots is True:
            cov_matrix = torch.matmul(input_data.transpose(-1, -2), input_data)
            eig_val, eig_vec = torch.linalg.eig(cov_matrix)  # (b, dim)
            eigen_vals.append(eig_val.detach().cpu())
            eigen_vecs.append(eig_vec.detach().cpu())

        layer = layer.to(device)
        outputs = layer(hidden_states=input_data, attention_mask=mask)
        hidden_states = outputs[0]
        if len(outputs) > 1:
            attention_matrix = outputs[1]

        mean = hidden_states.mean(dim=1, keepdim=True)
        std = hidden_states.std(dim=1, keepdim=True)
        hidden_states = (hidden_states - mean) / std
        input_data = hidden_states + input_data

        print(f'iter: {iter_start} | processing time: {time.time() - start:.4f}seconds')

        plot_heads(hidden_states=input_data,
                   title=f't = {iter_start}',
                   attention_weights=attention_matrix,
                   mask=mask,
                   save_dir=f'{save_png_path}/cluster_fig_{iter_start}.png' if save_png_path else None,
                   elevation=elevation,
                   azimuthal=azimuthal,
                   axes_labels=axes_labels,
                   plot_attention_lines=(iter_start % plot_attention_lines_interval == 0
                                         if plot_attention_lines_interval is not None else False),
                   ticks=axes_ticks,
                   show=True)

        iter_start += 1

    if save_pt_path is not None:
        Path(save_pt_path).mkdir(parents=True, exist_ok=True)
        torch.save(input_data.detach().cpu(), f'{save_pt_path}/clusters_iter_{iter_start}.pt')
        if attention_matrix is not None:
            torch.save(attention_matrix.detach().cpu(), f'{save_pt_path}/attention_mat_{iter_start}.pt')
        if len(eigen_vecs) > 0:
            torch.save(eigen_vals, f'{save_pt_path}/eigen_vals.pt')
            torch.save(eigen_vecs, f'{save_pt_path}/eigen_vecs.pt')

    if eigen_plots is True:
        plot_eigenvalues(
            eigenvalues_data=eigen_vals,
            n=top_n_eigen_vals,
            save_path=save_png_path
        )


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Visualize attention clusters with various parameters')

    # Dataset arguments
    parser.add_argument('--dataset', type=str, default='image',
                        choices=['diabetes', 'heart', 'image', 'generated'],
                        help='Dataset to use: diabetes, heart, image, or generated')
    parser.add_argument('--data_path', type=str, default=None,
                        help='Custom path to data.pt file (overrides default dataset paths)')
    parser.add_argument('--attn_mask', type=str, default=None,
                        help='path to mask.pt file')
    parser.add_argument('--data_columns', type=str, default=None,
                        help='Column names of custom data as comma-separated values (e.g., "feat1,feat2,feat3")')

    # Model configuration
    parser.add_argument('--num_layers', type=int, default=20,
                        help='Number of attention layers')
    parser.add_argument('--hidden_size', type=int, default=3,
                        help='Hidden size dimension')
    parser.add_argument('--attn_type', type=str, default='global',
                        help='attention type, must be in attention_modules.py dict')

    # Visualization parameters
    parser.add_argument('--elevation', type=int, default=15,
                        help='Elevation angle for 3D plots')
    parser.add_argument('--azimuthal', type=int, default=-135,
                        help='Azimuthal angle for 3D plots')
    parser.add_argument('--no_eigen_plots', action='store_false', dest='eigen_plots',
                        help='Disable eigenvalue plots')
    parser.add_argument('--top_n_eigen_vals', type=int, default=20,
                        help='Number of top eigenvalues to plot')
    parser.add_argument('--plot_attn', type=int, default=None,
                        help='Interval for plotting attention lines')
    parser.add_argument('--no_axes_ticks', action='store_false', dest='axes_ticks',
                        help='Hide axes ticks')

    # Output options
    parser.add_argument('--save_pt_path', type=str, default='outputs',
                        help='Path to save output tensors')
    parser.add_argument('--save_png_path', type=str, default='figures',
                        help='Path to save PNG plots')
    parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default=None,
                        help='Device to use (cpu/cuda), auto-detects if not specified')

    args = parser.parse_args()

    data_path = args.data_path if args.data_path is not None else dataset_paths[args.dataset]
    input_tensor = torch.load(data_path)
    b, s, d = input_tensor.shape
    attention_mask = torch.load(args.attn_mask) if args.attn_mask is not None else None

    if args.data_columns:
        cols = [col.strip() for col in args.data_columns.split(',')]
    else:
        cols = column_labels.get(args.dataset, None)

    assert args.attn_type in attention_configdict.keys(), 'not defined in attention_modules.py config dict'
    assert args.attn_type in attention_moduledict.keys(), 'not defined in attention_modules.py modules dict'
    model_config = attention_configdict[args.attn_type]
    attention = attention_moduledict[args.attn_type]

    model_config.max_position_embeddings = input_tensor.shape[1]
    model_config.hidden_size = input_tensor.shape[-1]
    model_config.num_attention_heads = model_config.hidden_size // 3  # each head should have 3 axes for 3d plots
    model_config.output_attentions = True
    model_config.attention_probs_dropout_prob = 0.0

    attention_layers = torch.nn.ModuleList(
        [attention(model_config) for _ in range(args.num_layers)]
    )

    visualize_clusters(
        input_data=input_tensor,
        mask=attention_mask,
        attention_module=attention_layers,
        eigen_plots=args.eigen_plots,
        top_n_eigen_vals=args.top_n_eigen_vals,
        elevation=args.elevation,
        azimuthal=args.azimuthal,
        axes_labels=cols,
        axes_ticks=args.axes_ticks,
        save_pt_path=args.save_pt_path,
        save_png_path=args.save_png_path,
        plot_attention_lines_interval=args.plot_attn,
        device=args.device
    )
