package client

import (
	"bytes"
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"sync/atomic"
	"time"

	"golang.org/x/net/http2"
)

// ErrOutputTruncated is returned when the LLM output was truncated due to max_tokens limit
var ErrOutputTruncated = errors.New("output truncated due to length limit")

// LLMClient handles communication with vLLM/SGLang API endpoints
type LLMClient struct {
	httpClient *http.Client
	baseURL    string
	apiKey     string
	model      string

	// Throughput tracking
	finishedRequests      atomic.Int64
	totalPromptTokens     atomic.Int64
	totalCompletionTokens atomic.Int64
}

// LLMConfig holds configuration for the LLM client
type LLMConfig struct {
	BaseURL        string
	APIKey         string
	Model          string
	TimeoutSeconds int
}

// ChatMessage represents a message in the chat completion format
type ChatMessage struct {
	Role    string `json:"role"`
	Content string `json:"content"`
}

// ChatCompletionRequest represents a request to the chat completion API
type ChatCompletionRequest struct {
	Model       string        `json:"model"`
	Messages    []ChatMessage `json:"messages"`
	MaxTokens   int           `json:"max_tokens,omitempty"`
	Temperature float64       `json:"temperature,omitempty"`
	TopP        float64       `json:"top_p,omitempty"`
	Stream      bool          `json:"stream"`
}

// ChatCompletionResponse represents a response from the chat completion API
type ChatCompletionResponse struct {
	ID      string `json:"id"`
	Object  string `json:"object"`
	Created int64  `json:"created"`
	Model   string `json:"model"`
	Choices []struct {
		Index   int `json:"index"`
		Message struct {
			Role    string `json:"role"`
			Content string `json:"content"`
		} `json:"message"`
		FinishReason string `json:"finish_reason"`
	} `json:"choices"`
	Usage struct {
		PromptTokens     int `json:"prompt_tokens"`
		CompletionTokens int `json:"completion_tokens"`
		TotalTokens      int `json:"total_tokens"`
	} `json:"usage"`
}

// NewLLMClient creates a new LLM client for vLLM/SGLang API
// Optimized for high concurrency (~1000 concurrent HTTPS streaming requests)
func NewLLMClient(cfg LLMConfig) *LLMClient {
	timeout := 120 * time.Second
	if cfg.TimeoutSeconds > 0 {
		timeout = time.Duration(cfg.TimeoutSeconds) * time.Second
	}

	// Custom dialer with optimized settings
	dialer := &net.Dialer{
		Timeout:   30 * time.Second,
		KeepAlive: 30 * time.Second,
	}

	// TLS config with session cache for connection reuse
	tlsConfig := &tls.Config{
		// Session cache improves TLS handshake performance
		ClientSessionCache: tls.NewLRUClientSessionCache(1024),
		// Use modern TLS only
		MinVersion: tls.VersionTLS12,
	}

	transport := &http.Transport{
		// Connection pool settings for 1000+ concurrent requests
		MaxIdleConns:        1500, // Total idle connections across all hosts
		MaxIdleConnsPerHost: 1200, // Idle connections per host (your LLM endpoint)
		MaxConnsPerHost:     0,    // 0 = unlimited connections per host
		IdleConnTimeout:     90 * time.Second,

		// TLS configuration
		TLSClientConfig: tlsConfig,

		// Timeouts
		TLSHandshakeTimeout:   10 * time.Second,
		ResponseHeaderTimeout: 0, // 0 = no timeout (important for streaming)
		ExpectContinueTimeout: 1 * time.Second,

		// Connection settings
		DialContext:        dialer.DialContext,
		ForceAttemptHTTP2:  true,  // Explicitly enable HTTP/2
		DisableCompression: true,  // Disable for streaming (reduces latency)
		DisableKeepAlives:  false, // Keep connections alive

		// Buffer sizes for high throughput
		WriteBufferSize: 64 * 1024, // 64KB write buffer
		ReadBufferSize:  64 * 1024, // 64KB read buffer
	}

	// Configure HTTP/2 specific settings
	if http2Transport, err := http2.ConfigureTransports(transport); err == nil && http2Transport != nil {
		// HTTP/2 specific tuning
		http2Transport.ReadIdleTimeout = 30 * time.Second
		http2Transport.PingTimeout = 15 * time.Second
		http2Transport.AllowHTTP = false // Only use HTTP/2 for HTTPS
	}

	return &LLMClient{
		httpClient: &http.Client{
			Timeout:   timeout,
			Transport: transport,
		},
		baseURL: cfg.BaseURL,
		apiKey:  cfg.APIKey,
		model:   cfg.Model,
	}
}

// Generate sends a prompt to the LLM and returns the generated text
func (c *LLMClient) Generate(ctx context.Context, prompt string, maxTokens int) (string, error) {
	messages := []ChatMessage{
		{Role: "user", Content: prompt},
	}
	return c.ChatCompletion(ctx, messages, maxTokens)
}

// ChatCompletion sends a chat completion request to the LLM
func (c *LLMClient) ChatCompletion(ctx context.Context, messages []ChatMessage, maxTokens int) (string, error) {
	req := ChatCompletionRequest{
		Model:       c.model,
		Messages:    messages,
		MaxTokens:   maxTokens,
		Temperature: 0.7,
		TopP:        0.9,
		Stream:      false,
	}

	reqBody, err := json.Marshal(req)
	if err != nil {
		return "", fmt.Errorf("failed to marshal request: %w", err)
	}

	httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/v1/chat/completions", bytes.NewReader(reqBody))
	if err != nil {
		return "", fmt.Errorf("failed to create request: %w", err)
	}

	httpReq.Header.Set("Content-Type", "application/json")
	if c.apiKey != "" {
		httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
	}

	resp, err := c.httpClient.Do(httpReq)
	if err != nil {
		return "", fmt.Errorf("failed to send request: %w", err)
	}
	defer resp.Body.Close()

	body, err := io.ReadAll(resp.Body)
	if err != nil {
		return "", fmt.Errorf("failed to read response: %w", err)
	}

	if resp.StatusCode != http.StatusOK {
		return "", fmt.Errorf("LLM API error (status %d): %s", resp.StatusCode, string(body))
	}

	var chatResp ChatCompletionResponse
	if err := json.Unmarshal(body, &chatResp); err != nil {
		// Show first 500 chars of response body for debugging
		preview := string(body)
		if len(preview) > 500 {
			preview = preview[:500] + "..."
		}
		return "", fmt.Errorf("failed to parse response (not valid JSON): %w\nResponse preview: %s", err, preview)
	}

	if len(chatResp.Choices) == 0 {
		return "", fmt.Errorf("no choices in response")
	}

	// Check if output was truncated due to length limit
	if chatResp.Choices[0].FinishReason == "length" {
		return "", fmt.Errorf("%w: finish_reason is 'length', increase max_tokens or reduce input size", ErrOutputTruncated)
	}

	// Track throughput statistics
	c.finishedRequests.Add(1)
	c.totalPromptTokens.Add(int64(chatResp.Usage.PromptTokens))
	c.totalCompletionTokens.Add(int64(chatResp.Usage.CompletionTokens))

	return chatResp.Choices[0].Message.Content, nil
}

// GenerateWithSystem sends a prompt with a system message to the LLM
func (c *LLMClient) GenerateWithSystem(ctx context.Context, systemPrompt, userPrompt string, maxTokens int) (string, error) {
	messages := []ChatMessage{
		{Role: "system", Content: systemPrompt},
		{Role: "user", Content: userPrompt},
	}
	return c.ChatCompletion(ctx, messages, maxTokens)
}

// LLMStats holds throughput statistics for the LLM client
type LLMStats struct {
	FinishedRequests      int64
	TotalPromptTokens     int64
	TotalCompletionTokens int64
	TotalTokens           int64
}

// GetStats returns the current throughput statistics
func (c *LLMClient) GetStats() LLMStats {
	finished := c.finishedRequests.Load()
	prompt := c.totalPromptTokens.Load()
	completion := c.totalCompletionTokens.Load()
	return LLMStats{
		FinishedRequests:      finished,
		TotalPromptTokens:     prompt,
		TotalCompletionTokens: completion,
		TotalTokens:           prompt + completion,
	}
}
