import torch
import random
import argparse
import numpy as np


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def args_parser():
    parser = argparse.ArgumentParser()
    # General
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--do_pretrain', action='store_true')
    # Dataset
    parser.add_argument('--data_dir', type=str, default='data')
    parser.add_argument('--folder_name', type=str, default='memmap')
    parser.add_argument('--dataset_info_path', type=str, default='dataset_info.json')
    # Dataloader
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--num_workers', type=int, default=16)
    # Model
    parser.add_argument('--embedding_dim', type=int, default=768)
    parser.add_argument('--max_orig_positional_len', type=int, default=2048)
    parser.add_argument('--vocab_size', type=int, default=30522)
    parser.add_argument('--hidden_size', type=int, default=768)
    parser.add_argument('--num_hidden_layers', type=int, default=6)
    parser.add_argument('--num_attention_heads', type=int, default=12)
    parser.add_argument('--intermediate_size', type=int, default=3072)
    parser.add_argument('--intermediate_size_expert', type=int, default=3072)
    parser.add_argument('--num_expert_heads', type=int, default=0)
    parser.add_argument('--pad_token_id', type=int, default=16)
    parser.add_argument('--hidden_act', type=str, default='gelu')
    parser.add_argument('--token_moe', type=bool, default=True)
    parser.add_argument('--moe_type', type=str, default='topk')
    parser.add_argument('--topk', type=int, default=1)
    parser.add_argument('--hash_list_path', type=str, default=None)
    parser.add_argument('--num_experts', type=int, default=1)
    parser.add_argument('--num_sparse_layers', type=int, default=3)
    parser.add_argument('--gradient_checkpointing', type=bool, default=False)
    parser.add_argument('--mean_pooling', type=bool, default=True)
    # Optimizer
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--adam_epsilon', type=float, default=1e-6)
    # Scheduler
    parser.add_argument('--warmup_steps', type=int, default=2000)
    # Loss
    parser.add_argument('--temperature', type=float, default=0.05)
    # Logging
    parser.add_argument('--project_name', type=str, default='moe')
    parser.add_argument('--run_name', type=str, default='run')
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--load_model', type=str, default=None)
    return parser.parse_args()


def get_router_stats(router_logits):
    router_probs = torch.zeros((len(router_logits), router_logits[0].shape[-1]))
    for i, router_logit in enumerate(router_logits):
        router_probs[i] = torch.nn.functional.softmax(router_logit.float().mean(dim=1), dim=-1).mean(0)
    return router_probs.detach().cpu().numpy()