import torch
from typing import Dict, List, Optional, Set, Tuple
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
from filelock import FileLock
import os
import time
import socket
import pickle
from vllm_inject.utils import *
        
@torch.inference_mode()
def execute_model(
    self,
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
    kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
    (input_tokens, input_positions, attn_metadata, sampling_metadata,
        lora_requests, lora_mapping, multi_modal_input
        ) = self.prepare_input_tensors(seq_group_metadata_list)

    if self.lora_config:
        self.set_active_loras(lora_requests, lora_mapping)

    # Execute the model.
    if attn_metadata.use_cuda_graph:
        graph_batch_size = input_tokens.shape[0]
        model_executable = self.graph_runners[graph_batch_size]
    else:
        model_executable = self.model
    execute_model_kwargs = {
        "input_ids": input_tokens,
        "positions": input_positions,
        "kv_caches": kv_caches,
        "attn_metadata": attn_metadata,
    }
    if self.vision_language_config:
        execute_model_kwargs.update({"image_input": multi_modal_input})
    hidden_states = model_executable(**execute_model_kwargs)

    # Compute the logits.
    logits = self.model.compute_logits(hidden_states, sampling_metadata)
    # print(logits.shape)
    HOST = socket.gethostname()
    PORT = 11454
    client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    while 1:
        try:
            client_socket.connect((HOST, PORT))
            # new_add_num_data = pickle.dumps((self.scheduler_config.model_name, logits))
            send_large_tensor(client_socket, (self.model_config.model, logits))
            # client_socket.send(new_add_num_data)
            logits = receive_large_tensor(client_socket).to("cuda:0")
            # logits = logits_sum.to("cuda:0")
            client_socket.close()
            break
        except:
            print("try to reconnect")
            time.sleep(2)
            continue
    # file_path = "/xx/analysis/vllm_inject/logits_gsm"
    # now_name = self.model_config.model.replace("/", "#")
    # path_now = os.path.join(file_path, now_name+".pkl_read") 
    # if os.path.exists(path_now.replace("_read", "_unread")):
    #     os.rename(path_now.replace("_read", "_unread"), path_now)
    # torch.save(logits, path_now)
    
    # os.rename(path_now, path_now.replace("_read", "_unread"))
    # # Only perform sampling in the driver worker.
    # if not sampling_metadata.perform_sampling:
    #     return None

    # path_13b = os.path.join(file_path, "meta-llama#Llama-2-13b-hf.pkl_unread")
    # path_7b = os.path.join(file_path, "meta-llama#Llama-2-7b-hf.pkl_unread")
    # path_7b_expert = os.path.join(file_path, "meta-llama#Llama-2-7b-chat-hf.pkl_unread")
    # path_list = [path_13b, path_7b, path_7b_expert]
    # logits_all = []
    # for idx, path_i in enumerate(path_list):
    #     while not os.path.exists(path_i):
    #         time.sleep(0.1)
    #     logits_now = torch.load(path_i)
    #     logits_all.append(logits_now)
    # # for path_i in path_list:
    # #     if os.path.exist(path_i):
    # #         os.rename(path_i, path_i.replace("unread", "read"))
    # os.rename(path_now.replace("_read", "_unread"), path_i)
    # logits = logits_all[0] - logits_all[1] + logits_all[2]
    
        
    # Sample the next token.
    output = self.model.sample(
        logits=logits,
        sampling_metadata=sampling_metadata,
    )
    # for idx, i in enumerate(output.outputs):
    #     i.logits = logits[idx]
    return output

from vllm.worker.model_runner import ModelRunner
setattr(ModelRunner, "execute_model", execute_model)

