from transformers.models.opt import OPTConfig
from transformers import AutoTokenizer
from flash_attn.models.opt import opt_config_to_gpt2_config

import os
import torch
import argparse
from apex.transformer import parallel_state

from HybridTensor.utils.utils import arg_parser, _get_device
from HybridTensor.utils.activations import OPT_MODELS
from HybridTensor.models.opt import SparseConfig, build_sparse_opt, build_dense_opt


def initialize_distributed_environment():
    # Set environment variables for NCCL
    os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
    os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"

    # Initialize the distributed process group
    torch.distributed.init_process_group(backend="nccl", init_method="env://")

    # Set the device based on the rank of the current process
    device = f"cuda:{torch.distributed.get_rank()}"
    world_size = torch.distributed.get_world_size()

    # Set the current CUDA device to avoid operations being executed on the wrong GPU
    torch.cuda.set_device(device)

    # You can return device, world_size, and any other relevant information
    return device, world_size

def _turn_bias_off(model, num_layers):
    for i in range(num_layers):
        model.transformer.layers[i].mlp.fc1.bias = None
        model.transformer.layers[i].mlp.fc2.bias = None

def arg_parser():
    parser = argparse.ArgumentParser(description='Inference benchmarking')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--model_index', type=int, default=5)
    parser.add_argument('--seq_len', type=int, default=25)
    parser.add_argument('--index_size', type=int, default=8192)
    parser.add_argument('--head_density', type=float, default=0.25)
    parser.add_argument('--print_results', type=bool, default=True)
    parser.add_argument('--iterations', type=int, default=2)
    parser.add_argument('--check_results', type=bool, default=False)
    parser.add_argument('--results_dir', type=str, default='results')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--bias', type=bool, default=False)
    parser.add_argument('--mlp_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mlp')
    parser.add_argument('--attn_ckpt_dir', type=str, default='/home/grads/s/<name>/nvme/HybridTensor/checkpoint/opt-6.7b-routers/mha_linear')
    
    return parser.parse_args()

if __name__ == "__main__":
    
    args = arg_parser()
    model_name = OPT_MODELS[args.model_index-1]

    device, world_size = initialize_distributed_environment()
    dtype = torch.float16
    
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # model = build_sparse_opt(model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype, process_group = process_group, world_size = world_size, rank = rank)
    model = build_dense_opt(model_name, process_group = process_group, world_size = world_size, rank = rank, device = device, dtype=dtype)
    model.eval()
    # if rank == 0:
    #     print(model)
    
    # input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
    input_texts = ["In a distant galaxy, a spaceship"]
    tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(device)
    input_ids=tokenized_inputs["input_ids"]
    # input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device)
    
    max_length = args.seq_len
    position_ids = None
    eos_token_id = tokenizer.eos_token_id
    num_layers = model.config.n_layer
    
    # turn bias off for mlp layers
    if not args.bias:
        _turn_bias_off(model, num_layers)

    _ = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        eos_token_id=eos_token_id,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=False,
        )

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    
    for i in range(args.iterations):
        out = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            eos_token_id=eos_token_id,
            return_dict_in_generate=True,
            output_scores=True,
            enable_timing=False,
            )
    
    end_event.record()
    
    torch.cuda.synchronize()

    # print(tokenizer.batch_decode(out.sequences.tolist()))
    
    if rank == 0:
        elapsed_time = start_event.elapsed_time(end_event) / args.iterations
        print(f"Average time per genearation : {elapsed_time} ms")
        
        # Compute throughput and latency per token
        num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
        throughput = num_tokens_generated / (elapsed_time / 1000)  # tokens per second
        latency_per_token = elapsed_time / num_tokens_generated  # ms per token
        
        print(f"Number of tokens generated: {num_tokens_generated}")
        print(f"Throughput: {throughput} tokens/second")
        print(f"Latency per token: {latency_per_token} ms")
        print(tokenizer.batch_decode(out.sequences.tolist()))
        
