import os
import random
from typing import Optional, List

import torch
from tqdm import tqdm

FP32_EPS = 1e-7
BF16_EPS = 1e-3

def generate_random_group_labels(
        num_experts: int,
        num_groups: int,
) -> torch.Tensor:
    group_labels = torch.zeros(num_experts, dtype=torch.long)
    for i in range(num_groups):
        group_labels[i] = i
    for i in range(num_groups, num_experts):
        group_labels[i] = random.randint(0, num_groups - 1)
    return group_labels[torch.randperm(num_experts)]


def apply_mask(module, _mask):
    
    def masking(_, i):
        return _mask * i[0]

    handle = module.register_forward_pre_hook(masking)
    return handle

def hijack(module, _list, _hijack_input, _stop_forward=False):
    
    if _hijack_input:
        def input_hook(_, inputs, __):
            _list.append(inputs[0].detach().cpu()) 
            
                

        handle = module.register_forward_hook(input_hook)
    else:
        def output_hook(_, __, outputs):
            if isinstance(outputs, tuple):
                _list.append(outputs[0].detach().cpu())
            else:
                _list.append(outputs.detach()) 
            
                
        handle = module.register_forward_hook(output_hook)
    return handle  

def remove_col(x, idx):
    return torch.cat([x[:, :idx], x[:, idx+1:]], dim=-1)

def remove_row(x, idx):
    return torch.cat([x[:idx], x[idx+1:]], dim=0)   

@torch.no_grad()
def collect_act(data, weight1, weight3=None):
    activations = []
    act = torch.nn.SiLU()
    if weight3 is not None:
        cur = act(torch.matmul(data, weight1.T)) * torch.matmul(data, weight3.T)
    else:
        cur = torch.matmul(data, weight1.T)
    activations.append(cur.reshape(-1, cur.shape[-1]))
    return torch.cat(activations, dim=0) 

@torch.no_grad()
def collect_feature(ingredient, data, weight1, weight2, weight3):
    if ingredient == "act":
        return collect_act(data, weight1, weight3)
    elif ingredient == "weight":
        
        return torch.cat([weight1.T, weight2, weight3.T], dim=0)
    else: 
        return collect_act(data, weight1, weight3), torch.cat([weight1.T, weight2, weight3.T], dim=0)

@torch.no_grad()
def compute_covariance(act1, act2):
    with torch.no_grad():
        print(f"compute covariance: {act1.shape}, {act2.shape}")
        mean1 = act1.mean(dim=0, keepdim=True)
        mean2 = act2.mean(dim=0, keepdim=True)
        std1 = act1.std(dim=0, keepdim=True)
        std2 = act2.std(dim=0, keepdim=True)
        corr_matrix = torch.matmul((act1 - mean1).T, act2 - mean2) / (act1.shape[0] - 1)
        mean1 = mean1.to("cpu")
        mean2 = mean2.to("cpu")
        del mean1, mean2
        torch.cuda.empty_cache()
        corr_matrix = corr_matrix / (std1.T * std2 + FP32_EPS)
        del std1, std2
        torch.cuda.empty_cache()
    return corr_matrix 

@torch.no_grad()
def compute_feature_covariance(ingredient, data1, data2):
    if ingredient == "act+weight":
        corr1 = compute_covariance(data1[0], data2[0])
        corr2 = compute_covariance(data1[1], data2[1])
        return corr1 + corr2
    else:
        return compute_covariance(data1, data2)

def get_coef(num_ffn, input_weight, average_coefs, d_ff=None):
    if d_ff == None: 
        if input_weight is not None:
            coef = input_weight
        elif average_coefs is None:
            coef = [1.0] * num_ffn
        elif len(average_coefs) == num_ffn:
            coef = average_coefs
        else:
            coef = [1.0] * num_ffn
    else: 
        if input_weight is not None:
            coef = []
            for w in input_weight:
                coef = [w] * d_ff
                coef.extend(coef)
        elif average_coefs is None:
            coef = [1.0] * num_ffn * d_ff
        elif len(average_coefs) == num_ffn:
            coef = [coef for coef in average_coefs for _ in range(d_ff)]
        elif len(average_coefs) != num_ffn * d_ff:
            raise ValueError(
                f"The length of average_coefs should be either {num_ffn} or {num_ffn * d_ff}, "
                f"but got {len(average_coefs)}."
            )
    return coef
