import os
import json
import argparse
import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoConfig
from typing import List, Dict, Any

def parse_args():
    parser = argparse.ArgumentParser(description='Generate random vectors based on model config')
    parser.add_argument('--paths', nargs='+', required=True, help='List of paths to process')
    parser.add_argument('--prefix', type=str, required=True, help='Prefix for metadata JSON files')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--n_rand', type=int, default=10, help='Number of random vectors to generate')
    parser.add_argument(
        '--scales',
        nargs='+',            
        type=float,
        default=1.0,
        help="List scale multipliers"
    )
    return parser.parse_args()

def read_metadata(path: str, prefix: str) -> Dict[str, Any]:
    metadata_file = os.path.join(path, f"{prefix}_metadata.json")
    with open(metadata_file, 'r') as f:
        return json.load(f)

def get_model_dim(model_name: str) -> int:
    config = AutoConfig.from_pretrained(model_name)
    return config.hidden_size

def generate_random_vectors(d_model: int, n_rand: int, seed: int = 42) -> torch.Tensor:
    """
    Generate random vectors with normalized columns.
    
    Args:
        d_model: Model dimension (rows)
        n_rand: Number of random vectors (columns)
        seed: Random seed
        
    Returns:
        Tensor of shape (d_model, n_rand) with normalized columns
    """
    # Set random seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # Generate random vectors with shape (d_model, n_rand)
    V = torch.randn(d_model, n_rand)
    
    # Normalize each column (dim=0 means normalize along the first dimension)
    V_normalized = F.normalize(V, p=2, dim=0)
    
    return V_normalized

def main():
    args = parse_args()
    
    for path in args.paths:
        print(f"Processing {path}...")
        metadata = read_metadata(path, args.prefix)
        
        # Extract necessary data from metadata
        model_name = metadata['model_name']
        base_input_scale = metadata['base_input_scale']
        #scalar_multipliers_str = metadata['scalar_multipliers']
        #scalar_multipliers = [float(x) for x in scalar_multipliers_str.split(',')]
        scalar_multipliers = args.scales
        
        # Get model dimension
        d_model = get_model_dim(model_name)
        print(f"Using model {model_name} with dimension {d_model}")
        
        # Generate random vectors with normalized columns
        V = generate_random_vectors(d_model, args.n_rand, args.seed)
        print(f"Generated random vectors with shape {V.shape}")
        
        # Generate scaled copies of V
        all_scaled_vectors = []
        vector_info = []
        
        for multiplier in scalar_multipliers:
            scale = multiplier * base_input_scale
            scaled_V = V * scale
            all_scaled_vectors.append(scaled_V)
            
            # Generate vector info for this scaling
            for idx in range(args.n_rand):
                vector_info.append(f"scale_{multiplier}_idx_{idx}")
        
        # Concatenate all scaled vectors (along dimension 1, which is columns)
        final_V = torch.cat(all_scaled_vectors, dim=1)
        print(f"Final tensor shape: {final_V.shape}")
        
        # Save output
        output_path = os.path.join(path, "rand_V.pt")
        torch.save(final_V, output_path)
        print(f"Saved scaled vectors to {output_path}")
        
        info_path = os.path.join(path, "rand_vector_info.json")
        with open(info_path, 'w') as f:
            json.dump(vector_info, f, indent=2)
        print(f"Saved vector info to {info_path}")

if __name__ == "__main__":
    main()