import socket
import os
import time
import argparse
import sys
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer

from multiprocessing import Process, Barrier, set_start_method

logger.remove()
logger.add(sys.stdout, level="INFO")
logger.add("client_debug.txt", level="INFO", rotation="100 MB", encoding="utf-8")

def run_client(host, port, model_name, gpu_id, iterative, delay_list, barrier, use_fp8):
    try:

        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        print(f"[Client-{port}] Using GPU {gpu_id}")


        from vllm import LLM, SamplingParams
        sampleing_params = SamplingParams(max_tokens=100)


        if use_fp8:
            model = LLM(model_name, quantization="fp8")
            print(f"[Client-{port}] Loaded model with FP8 quantization.")
        else:
            model = LLM(model_name)
            print(f"[Client-{port}] Loaded model with default FP16 precision.")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        barrier.wait()
        print(f"[Client-{port}] All models loaded. Connecting to server...")

        iteration_id = 0
        connection_count = 0

        while True:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            try:
                print(f"[Client-{port}] Trying to connect to {host}:{port} ...")
                s.connect((host, port))
                connection_count += 1
                logger.info(f"[Client-{port}] Successful connections: {connection_count}")
                print(f"[Client-{port}] Connected to server.")

                while True:
                    data = s.recv(4096)
                    if not data:
                        print(f"[Client-{port}] Connection closed by server.")
                        break

                    prompt = data.decode()
                    # print(f"[Client-{port}] Received prompt:\n{prompt}")
                    character, move_list, context_prompt = prompt.split("------")
                    print(f"[Client-{port}] Received:\n{character}\n{move_list}\n{context_prompt}")

                    background = f"You are the best and most aggressive Street Fighter III 3rd strike player in the world. Your character is {character}. Your goal is to beat the other opponent."
                    hint = "if you are far from opponent, use Move Closer and Fireball more often. If you are close to opponent or already move closer, try to use Punch and Kick more often. Megapunch, Hurricane, and other combinations uses more time but are more powerful. Use them when you are close to opponent and you are getting positive scores or winning. If you are getting negative scores or losing, try to Move away and use Kick."

                    prompt = f"""
The moves you can use are:
{move_list}
----
Example 1:
Context:
You are very far from the opponent. Move closer to the opponent. Your opponent is on the left.
Your last action was Medium Punch. The opponent's last action was Medium Punch.
Your current score is 108.0. You are winning. Keep attacking the opponent.

Your Response:
- Move closer
- Move closer
- Low Kick

Example 2:
Context:
You are close to the opponent. You should attack him.
Your last action was High Punch. The opponent's last action was High Punch.
Your current score is 37.0. You are winning. Keep attacking the opponent.

Your Response:
- High Punch
- Low Punch
- Hurricane

Example 3:
Context:
You are very far from the opponent. Move closer to the opponent. Your opponent is on the left.
Your last action was Low. The opponent's last action was Medium Punch.
Your current score is -75.0. You are losing. Continue to attack the opponent but don't get hit.
To increase your score, move toward the opponent and attack the opponent. To prevent your score from decreasing, don't get hit by the opponent.

Your Response:
- Move Away
- Low Punch
- Fireball

Now you are provided the following context, give your response using the same format as in the example.
Context: 
{context_prompt}
"""

                    messages = [
                        {"role": "system", "content": background + hint},
                        {"role": "user", "content": prompt + "\nYour Response:\n"}
                    ]
                    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                    # vLLM inference
                    print("[Client-{port}]text:", text)
                    start_time = time.time()
                    outputs = model.generate(text, sampling_params=sampleing_params)
                    end_time = time.time()

                    generated = outputs[0].outputs[0].text
                    lasting_time = end_time - start_time

                    s.sendall(generated.encode())
                    print(f"[Client-{port}] Sent response.")
                    print(f"[Client-{port}] Generated: {generated}")
                    logger.info(f"client{gpu_id} send response: {generated}")

                    if gpu_id == 0 and iterative > 0:
                        if connection_count % iterative == 0:
                            iteration_id = (iteration_id + 1) % len(delay_list)
                        current_delay = delay_list[iteration_id]
                        if current_delay > 0.0:
                            print(f"[Client-{port}] Sleeping for {current_delay:.2f}x of generation time (index {iteration_id})")
                            time.sleep(current_delay * lasting_time)

            except (ConnectionRefusedError, socket.error) as e:
                print(f"[Client-{port}] Connection failed: {e}. Retrying in 4s...")
                time.sleep(4)
            except Exception as e:
                print(f"[Client-{port}] Unexpected error inside connection loop: {e}")
            finally:
                s.close()
                print(f"[Client-{port}] Socket closed. Reconnecting...")

    except Exception as e:
        logger.error(f"[Client-{port}] Fatal error: {e}")
        print(f"[Client-{port}] Fatal error: {e}")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_names", type=str, required=True, help="Comma-separated model names, e.g. model1,model2")
    parser.add_argument("--host", type=str, default="localhost", help="Server hostname or IP")
    parser.add_argument("--ports", type=str, required=True, help="Comma-separated ports, e.g. 38001,38002")
    parser.add_argument("--devices", type=str, required=True, help="Comma-separated device ids, e.g. 0,1")
    parser.add_argument("--iterative", type=int, default=-1, help="Change delay after every N iterations")
    parser.add_argument("--delay_list", type=str, required=True, help="Comma-separated float delays, e.g. 1.0,0.5,0.3")
    parser.add_argument("--use_fp8", type=str, required=True, help="Comma-separated flags, e.g. 1,0 to control per-process fp8 usage")
    return parser.parse_args()

def main():
    args = parse_args()

    model_names = args.model_names.split(",")
    ports = [int(p) for p in args.ports.split(",")]
    devices = [int(d) for d in args.devices.split(",")]
    delay_list = [float(x) for x in args.delay_list.split(",")]
    use_fp8_list = [bool(int(x)) for x in args.use_fp8.split(",")]

    if not (len(ports) == len(devices) == len(model_names) == len(use_fp8_list)):
        raise ValueError("Number of ports, devices, model_names, and use_fp8 flags must all match.")

    barrier = Barrier(len(ports))

    processes = []
    for i in range(len(ports)):
        p = Process(
            target=run_client,
            args=(args.host, ports[i], model_names[i], devices[i], args.iterative, delay_list, barrier, use_fp8_list[i])
        )
        p.start()
        processes.append(p)

    try:
        for p in processes:
            p.join()
    except KeyboardInterrupt:
        print("Interrupted by user. Exiting...")

if __name__ == "__main__":
    set_start_method('spawn')
    main()
