import os
import torch
import numpy as np
from tqdm import tqdm
import copy
import random
import argparse

def cut_n_layers_from_raw_routing_data(raw_routing_data, num_layers):
    layers_to_stay = list(set([el['layer_idx'] for el in raw_routing_data]))
    layers_to_stay = layers_to_stay[:num_layers]
    
    updated_raw_routing_data = []
    for el in raw_routing_data:
        if el['layer_idx'] in layers_to_stay:
            updated_raw_routing_data.append(el)
    return updated_raw_routing_data

def add_n_layers_to_raw_routing_data(raw_routing_data, num_target_layers):
    unique_layers_ids = list(set([el['layer_idx'] for el in raw_routing_data]))
    num_layers = len(unique_layers_ids)

    num_new_layers = num_target_layers - num_layers
    
    added_layer_ids = list(range(
        max(unique_layers_ids) + 1, 
        max(unique_layers_ids) + 1 + num_new_layers
    ))

    populate_mapping = {
        new_id: random.choice(unique_layers_ids) 
        for new_id in added_layer_ids
    }

    populated_raw_routing_data = []

    next_source_indx = 0
    last_layer_activation = {}

    next_layer_idx = -1

    while next_source_indx < len(raw_routing_data):
        routing_sample = raw_routing_data[next_source_indx]
        next_source_indx+=1
        
        # print("CHECK: ", routing_sample['layer_idx'], next_layer_idx)
        
        if next_layer_idx > routing_sample['layer_idx']:
            # new iteration begins
            for layer_idx in added_layer_ids:
                steal_from_layer_idx = populate_mapping[layer_idx]
                populate_from_activation = copy.deepcopy(
                    last_layer_activation[steal_from_layer_idx]
                )
                populate_from_activation['layer_idx'] = layer_idx
                populated_raw_routing_data.append(populate_from_activation)
            
            last_layer_activation = {}
        # else:
        
        populated_raw_routing_data.append(copy.deepcopy(routing_sample))

        last_layer_activation[routing_sample['layer_idx']] = routing_sample    
        next_layer_idx = routing_sample['layer_idx']

            
    for layer_idx in added_layer_ids:
        steal_from_layer_idx = populate_mapping[layer_idx]
        populate_from_activation = copy.deepcopy(
            last_layer_activation[steal_from_layer_idx]
        )
        populate_from_activation['layer_idx'] = layer_idx
        populated_raw_routing_data.append(populate_from_activation)
        
    return populated_raw_routing_data

def populate_num_layers(raw_routing_data, num_target_layers):
    unique_layers_ids = list(set([el['layer_idx'] for el in raw_routing_data]))
    num_layers = len(unique_layers_ids)
    if num_layers == num_target_layers:
        return raw_routing_data
    elif num_layers < num_target_layers:
        return add_n_layers_to_raw_routing_data(raw_routing_data, num_target_layers)
    else:
        return cut_n_layers_from_raw_routing_data(raw_routing_data, num_target_layers)
    
    

def scale_number_of_experts(raw_routing_data, populate_experts_to):
    current_experts = raw_routing_data[0]['num_experts']
    if current_experts == populate_experts_to:
        return raw_routing_data
    
    upscaled_float = torch.arange(0, current_experts, dtype=torch.float32) * populate_experts_to/current_experts
    lower = (upscaled_float - populate_experts_to/current_experts / 2).ceil()
    lower[lower<0] = +0
    lower[lower>=populate_experts_to] = populate_experts_to-1
    upper = (upscaled_float + populate_experts_to/current_experts / 2).ceil()
    upper[upper>=populate_experts_to] = populate_experts_to-1
    
    sample_ranges = [
        np.arange(l, u+1) for l, u in zip(lower, upper)
    ]
    
    widen_experts_raw_routing_data = []
    for sample in tqdm(raw_routing_data):
        sample_widen = copy.deepcopy(sample)
        sample_widen['topk_indices'].apply_(
            lambda x: np.random.choice(sample_ranges[x])
        )
        sample_widen["num_experts"] = populate_experts_to
        widen_experts_raw_routing_data.append(sample_widen)
    
    return widen_experts_raw_routing_data
    

def calculate_matrix_stats(raw_routing_data):
    unique_layers_ids = list(set([el['layer_idx'] for el in raw_routing_data]))
    # print(unique_layers_ids, len(unique_layers_ids))
    unique_expert_ids = list(range(raw_routing_data[0]['num_experts']))
    
    starting_layer_idx = min(unique_layers_ids)
    
    stats_matrix = torch.zeros(len(unique_layers_ids), len(unique_expert_ids))
    
    for sample in tqdm(raw_routing_data):
        layer_idx = sample['layer_idx'] - starting_layer_idx
        expert_idx = sample['topk_indices']
        # stats_matrix[layer_idx, expert_idx] += 1
        for exp_id in expert_idx.flatten():
            stats_matrix[layer_idx, exp_id] += 1
        
    return stats_matrix

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--raw_routing_data_path", type=str, default="16b_200/deepseek_ai_deepseek_moe_16b_chat_raw_routing_data.pt")
    parser.add_argument("--activation_matrix_path", type=str, default="16b_200/deepseek_ai_deepseek_moe_16b_chat_layer_expert_matrix.pt")
    parser.add_argument("--target_layers", type=int, default=64)
    parser.add_argument("--target_experts", type=int, default=128)
    args = parser.parse_args()
    
    activation_matrix_path = args.activation_matrix_path
    raw_routing_data_path = args.raw_routing_data_path
    
    prefix_output_dir = "scaled_stats"
    
    target_layers = args.target_layers
    target_experts = args.target_experts
    

    raw_routing_data = torch.load(raw_routing_data_path)
    
    # upscaled_experts = scale_number_of_experts(raw_routing_data, target_experts)
    # upscaled_layers = populate_num_layers(upscaled_experts, target_layers)

    upscaled_layers = populate_num_layers(raw_routing_data, target_layers)
    upscaled_experts = scale_number_of_experts(upscaled_layers, target_experts)

    upscaled_stats = calculate_matrix_stats(upscaled_layers)
    
    output_dir = prefix_output_dir + f"_l{target_layers}_e{target_experts}"
    os.makedirs(output_dir, exist_ok=True)
    
    torch.save(upscaled_stats, output_dir + "/stats_matrix.pt")
    torch.save(upscaled_experts, output_dir + "/raw_routing_data.pt")
