import tempfile
from torch.nn.parallel.comm import broadcast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.cache_utils import DynamicCache
from utils import *
from utils import (
    send_to_next_pipeline_rank,
    recv_from_prev_pipeline_rank_,
    send_list_to_next_pipeline_rank,
    _update_causal_mask,
    recv_list_from_prev_pipeline_rank,
    get_interval,
    copy_from_last_to_first_pipeline_stage,
    send_token_and_probs_to_first_pipeline_stage,
    recv_token_and_probs,
    broadcast_from_last_pipeline_stage,
    broadcast_from_first_pipeline_stage,
    broadcast_from_last_to_first_pipeline_stage
)
import transformers
import time
import os
import warnings
import shutil
import torch.distributed as dist
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch

import argparse
import math
def parse_args():
    parser = argparse.ArgumentParser(description='Run an evaluation task')
    parser.add_argument('--shm_file', type=str, default="1",help='Path to shared memory file') ##only needed in opencompass

    #parser.add_argument('--model_name', type=str, help='Name of mode')
    parser.add_argument('--model_path', type=str, help='Path to model/config file')
    parser.add_argument('--split_model_path', type=str, help='Path to split model')
    parser.add_argument('--headclass', type=str, default="norm", help='class of training heads')

    parser.add_argument('--data_path', type=str, help='Path to datasets')
    parser.add_argument('--ckpt_path', type=str, help='Path to tensorboard log')

    parser.add_argument('--stage', type=int, default="3", help='stage indicator')
    parser.add_argument('--maxlen', type=int, default="64", help='stage indicator')

    
    

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    # multi-gpu initialize
    LOCAL_RANK = int(os.environ['LOCAL_RANK'])
    WORLD_SIZE = int(os.environ['WORLD_SIZE'])
    dist.init_process_group(backend='nccl', world_size=WORLD_SIZE, rank=LOCAL_RANK)
    torch.cuda.set_device(LOCAL_RANK)
    device = torch.device(LOCAL_RANK)
    # model, tokenizer, dataset initialize
    args = parse_args()
    model_path = args.model_path
    split_model_path = args.model_path +f"-{WORLD_SIZE}cut/"
    data_path = args.data_path
    config = transformers.AutoConfig.from_pretrained(
            model_path,
        )
    
    orig_ctx_len = getattr(config, "max_position_embeddings", None)
    if orig_ctx_len and 2048 > orig_ctx_len:
        scaling_factor = float(math.ceil(2048 / orig_ctx_len))
        config.rope_scaling = {"type": "linear", "factor": scaling_factor}

    mysize = torch.tensor([0],dtype=torch.int64,device=torch.cuda.current_device())
    batch_size=1 
    if dist.get_rank() == 0:
        gen_max_len=0
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_path,
            model_max_length=2048,
            padding_side="right",
            use_fast=True,
        )

        if tokenizer.pad_token != tokenizer.unk_token:
            tokenizer.pad_token = tokenizer.unk_token

        data_module = make_supervised_data_module(tokenizer=tokenizer, data_path=data_path,lazy_preprocess=True)
        train_dataset = data_module['train_dataset']
    
        data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
        train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False,
                                                   drop_last=True,
                                                   collate_fn=data_collator)
    
        input_ids, labels, attention_mask = next(iter(train_dataloader))
        mysize = torch.tensor([input_ids.size(0)],dtype=torch.int64,device=torch.cuda.current_device())
        
        lengths = input_ids.size(1)
        min_prompt_length = input_ids.size(1)
        max_sequence_length = input_ids.size(1) + 20
        output_tensor = torch.zeros((input_ids.size(0),args.maxlen),dtype=torch.int64,device=torch.cuda.current_device())
    dist.broadcast(mysize,0)
    early_exit_thres = 1
    mymodel = MyModel(LOCAL_RANK,config,split_model_path,WORLD_SIZE).cuda()
    mymodel.eval()

    MYHEAD_CLASSES = {
        "norm": Normhead,
        "mlp": MLPhead,
        "trm": Transformerhead,
    }

    layer_list = []
    interval = config.num_hidden_layers//WORLD_SIZE
    left = config.num_hidden_layers%WORLD_SIZE
    for i in range(WORLD_SIZE):
        layer_list.append(interval)
    for i in range(left):
        layer_list[-i-1]+=1
    
    layer_sum=0
    for i in range(LOCAL_RANK+1):
        layer_sum+=layer_list[i]

    myhead = MYHEAD_CLASSES[args.headclass](config, LOCAL_RANK, split_model_path,True)
    myhead.load_state_dict(torch.load(f"{args.ckpt_path}lm_head4.pth"))
    myhead.cuda()
        
    rotary_emb = LlamaRotaryEmbedding(config=config)
    rotary_emb.load_state_dict(torch.load(f"{split_model_path}rotary_emb.pth"))

    modelnorm = LlamaRMSNorm(config.hidden_size,config.rms_norm_eps)
    modelhead = nn.Linear(config.hidden_size,32000,bias=False)
    modelnorm.load_state_dict(torch.load(f"{split_model_path}norm.pth"))
    modelhead.load_state_dict(torch.load(f"{split_model_path}lmhead.pth"))
    modelnorm.cuda()
    modelhead.cuda()

    embed_tokens = nn.Embedding(32000,config.hidden_size,padding_idx=0)
    embed_tokens.load_state_dict(torch.load(f"{split_model_path}embed_tokens.pth"))
    
    
    embed_tokens.cuda()
    rotary_emb.cuda()
    maxlen = args.maxlen

    len_all=0
    time_all=0
    match_all =0
    inter = get_interval(args.stage,WORLD_SIZE)

    # K is batch_size, due to setting batch_size=1, need to run k loops
    for k in range(mysize.item()):
        neg = torch.tensor([-1]).to(device)
        seq_len = torch.tensor([1]).to(device)
        # initialize inputs and position_ids in GPU 0
        if dist.get_rank() == 0:
            generated_sequence = []
            micro_attention_mask = attention_mask[k,:].to(device).unsqueeze(0)
            ne_len = micro_attention_mask.ne(1).sum().item()

            micro_input_ids = input_ids[k,:].to(device).unsqueeze(0)
            if ne_len != 0:
                micro_input_ids = micro_input_ids[:,ne_len:]

            past_key_values=None
            use_cache = False
            inputs_embeds = embed_tokens(micro_input_ids)

            return_legacy_cache = False
            if use_cache and not isinstance(past_key_values, Cache):
                return_legacy_cache = True
                if past_key_values is None:
                    past_key_values = DynamicCache()
                else:
                    past_key_values = DynamicCache.from_legacy_cache(past_key_values)
                    logger.warning_once(
                        "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
                        "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
                        "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
                    )

            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                )

            position_ids = cache_position.unsqueeze(0)
            causal_mask = _update_causal_mask(
                config,micro_attention_mask, inputs_embeds, cache_position, past_key_values, False
            )
            hidden_states = inputs_embeds
            position_embeddings = rotary_emb(hidden_states, position_ids)

            torch.cuda.empty_cache()
            seq_len[0] = inputs_embeds.size(1) 

            input_tensor = inputs_embeds
            input_position = position_ids
            seq_len = torch.tensor([inputs_embeds.size(1)]).to(device)

            _inputs_embeds = inputs_embeds
            _position_ids = position_ids
            position_embeddings = tuple(t.to(device) for t in position_embeddings)
            torch.cuda.empty_cache()

            len1 = 0
            len2 = seq_len.item()

        
        idx = 0
        past_key_values = DynamicCache()
        dist.broadcast(seq_len, 0)
        flag = 0
        fflag=0
        start1 = time.time()
        match = 0
        count = 0
        # initialize communication buffers 
        with ((torch.no_grad())):
            if dist.get_rank() != WORLD_SIZE-1:
                expanded_hidden_states = torch.zeros((1, seq_len + maxlen, config.hidden_size), dtype=torch.float,
                                                     device=torch.cuda.current_device())
                expanded_position = torch.zeros((1, seq_len + maxlen), dtype=torch.int64,
                                                device=torch.cuda.current_device())

            if dist.get_rank() != 0:
                recv_buffers = [
                    torch.zeros((1, seq_len + maxlen, config.hidden_size), dtype=torch.float, device=torch.cuda.current_device()),
                    torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device()),
                    torch.zeros((1, seq_len + maxlen), dtype=torch.int64, device=torch.cuda.current_device()),
                    torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device()),
                    torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device()),
                    torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device()),
                    #torch.zeros((1, seq_len + maxlen), dtype=torch.int64, device=torch.cuda.current_device()),
                ]
            if dist.get_rank() == 0:
                check_idx = torch.tensor([-1], dtype=torch.int64, device=torch.cuda.current_device())
                check_len1 = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
                check_len2 = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
                check_token = torch.tensor([-1],  dtype=torch.int64, device=torch.cuda.current_device())
            new_head_kv = DynamicCache()
            head_kv = DynamicCache()
            #ready to generate
            while True:
                if idx == WORLD_SIZE-1:
                    start1 = time.time()
                has_early_exited = False
                prev_has_early_exited = False
                # GPU0 make input_tensor, input_position and exit control mechanism: 1e6 refers to the end of sentence and 1e5 refers to the maximum of generation
                if dist.get_rank() == 0:
                    input_tensor = inputs_embeds[:, len1:len2, :]
                    input_position = position_ids[:, len1:len2]
                    if idx >= 1e6:
                        idx_tensor = torch.tensor([1e6], dtype=torch.int64, device=torch.cuda.current_device())
                        send_list_to_next_pipeline_rank(
                            [expanded_hidden_states, signal_tensor, expanded_position, len1_tensor, len2_tensor, idx_tensor])
                        break
                    if len(generated_sequence) >= maxlen:
                        gen_max_len=max(gen_max_len,maxlen)
                        idx_tensor = torch.tensor([1e5], dtype=torch.int64, device=torch.cuda.current_device())
                        send_list_to_next_pipeline_rank(
                            [expanded_hidden_states, signal_tensor, expanded_position, len1_tensor, len2_tensor, idx_tensor])
                        break
                # GPU>0 receive input_tensor, input_position and exit control mechanism
                if not dist.get_rank() == 0:
                    recv_list_from_prev_pipeline_rank(recv_buffers)
                    prev_has_early_exited = bool(recv_buffers[1])
                    len1 = recv_buffers[3].item()
                    len2 = recv_buffers[4].item()
                    idx = recv_buffers[5].item()
                    input_tensor = recv_buffers[0][:, len1:len2, :]
                    input_position = recv_buffers[2][:, len1:len2]

                    if idx >= 1e5 and dist.get_rank() <WORLD_SIZE-1:
                        idx_tensor = torch.tensor([idx], dtype=torch.int64, device=torch.cuda.current_device())
                        send_list_to_next_pipeline_rank(
                            [recv_buffers[0], signal_tensor, recv_buffers[2], recv_buffers[3], recv_buffers[4], idx_tensor])
                        break
                    if idx >= 1e5 and dist.get_rank() ==WORLD_SIZE-1:
                        break
                #KV cache manager and model runner
                if idx > 0:
                    new_kv = DynamicCache()
                    for i in range(len(past_key_values)):
                        new_kv.update(key_states=past_key_values[i][0][:, :, :len1, :],
                                      value_states=past_key_values[i][1][:, :, :len1, :], layer_idx=i)
                    if args.headclass == 'trm' and dist.get_rank()!=WORLD_SIZE-1:
                        new_head_kv = DynamicCache()
                        for i in range(len(head_kv)):
                            new_head_kv.update(key_states=head_kv[i][0][:, :, :len1, :],
                                          value_states=head_kv[i][1][:, :, :len1, :], layer_idx=i)
                    outputs = mymodel(input_tensor, None,input_position, new_kv,None)

                else:
                    outputs = mymodel(input_tensor, None,input_position, past_key_values,None)
                hidden_states = outputs[0]
                past_key_values = outputs[1]
                # GPU0 early exit head runner and outputs verify
                if dist.get_rank() == 0:
                    if args.headclass == 'trm':
                        logits,head_kv = myhead(hidden_states,input_position,new_head_kv)
                    else:
                        logits = myhead(hidden_states)
                    last_token_logits = logits[:, -1, :]

                    log_probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
                    max_log_prob, token_id = torch.max(log_probs, dim=-1)

                    has_early_exited = max_log_prob[-1] >= early_exit_thres and idx>0
                    if args.stage == 0 and idx>0:
                        has_early_exited = 1

                    signal_tensor = torch.tensor(
                        [int(has_early_exited or prev_has_early_exited)],
                        dtype=torch.int64,
                        device=torch.cuda.current_device())

                    len1_tensor = torch.tensor([len1], dtype=torch.int64, device=torch.cuda.current_device())
                    len2_tensor = torch.tensor([len2], dtype=torch.int64, device=torch.cuda.current_device())
                    idx_tensor = torch.tensor([idx], dtype=torch.int64, device=torch.cuda.current_device())

                    expanded_hidden_states[:, len1:len2, :] = hidden_states
                    expanded_position[:, len1:len2] = input_position

                    send_list_to_next_pipeline_rank(
                        [expanded_hidden_states, signal_tensor, expanded_position, len1_tensor, len2_tensor, idx_tensor])

                    if not has_early_exited:
                        ee_idx = recv_token_and_probs(has_early_exited=has_early_exited,
                                                  token_tensor_buffer=token_id,
                                                  prob_tensor_buffer=max_log_prob)
                    else:
                        ee_idx=0

                    if  ee_idx != WORLD_SIZE-1 and count>inter:
                        
                        req1=dist.irecv(check_token, WORLD_SIZE-1)
                        req2=dist.irecv(check_len1, WORLD_SIZE-1)
                        req3=dist.irecv(check_len2, WORLD_SIZE-1)
                        req4=dist.irecv(check_idx, WORLD_SIZE-1)
                        
                        req1.wait()
                        req2.wait()
                        req3.wait()
                        req4.wait()



                    new_input_embeds = embed_tokens(token_id).to(inputs_embeds.device)

                    if idx >= len(generated_sequence):
                        inputs_embeds = torch.cat((inputs_embeds, new_input_embeds.unsqueeze(0)), dim=1)
                        generated_sequence.append(token_id.item()) 
                        l2 = position_ids[:,-1:]+1
                        position_ids = torch.cat((position_ids, l2), dim=1)
                    else:
                        inputs_embeds[:, len2, :] = new_input_embeds
                        generated_sequence[idx] = token_id.item()

                    if check_idx.item() == 1:
                        len1 = copy.deepcopy(check_len1.item())
                        len2 = copy.deepcopy(check_len2.item())
                        idx = copy.deepcopy(check_idx.item())
                        generated_sequence[check_idx.item()] = copy.deepcopy(check_token.item())
                        inputs_embeds[:,len2,:]= copy.deepcopy(embed_tokens(check_token).to(inputs_embeds.device))
                    
                    fflag=-1
                    if check_idx.item() != -1 and check_idx.item() < idx and idx<maxlen-2 :
                        if check_token.item()!=generated_sequence[check_idx.item()]:
                            len1 = copy.deepcopy(check_len1.item())
                            len2 = copy.deepcopy(check_len2.item())
                            idx = copy.deepcopy(check_idx.item())
                            generated_sequence[idx] = copy.deepcopy(check_token.item())
                            inputs_embeds[:,len2,:]= copy.deepcopy(embed_tokens(check_token).to(inputs_embeds.device))
                            fflag=copy.deepcopy(check_idx.item())
                        else:
                            fflag=copy.deepcopy(check_idx.item())
                            match+=1
                    check_idx[0] = -1

                    if idx>0 and generated_sequence[idx]==2:
                        gen_max_len=max(len(generated_sequence),gen_max_len)
                        idx=1e6
                    elif ee_idx<WORLD_SIZE-1  and fflag>0 and generated_sequence[fflag]==2 :
                        generated_sequence = generated_sequence[:fflag+1]
                        gen_max_len=max(len(generated_sequence),gen_max_len)
                        idx=1e6

                  

                    len1 = len2
                    idx += 1
                    len2 += 1
                # GPU>0 early exit head runner
                elif dist.get_rank() != 0 and dist.get_rank()<WORLD_SIZE-1:
                    if args.headclass == 'trm':
                        logits,head_kv = myhead(hidden_states,input_position,new_head_kv)
                    else:
                        logits = myhead(hidden_states)
                    last_token_logits = logits[:, -1, :]

                    log_probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
                    max_log_prob, token_id = torch.max(log_probs, dim=-1)

                    has_early_exited = max_log_prob[-1] >= early_exit_thres and idx>0
                    if args.stage == dist.get_rank() and idx>0:
                        has_early_exited = 1
                    signal_tensor = torch.tensor(
                        [int(has_early_exited or prev_has_early_exited)],
                        dtype=torch.int64,
                        device=torch.cuda.current_device())

                    len1_tensor = torch.tensor([len1], dtype=torch.int64, device=torch.cuda.current_device())
                    len2_tensor = torch.tensor([len2], dtype=torch.int64, device=torch.cuda.current_device())
                    idx_tensor = torch.tensor([idx], dtype=torch.int64, device=torch.cuda.current_device())
                    expanded_hidden_states[:, len1:len2, :] = hidden_states
                    expanded_position[:, len1:len2] = input_position

                    send_list_to_next_pipeline_rank(
                        [expanded_hidden_states, signal_tensor, expanded_position, len1_tensor, len2_tensor, idx_tensor])

                    if not prev_has_early_exited:
                        send_token_and_probs_to_first_pipeline_stage(has_early_exited=has_early_exited,
                                                                     token_tensor=token_id,
                                                                     prob_tensor=max_log_prob)

                # last GPU model lm_head runner and verification token sending
                elif dist.get_rank() == WORLD_SIZE-1:

                    hidden_states= modelnorm(hidden_states)
                    logits = modelhead(hidden_states)
                    last_token_logits = logits[:, -1, :]
                    log_probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
                    max_log_prob, token_id = torch.max(log_probs, dim=-1)
                    has_early_exited = 1 
                    token_id = torch.argmax(last_token_logits,dim=-1)

                    if not prev_has_early_exited:

                        send_token_and_probs_to_first_pipeline_stage(has_early_exited=has_early_exited,
                                                                     token_tensor=token_id,
                                                                     prob_tensor=max_log_prob,
                                                                     is_final=True)
                    elif idx>0 :
                        len1_tensor = torch.tensor([len1], dtype=torch.int64, device=torch.cuda.current_device())
                        len2_tensor = torch.tensor([len2], dtype=torch.int64, device=torch.cuda.current_device())
                        idx_tensor = torch.tensor([idx], dtype=torch.int64, device=torch.cuda.current_device())

                        dist.isend(token_id, 0)
                        dist.isend(len1_tensor, 0)
                        dist.isend(len2_tensor, 0)
                        dist.isend(idx_tensor, 0)

                count+=1
        # generation outputs
        if dist.get_rank() == 0:
            gen_tensor = torch.tensor(generated_sequence,dtype=int)

            #print("original outputs")
            #print(tokenizer.decode(myinput_ids[:,-args.maxlen:].squeeze(0)))
            #
            #print(gen_tensor)
            #print("my outputs")
            #pp = tokenizer.decode(gen_tensor)
            #print(pp)
            match_all += match
            len_cur=len(generated_sequence)
            #print("Tensor:",gen_tensor)
            output_tensor[k,:len_cur]=gen_tensor
            len_all += len_cur
            if args.stage!=WORLD_SIZE-1:
                for i in range(inter):
                    req1=dist.irecv(check_token, WORLD_SIZE-1)
                    req2=dist.irecv(check_len1, WORLD_SIZE-1)
                    req3=dist.irecv(check_len2, WORLD_SIZE-1)
                    req4=dist.irecv(check_idx, WORLD_SIZE-1)
                    req1.wait()
                    req2.wait()
                    req3.wait()
                    req4.wait()


        # generation time 
        if dist.get_rank() == WORLD_SIZE-1:
            end1 = time.time()
            time_cur = end1-start1
            time_all += time_cur

        if k <mysize.item()-1:
            dist.barrier()

    if dist.get_rank()==0:
        print("Length:",len_all)
        #print("Tensor:",output_tensor)
        #print("outputs:",)
        print("Match:",match_all)
        #with tempfile.NamedTemporaryFile(delete=False) as f:
        #    torch.save(output_tensor[:,:gen_max_len+1], f.name)
        #    shm_file = f.name
        #print("Temp File:",f.name)

    if dist.get_rank()==WORLD_SIZE-1:
        print("Time:",time_all)

    dist.destroy_process_group()
    exit()

