import argparse
import os
from transformers import AutoModelForCausalLM
import torch 
from tqdm import tqdm 
import pdb 
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import pdb 

def main():
    parser = argparse.ArgumentParser(description="Script to process model and save directory.")
    
    parser.add_argument("--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                        help="Name of the model (default: TinyLlama/TinyLlama-1.1B-Chat-v1.0)")
    
    parser.add_argument("--save_dir", type=str, default="saved_plots",
                        help="Directory to save plots (default: saved_plots)")
    
    args = parser.parse_args()
    
    # Create the save directory if it doesn't exist
    os.makedirs(args.save_dir, exist_ok=True)
    
    model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16)
    print('\nMode Loaded..')

    full_name_dict = {module: name for name, module in model.named_modules()}
    linear_info = {}
    modules = [model]
    while len(modules) > 0:
        submodule = modules.pop()
        for name, raw_linear in submodule.named_children():
            if isinstance(raw_linear, torch.nn.Linear):
                full_name = full_name_dict[raw_linear]
                linear_info[raw_linear] = {
                    "father": submodule,
                    "name": name,
                    "full_name": full_name,
                }
            else:
                modules.append(raw_linear)


    N = len(model.model.layers)
    keep_layers = [str(int(N*0.2)), str(int(N*0.4)), str(int(N*0.6)), str(int(N*0.8))]
    # layer_types = ['k_proj', 'up_proj', 'down_proj', 'v_proj']
    layer_types = ['k_proj', 'down_proj']

    total_len = 0 
    for _, (name, module) in enumerate(model.named_modules()):
        if name.count('.') > 2 and name.split('.')[2] in keep_layers and name.split('.')[-1] in layer_types:
            total_len += 1
        
    df = [] 
    for name, module in tqdm(model.named_modules(), total=total_len, desc='Replacing Linear with Low-Rank Layers'):
        if 'lm_head' in name:
            print('Ignored low-rank decomposition on logits layer')

        elif name.count('.') > 2 and name.split('.')[2] in keep_layers and isinstance(module, torch.nn.Linear):
            layer_name = name.split('.')[-1]

            if layer_name not in layer_types:
                continue 
            
            rank = min(module.in_features, module.out_features)
            U, E, V = torch.svd_lowrank(module.weight.float(),
                                        q=rank,
                                        niter=2)
    
            layer_idx = int(name.split('.')[2])
            df.append({'layer_name': layer_name, 'layer_idx': layer_idx, 'singular_values': E.detach().numpy()})

    # dataframe with columns: layer_name, layer_idx, singular_values(numpy array). 
    df = pd.DataFrame(df)
    df['layer_idx'] = df['layer_idx'].astype(str)

    # Filter for keep_layers
    filtered_df = df[df['layer_idx'].astype(str).isin(keep_layers)]

    # Set the style for seaborn
    sns.set_style("whitegrid")
    # plt.rcParams['font.family'] = 'sans-serif'
    # plt.rcParams['font.sans-serif'] = ['Arial']

    # Create a custom colormap from bright red to dark red
    colors = ['#FF0000', '#8B0000']  # Bright red to dark red
    n_bins = 100
    cmap = LinearSegmentedColormap.from_list('bright_to_dark_red', colors, N=n_bins)

    def plot_singular_values():
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        # fig.suptitle('Distribution of Singular Values across layers', fontsize=16, fontweight='bold')

        layer_names = ['k_proj', 'down_proj']
        
        # Find global max singular value for consistent color scaling
        max_singular_value = max(
            filtered_df[filtered_df['layer_name'].isin(layer_names)]['singular_values'].apply(max)
        )

        for row, layer_name in enumerate(layer_names):
            # Add a layer_name title for each row of subplots
            layer_title = f'Layer Name: {layer_name}'
            axes[row, 1].annotate(layer_title, xy=(0, 1.1), xycoords='axes fraction', fontsize=16, fontweight='bold')
            
            for i, layer_idx in enumerate(keep_layers):
                ax = axes[row, i]
                subset = filtered_df[(filtered_df['layer_name'] == layer_name) & (filtered_df['layer_idx'] == layer_idx)]
                
                if not subset.empty:
                    singular_values = subset.iloc[0]['singular_values']
                    positions = range(len(singular_values))
                    
                    # Create a scatter plot with color based on singular value magnitude
                    scatter = ax.scatter(positions, singular_values, c=singular_values, cmap=cmap, 
                                        vmin=0, vmax=max_singular_value, s=30)
                    
                    # Connect points with lines
                    ax.plot(positions, singular_values, color='gray', alpha=0.5, linewidth=1)
                    
                    # Set the subplot title with layer_idx
                    ax.set_title(f'Layer Number: {layer_idx}', fontsize=16)
                    ax.tick_params(axis='both', which='major', labelsize=16)
                    
                    # Set x-axis label and limits for each subplot
                    ax.set_xlabel('Index', fontsize=16)
                    # ax.set_xlim(0, len(singular_values) - 1)
                
                    ax.grid(True, linestyle='--', alpha=0.7)

                if i == 0:
                    ax.set_ylabel('Singular Values', fontsize=16)

        plt.tight_layout()  # Adjust layout to make room for colorbar
        
        # Save the figure as PDF
        plt.savefig('saved_plots/sv_distribution.pdf', format='pdf', bbox_inches='tight')
        plt.close(fig)  # Close the figure after saving to free up memory

    # Call the function to create the plot
    plot_singular_values()

if __name__ == "__main__":
    main()