package tokenizer

import (
	"encoding/binary"
	"fmt"
	"io"
	"os/exec"
	"sync"
	"sync/atomic"

	"github.com/vmihailenco/msgpack/v5"
)

// RustWorkerConfig holds configuration for the Rust tokenizer worker
type RustWorkerConfig struct {
	WorkerPath          string
	Model               string
	Workers             int
	ThreadsPerTokenizer int
	MaxMessageSize      int // Maximum message size in bytes (default: 100MB)
}

// RustTokenizerClient manages communication with the Rust tokenizer worker via IPC
// Uses length-prefixed MessagePack protocol for efficient binary communication
// Implements async message queue pattern: producers send requests, harvester receives responses
type RustTokenizerClient struct {
	cmd    *exec.Cmd
	stdin  io.WriteCloser
	stdout io.ReadCloser
	stderr io.ReadCloser

	// Async message queue pattern
	writeMu        sync.Mutex             // Protects stdin writes
	responseCh     chan *TokenizeResponse // Channel for responses from harvester
	closed         atomic.Bool
	harvesterWg    sync.WaitGroup
	maxMessageSize int // Maximum message size in bytes
}

// PRText represents a PR text to be tokenized
type PRText struct {
	RepoID   int64  `msgpack:"repo_id"`
	RepoName string `msgpack:"repo_name"`
	PRID     int64  `msgpack:"pr_id"`
	Text     string `msgpack:"text"`
}

// TokenizeRequest represents a tokenization request (can contain multiple PRs)
type TokenizeRequest struct {
	Command   string   `msgpack:"command"`
	PRs       []PRText `msgpack:"prs"`
	MaxTokens int32    `msgpack:"max_tokens"`
}

// TokenizedResult represents the tokenization result for a single PR
type TokenizedResult struct {
	RepoID     int64   `msgpack:"repo_id"`
	RepoName   string  `msgpack:"repo_name"`
	PRID       int64   `msgpack:"pr_id"`
	TokenIDs   []int32 `msgpack:"token_ids"`
	TokenCount int32   `msgpack:"token_count"`
	ByteSize   int32   `msgpack:"byte_size"`
	Discarded  bool    `msgpack:"discarded"`
}

// TokenizeResponse represents the response from tokenization (can contain multiple results)
type TokenizeResponse struct {
	Status  string            `msgpack:"status"`
	Results []TokenizedResult `msgpack:"results"`
	Error   *string           `msgpack:"error"`
}

// NewRustTokenizerClient creates and starts a new Rust tokenizer worker process
func NewRustTokenizerClient(config RustWorkerConfig) (*RustTokenizerClient, error) {
	// Build command with CLI flags
	args := []string{}

	if config.Model != "" {
		args = append(args, "--model", config.Model)
	}

	if config.Workers > 0 {
		args = append(args, "--workers", fmt.Sprintf("%d", config.Workers))
	}

	if config.ThreadsPerTokenizer > 0 {
		args = append(args, "--threads-per-tokenizer", fmt.Sprintf("%d", config.ThreadsPerTokenizer))
	}

	cmd := exec.Command(config.WorkerPath, args...)

	stdin, err := cmd.StdinPipe()
	if err != nil {
		return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
	}

	stdout, err := cmd.StdoutPipe()
	if err != nil {
		return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
	}

	stderr, err := cmd.StderrPipe()
	if err != nil {
		return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
	}

	if err := cmd.Start(); err != nil {
		return nil, fmt.Errorf("failed to start Rust worker: %w", err)
	}

	// Set default max message size if not specified (100MB)
	maxMsgSize := config.MaxMessageSize
	if maxMsgSize <= 0 {
		maxMsgSize = 100 * 1024 * 1024 // 100MB default
	}

	client := &RustTokenizerClient{
		cmd:            cmd,
		stdin:          stdin,
		stdout:         stdout,
		stderr:         stderr,
		responseCh:     make(chan *TokenizeResponse, 1000), // Buffered channel for responses
		maxMessageSize: maxMsgSize,
	}
	client.closed.Store(false)

	// Start goroutine to forward stderr to our stderr
	go func() {
		buf := make([]byte, 4096)
		for {
			n, err := stderr.Read(buf)
			if err != nil {
				return
			}
			if n > 0 {
				fmt.Printf("%s", string(buf[:n]))
			}
		}
	}()

	// Start harvester goroutine to read responses from Rust worker
	client.harvesterWg.Add(1)
	go client.runHarvester()

	return client, nil
}

// writeMessage writes a length-prefixed MessagePack message
func (c *RustTokenizerClient) writeMessage(msg interface{}) error {
	// Serialize using MessagePack
	msgBytes, err := msgpack.Marshal(msg)
	if err != nil {
		return fmt.Errorf("failed to marshal message: %w", err)
	}

	// Check message size limit
	if len(msgBytes) > c.maxMessageSize {
		return fmt.Errorf("message size %d bytes exceeds limit %d bytes", len(msgBytes), c.maxMessageSize)
	}

	// Prepare complete message: [4-byte length][message body]
	// Write atomically to prevent interleaving with other goroutines
	totalLen := 4 + len(msgBytes)
	fullMsg := make([]byte, totalLen)

	// Write length prefix
	binary.BigEndian.PutUint32(fullMsg[0:4], uint32(len(msgBytes)))

	// Copy message body
	copy(fullMsg[4:], msgBytes)

	// Write complete message in one call (atomic at OS level for pipes)
	n, err := c.stdin.Write(fullMsg)
	if err != nil {
		return fmt.Errorf("failed to write message: %w", err)
	}
	if n != totalLen {
		return fmt.Errorf("incomplete write: wrote %d bytes, expected %d", n, totalLen)
	}

	return nil
}

// readMessage reads a length-prefixed MessagePack message
func (c *RustTokenizerClient) readMessage(msg interface{}) error {
	// Read 4-byte length prefix
	lenBuf := make([]byte, 4)
	if _, err := io.ReadFull(c.stdout, lenBuf); err != nil {
		return fmt.Errorf("failed to read message length: %w", err)
	}
	msgLen := binary.BigEndian.Uint32(lenBuf)

	// Read message body
	msgBuf := make([]byte, msgLen)
	if _, err := io.ReadFull(c.stdout, msgBuf); err != nil {
		return fmt.Errorf("failed to read message body: %w", err)
	}

	// Deserialize using MessagePack
	if err := msgpack.Unmarshal(msgBuf, msg); err != nil {
		return fmt.Errorf("failed to unmarshal message: %w", err)
	}

	return nil
}

// SendRequest sends PRs to the Rust worker for tokenization (non-blocking)
// This is called by producer goroutines. PRs can be sent individually or in small batches.
func (c *RustTokenizerClient) SendRequest(prs []PRText, maxTokens int32) error {
	if c.closed.Load() {
		return fmt.Errorf("client is closed")
	}

	req := TokenizeRequest{
		Command:   "tokenize",
		PRs:       prs,
		MaxTokens: maxTokens,
	}

	// Send request (protected by mutex for concurrent writes)
	c.writeMu.Lock()
	defer c.writeMu.Unlock()

	if err := c.writeMessage(&req); err != nil {
		return fmt.Errorf("failed to write request: %w", err)
	}

	return nil
}

// ReceiveResponse receives a response from the Rust worker (blocking)
// This is called by the harvester goroutine
func (c *RustTokenizerClient) ReceiveResponse() (*TokenizeResponse, error) {
	resp, ok := <-c.responseCh
	if !ok {
		return nil, fmt.Errorf("response channel closed")
	}
	return resp, nil
}

// runHarvester continuously reads responses from Rust worker stdout
// and sends them to the response channel
func (c *RustTokenizerClient) runHarvester() {
	defer c.harvesterWg.Done()
	defer close(c.responseCh)

	for {
		var resp TokenizeResponse
		if err := c.readMessage(&resp); err != nil {
			if c.closed.Load() {
				// Normal shutdown
				return
			}
			// Log error but continue trying to read
			fmt.Printf("[ERROR] Harvester failed to read response: %v\n", err)
			return
		}

		// Send response to channel (may block if channel is full)
		select {
		case c.responseCh <- &resp:
		default:
			fmt.Printf("[WARN] Response channel full, dropping response with %d results\n", len(resp.Results))
		}
	}
}

// CloseStdin closes stdin to signal no more requests will be sent
// This allows the Rust worker to finish processing and send remaining responses
func (c *RustTokenizerClient) CloseStdin() error {
	c.writeMu.Lock()
	defer c.writeMu.Unlock()

	if c.stdin != nil {
		if err := c.stdin.Close(); err != nil {
			return fmt.Errorf("failed to close stdin: %w", err)
		}
		c.stdin = nil
	}
	return nil
}

// Close terminates the Rust worker process
func (c *RustTokenizerClient) Close() error {
	if !c.closed.CompareAndSwap(false, true) {
		return nil // Already closed
	}

	// Close stdin if not already closed
	if c.stdin != nil {
		if err := c.stdin.Close(); err != nil {
			return fmt.Errorf("failed to close stdin: %w", err)
		}
	}

	// Wait for harvester to finish
	c.harvesterWg.Wait()

	// Wait for the process to exit
	if err := c.cmd.Wait(); err != nil {
		// Ignore exit errors as the process may have already exited
		if _, ok := err.(*exec.ExitError); !ok {
			return fmt.Errorf("failed to wait for worker: %w", err)
		}
	}

	return nil
}
