from scale_stats import calculate_matrix_stats
import torch
import argparse
import os

def split_stats(path_to_raw_routing_data, test_size=0.2):
    raw_routing_data = torch.load(path_to_raw_routing_data)
    
    min_layer_idx = min(routing_data["layer_idx"] for routing_data in raw_routing_data)
    
    split_idx = int(len(raw_routing_data) * (1-test_size))
    while raw_routing_data[split_idx]["layer_idx"] != min_layer_idx:
        split_idx -= 1
    
    train_routing_data = raw_routing_data[:split_idx]
    test_routing_data = raw_routing_data[split_idx:]
    
    train_stats = calculate_matrix_stats(train_routing_data)
    test_stats = calculate_matrix_stats(test_routing_data)
    
    return train_routing_data, train_stats, test_routing_data, test_stats
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--path_to_raw_routing_data", type=str, default="16b_200/deepseek_ai_deepseek_moe_16b_chat_raw_routing_data.pt")
    parser.add_argument("--output_dir_prefix", type=str, default="splited_stats")
    parser.add_argument("--test_size", type=float, default=0.2)
    args = parser.parse_args()
    
    train_routing_data, train_stats, test_routing_data, test_stats = split_stats(args.path_to_raw_routing_data, args.test_size)
    print(f"Train routing data size: {len(train_routing_data)}")
    print(f"Test routing data size: {len(test_routing_data)}")
    
    save_dir = f"{args.output_dir_prefix}_{args.test_size}"
    os.makedirs(save_dir, exist_ok=True)
    torch.save(train_routing_data, f"{save_dir}/train_routing_data.pt")
    torch.save(test_routing_data, f"{save_dir}/test_routing_data.pt")
    torch.save(train_stats, f"{save_dir}/train_stats.pt")
    torch.save(test_stats, f"{save_dir}/test_stats.pt")
    