pip-proxy/main.go
2025-07-14 10:22:59 -05:00

253 lines
6.2 KiB
Go

package main
import (
"bufio"
"bytes"
"compress/gzip"
"fmt"
"html"
"io"
"log"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"github.com/gobwas/glob"
"github.com/hashicorp/go-version"
"gopkg.in/yaml.v3"
)
// Config file structure.
type Config struct {
HttpPort uint `yaml:"hddp_port"`
HttpBind string `yaml:"http_bind"`
MaxPythonVersion string `yaml:"max_python_version"`
PackageVersionLimits map[string]string `yaml:"package_version_limits"`
}
var config *Config
// Handle http proxy request.
func handleRequest(rw http.ResponseWriter, req *http.Request) {
// Replace the scheme and host to pypi.org.
url := req.URL
url.Scheme = "https"
url.Host = "pypi.org"
// Create a new request based on the original
outreq, err := http.NewRequest(req.Method, url.String(), req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
// Copy headers from original request, and correct the host header.
outreq.Header = req.Header
outreq.Header.Set("Host", url.Host)
// Make sure we close the request body after the function call is done.
if outreq.Body != nil {
defer outreq.Body.Close()
}
// Send the request.
res, err := http.DefaultTransport.RoundTrip(outreq)
if err != nil {
http.Error(rw, err.Error(), http.StatusServiceUnavailable)
return
}
defer res.Body.Close()
// Copy headers received to requester.
for k, vv := range res.Header {
for _, v := range vv {
rw.Header().Add(k, v)
}
}
// Log the request.
log.Println(req.Method, url, res.StatusCode)
// Verify this is `/simple/package/` request.
pathS := strings.Split(req.URL.Path, "/")
if len(pathS) == 4 {
// Get the package being requested, and check if there is a version constraint on it.
pkg := pathS[2]
pkgVerMax, hasVerConstraint := config.PackageVersionLimits[pkg]
var pkgVer *version.Version
if hasVerConstraint {
pkgVer, _ = version.NewVersion(pkgVerMax)
}
// Buffer to store the modified response.
bodyBuff := new(bytes.Buffer)
// Determine body, gzip encoded or plain text.
var bodyReader io.Reader = res.Body
gzipEncoded := res.Header.Get("Content-Encoding") == "gzip"
if gzipEncoded {
bodyReader, err = gzip.NewReader(res.Body)
}
// Setup scanner and matching variables.
scanner := bufio.NewScanner(bodyReader)
constraintsRx := regexp.MustCompile(`data-requires-python="([^"]+)"`)
pkgVersionRx := regexp.MustCompile(`<a[^>]+>[^<]+-([0-9]+.[0-9]+.[0-9]+)[-.][^<]+</a>`)
pyVersion, _ := version.NewVersion(config.MaxPythonVersion)
// Scan each line, and apply version constraints.
for scanner.Scan() {
line := scanner.Text()
// Check if this has the python version constraints.
matches := constraintsRx.FindAllStringSubmatch(line, 1)
if len(matches) == 1 {
// Compare and skip if the constraints are not matched.
rules := strings.Split(html.UnescapeString(matches[0][1]), ",")
versionMatch := true
for _, rule := range rules {
var cmpB strings.Builder
var verB strings.Builder
wildCard := false
for _, b := range rule {
if b == '<' || b == '>' || b == '=' || b == '!' || b == '~' {
cmpB.WriteRune(b)
} else if b == ' ' {
continue
} else {
if b == '*' {
wildCard = true
}
verB.WriteRune(b)
}
}
var g glob.Glob
var ver *version.Version
cmp := cmpB.String()
if wildCard && cmp == "==" || cmp == "!=" {
g = glob.MustCompile(verB.String())
} else {
verS := verB.String()
if wildCard {
verS = strings.ReplaceAll(verS, "*", "0")
}
ver, err = version.NewVersion(verS)
if err != nil {
log.Println(err)
}
}
switch cmp {
case "==":
if wildCard {
if !g.Match(config.MaxPythonVersion) {
versionMatch = false
}
} else if !pyVersion.Equal(ver) {
versionMatch = false
}
case "!=":
if wildCard {
if g.Match(config.MaxPythonVersion) {
versionMatch = false
}
} else if pyVersion.Equal(ver) {
versionMatch = false
}
case "=~":
g = glob.MustCompile(fmt.Sprintf("%s.*", config.MaxPythonVersion))
if !g.Match(ver.String()) {
versionMatch = false
}
case ">=":
if !pyVersion.GreaterThanOrEqual(ver) {
versionMatch = false
}
case ">":
if !pyVersion.GreaterThan(ver) {
versionMatch = false
}
case "<=":
if !pyVersion.LessThanOrEqual(ver) {
versionMatch = false
}
case "<":
if !pyVersion.LessThan(ver) {
versionMatch = false
}
}
}
if !versionMatch {
continue
}
}
// If there is a package version constraint, compare the package version and skip if not met.
if hasVerConstraint {
matches = pkgVersionRx.FindAllStringSubmatch(line, 1)
if len(matches) == 1 {
ver, err := version.NewVersion(matches[0][1])
if err != nil {
log.Println(err)
} else {
if pkgVer.LessThan(ver) {
continue
}
}
}
}
// Write the read line.
fmt.Fprintln(bodyBuff, line)
}
// Get the modified data and compress with gzip if needed.
bodyBytes := bodyBuff.Bytes()
if gzipEncoded {
bodyBuff = new(bytes.Buffer)
gzipWriter := gzip.NewWriter(bodyBuff)
gzipWriter.Write(bodyBytes)
gzipWriter.Close()
bodyBytes = bodyBuff.Bytes()
}
// Update the content length.
rw.Header().Set("Content-Length", strconv.Itoa(len(bodyBytes)))
// Send headers.
rw.WriteHeader(res.StatusCode)
// Send the modified body.
rw.Write(bodyBytes)
} else {
// Just copy the body as there is nothing to limit here.
rw.WriteHeader(res.StatusCode)
io.Copy(rw, res.Body)
}
}
func main() {
// Read the yaml configuration.
yamlD, err := os.ReadFile("config.yaml")
if err != nil {
log.Fatal(err)
}
config = &Config{
HttpPort: 8080,
}
err = yaml.Unmarshal(yamlD, config)
if err != nil {
log.Fatal(err)
}
// Start the server.
bindAddr := fmt.Sprintf("%s:%d", config.HttpBind, config.HttpPort)
log.Println("Starting proxy server on", bindAddr)
err = http.ListenAndServe(bindAddr, http.HandlerFunc(handleRequest))
if err != nil {
log.Fatal(err)
}
}