package main

import (
	"bytes"
	"context"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"flag"
	"fmt"
	"io"
	"log"
	"net/http"
	"os"
	"strconv"
	"strings"
	"time"
)

const (
	defaultListenAddr      = ":8080"
	defaultGitHubAPIBase   = "https://api.github.com"
	defaultUserAgent       = "RepoSurvey-GitHub-API-Proxy/1.0"
	defaultCacheMaxItems   = 100_000
	defaultCacheTTLSeconds = 30 * 24 * 60 * 60 // 30 days (best effort)
	maxBodyBytes           = 32 << 20          // 32MB
)

func main() {
	var listen string
	var apiBase string
	var token string
	var cacheMax int
	var cacheTTLSeconds int

	flag.StringVar(&listen, "listen", getEnv("PROXY_LISTEN", defaultListenAddr), "listen address")
	flag.StringVar(&apiBase, "github-api-base", getEnv("GITHUB_API_BASE", defaultGitHubAPIBase), "GitHub API base URL")
	flag.StringVar(&token, "token", getEnv("GITHUB_TOKEN", ""), "GitHub token (required)")
	flag.IntVar(&cacheMax, "cache-max", getEnvInt("PROXY_CACHE_MAX", defaultCacheMaxItems), "max LRU cache entries")
	flag.IntVar(&cacheTTLSeconds, "cache-ttl-seconds", getEnvInt("PROXY_CACHE_TTL_SECONDS", defaultCacheTTLSeconds), "cache TTL in seconds")
	flag.Parse()

	if token == "" {
		log.Fatalf("GITHUB_TOKEN is required")
	}

	proxy := &server{
		client:   &http.Client{Timeout: 60 * time.Second},
		apiBase:  strings.TrimRight(apiBase, "/"),
		token:    token,
		cache:    newLRUCache(cacheMax),
		cacheTTL: time.Duration(cacheTTLSeconds) * time.Second,
	}

	mux := http.NewServeMux()
	mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusOK)
		_, _ = w.Write([]byte(`{"ok":true}`))
	})
	mux.Handle("/", proxy)

	log.Printf("github_api_proxy: listen=%s github_api_base=%s cache_max=%d cache_ttl=%s", listen, proxy.apiBase, cacheMax, proxy.cacheTTL)
	if err := http.ListenAndServe(listen, mux); err != nil {
		log.Fatal(err)
	}
}

type server struct {
	client   *http.Client
	apiBase  string
	token    string
	cache    *lruCache
	cacheTTL time.Duration
}

func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	// Only GET is used by the pipeline client.
	if r.Method != http.MethodGet {
		writeJSONErr(w, http.StatusMethodNotAllowed, "method not allowed")
		return
	}

	// Health handled by mux
	pathWithQuery := r.URL.RequestURI()

	cacheKey := s.cacheKey(pathWithQuery)
	if cr, ok := s.cache.Get(cacheKey); ok {
		writeCached(w, cr)
		return
	}

	ctx := r.Context()
	cr, err := s.fetch(ctx, pathWithQuery)
	if err != nil {
		if errors.Is(err, context.Canceled) {
			// client gave up
			return
		}
		writeJSONErr(w, http.StatusBadGateway, err.Error())
		return
	}

	// Do not cache obvious transient errors.
	if shouldCache(cr.statusCode) {
		s.cache.Set(cacheKey, cr, s.cacheTTL)
	}
	writeCached(w, cr)
}

func (s *server) fetch(ctx context.Context, pathWithQuery string) (cachedResponse, error) {
	isGQL := strings.HasPrefix(pathWithQuery, "/gql/pull_closing_issues")

	var upstreamReq *http.Request
	var err error
	if isGQL {
		payload, err := buildGraphQL(pathWithQuery)
		if err != nil {
			return cachedResponse{}, err
		}
		upstreamReq, err = http.NewRequestWithContext(ctx, http.MethodPost, s.apiBase+"/graphql", bytes.NewBufferString(payload))
		if err != nil {
			return cachedResponse{}, err
		}
		upstreamReq.Header.Set("Content-Type", "application/json")
		upstreamReq.Header.Set("Accept", "application/json")
	} else {
		upstreamReq, err = http.NewRequestWithContext(ctx, http.MethodGet, s.apiBase+pathWithQuery, nil)
		if err != nil {
			return cachedResponse{}, err
		}
		upstreamReq.Header.Set("Accept", "application/vnd.github.v3+json")
	}

	upstreamReq.Header.Set("Authorization", "token "+s.token)
	upstreamReq.Header.Set("User-Agent", defaultUserAgent)

	resp, err := s.client.Do(upstreamReq)
	if err != nil {
		return cachedResponse{}, fmt.Errorf("upstream request failed: %w", err)
	}
	defer resp.Body.Close()

	body, err := readUpTo(resp.Body, maxBodyBytes)
	if err != nil {
		return cachedResponse{}, err
	}

	status, retryAfterSeconds := remapStatus(isGQL, resp.StatusCode, resp.Header, body)
	ct := resp.Header.Get("Content-Type")

	return cachedResponse{
		statusCode:        status,
		body:              body,
		retryAfterSeconds: retryAfterSeconds,
		contentType:       ct,
	}, nil
}

func shouldCache(statusCode int) bool {
	// Cache 200 and deterministic 404/410/422/451 responses.
	// Do not cache rate limits or server-side errors.
	switch statusCode {
	case http.StatusOK, http.StatusNotFound, http.StatusGone, http.StatusUnprocessableEntity, 451:
		return true
	default:
		return false
	}
}

func writeCached(w http.ResponseWriter, cr cachedResponse) {
	if cr.contentType != "" {
		w.Header().Set("Content-Type", cr.contentType)
	} else {
		w.Header().Set("Content-Type", "application/json")
	}

	if (cr.statusCode == http.StatusTooManyRequests || cr.statusCode == http.StatusForbidden) && cr.retryAfterSeconds > 0 {
		w.Header().Set("Retry-After", strconv.Itoa(cr.retryAfterSeconds))
	}

	w.WriteHeader(cr.statusCode)
	_, _ = w.Write(cr.body)
}

func writeJSONErr(w http.ResponseWriter, status int, msg string) {
	w.Header().Set("Content-Type", "application/json")
	w.WriteHeader(status)
	_, _ = w.Write([]byte(fmt.Sprintf(`{"message":%q}`, msg)))
}

func (s *server) cacheKey(pathWithQuery string) string {
	// Hash to keep memory overhead stable.
	sum := sha256.Sum256([]byte(pathWithQuery))
	return hex.EncodeToString(sum[:])
}

func readUpTo(r io.Reader, max int64) ([]byte, error) {
	lr := &io.LimitedReader{R: r, N: max + 1}
	b, err := io.ReadAll(lr)
	if err != nil {
		return nil, err
	}
	if int64(len(b)) > max {
		return nil, fmt.Errorf("response too large (%d bytes > %d)", len(b), max)
	}
	return b, nil
}

func buildGraphQL(path string) (string, error) {
	parts := strings.Split(strings.Trim(path, "/"), "/")
	if len(parts) != 5 {
		// only /gql/pull_closing_issues/{owner}/{repo}/{pr_number} is supported
		return "", fmt.Errorf("invalid gql/pull_closing_issues path format, expected /gql/pull_closing_issues/{owner}/{repo}/{pr_number}")
	}

	owner := parts[2]
	repo := parts[3]
	prNumber := parts[4]

	// Keep query stable for caching.
	query := fmt.Sprintf(`{"query": "query { repository(owner: \"%s\", name: \"%s\") { pullRequest(number: %s) { closingIssuesReferences(first: 20) { nodes { number databaseId repository { databaseId name owner { login } } } } } } }"}`,
		owner, repo, prNumber)
	return query, nil
}

func remapStatus(isGQL bool, upstreamStatus int, hdr http.Header, body []byte) (finalStatus int, retryAfterSeconds int) {
	finalStatus = upstreamStatus

	remaining := hdr.Get("X-RateLimit-Remaining")
	isRateLimited := upstreamStatus == http.StatusTooManyRequests || upstreamStatus == http.StatusForbidden

	// Match upstream reference logic:
	// - 403 + remaining==0 => treat as hourly limit (403)
	// - otherwise if rate-limited => 429
	if isRateLimited {
		if remaining == "0" {
			finalStatus = http.StatusForbidden
		} else {
			finalStatus = http.StatusTooManyRequests
		}

		// Special: GitHub sometimes returns 403 with body containing this when a repo is blocked.
		if upstreamStatus == http.StatusForbidden && bytes.Contains(body, []byte("Repository access blocked")) {
			finalStatus = http.StatusNotFound
			isRateLimited = false
		}
	}

	// For GraphQL: sometimes 200 with "errors" should not be cached.
	if isGQL && upstreamStatus == http.StatusOK {
		if bytes.Contains(body, []byte(`"errors"`)) {
			if remaining == "0" {
				finalStatus = http.StatusForbidden
				isRateLimited = true
			} else {
				finalStatus = http.StatusNotFound
				isRateLimited = false
			}
		}
	}

	if isRateLimited {
		if ra := parseRetryAfterSeconds(hdr.Get("Retry-After")); ra > 0 {
			retryAfterSeconds = ra
			return finalStatus, retryAfterSeconds
		}

		// Best-effort: derive from X-RateLimit-Reset when remaining==0.
		resetUnix, _ := strconv.ParseInt(hdr.Get("X-RateLimit-Reset"), 10, 64)
		if resetUnix > 0 && remaining == "0" {
			resetAt := time.Unix(resetUnix, 0)
			delta := int(time.Until(resetAt).Seconds())
			if delta > 0 {
				retryAfterSeconds = delta
				return finalStatus, retryAfterSeconds
			}
		}

		// Fallback.
		retryAfterSeconds = 60
	}

	return finalStatus, retryAfterSeconds
}

func parseRetryAfterSeconds(v string) int {
	v = strings.TrimSpace(v)
	if v == "" {
		return 0
	}
	if seconds, err := strconv.Atoi(v); err == nil {
		return seconds
	}
	if t, err := time.Parse(time.RFC1123, v); err == nil {
		d := time.Until(t)
		if d > 0 {
			return int(d.Seconds())
		}
	}
	return 0
}

func getEnv(key, def string) string {
	if v := os.Getenv(key); v != "" {
		return v
	}
	return def
}

func getEnvInt(key string, def int) int {
	if v := os.Getenv(key); v != "" {
		if i, err := strconv.Atoi(v); err == nil {
			return i
		}
	}
	return def
}
