package main import ( "bufio" "bytes" "compress/gzip" "fmt" "html" "io" "log" "net/http" "os" "regexp" "strconv" "strings" "github.com/gobwas/glob" "github.com/hashicorp/go-version" "gopkg.in/yaml.v3" ) // Config file structure. type Config struct { HttpPort uint `yaml:"hddp_port"` HttpBind string `yaml:"http_bind"` MaxPythonVersion string `yaml:"max_python_version"` PackageVersionLimits map[string]string `yaml:"package_version_limits"` } var config *Config // Handle http proxy request. func handleRequest(rw http.ResponseWriter, req *http.Request) { // Replace the scheme and host to pypi.org. url := req.URL url.Scheme = "https" url.Host = "pypi.org" // Create a new request based on the original outreq, err := http.NewRequest(req.Method, url.String(), req.Body) if err != nil { http.Error(rw, err.Error(), http.StatusInternalServerError) return } // Copy headers from original request, and correct the host header. outreq.Header = req.Header outreq.Header.Set("Host", url.Host) // Make sure we close the request body after the function call is done. if outreq.Body != nil { defer outreq.Body.Close() } // Send the request. res, err := http.DefaultTransport.RoundTrip(outreq) if err != nil { http.Error(rw, err.Error(), http.StatusServiceUnavailable) return } defer res.Body.Close() // Copy headers received to requester. for k, vv := range res.Header { for _, v := range vv { rw.Header().Add(k, v) } } // Log the request. log.Println(req.Method, url, res.StatusCode) // Verify this is `/simple/package/` request. pathS := strings.Split(req.URL.Path, "/") if len(pathS) == 4 { // Get the package being requested, and check if there is a version constraint on it. pkg := pathS[2] pkgVerMax, hasVerConstraint := config.PackageVersionLimits[pkg] var pkgVer *version.Version if hasVerConstraint { pkgVer, _ = version.NewVersion(pkgVerMax) } // Buffer to store the modified response. bodyBuff := new(bytes.Buffer) // Determine body, gzip encoded or plain text. var bodyReader io.Reader = res.Body gzipEncoded := res.Header.Get("Content-Encoding") == "gzip" if gzipEncoded { bodyReader, err = gzip.NewReader(res.Body) } // Setup scanner and matching variables. scanner := bufio.NewScanner(bodyReader) constraintsRx := regexp.MustCompile(`data-requires-python="([^"]+)"`) pkgVersionRx := regexp.MustCompile(`]+>[^<]+-([0-9]+.[0-9]+.[0-9]+)[-.][^<]+`) pyVersion, _ := version.NewVersion(config.MaxPythonVersion) // Scan each line, and apply version constraints. for scanner.Scan() { line := scanner.Text() // Check if this has the python version constraints. matches := constraintsRx.FindAllStringSubmatch(line, 1) if len(matches) == 1 { // Compare and skip if the constraints are not matched. rules := strings.Split(html.UnescapeString(matches[0][1]), ",") versionMatch := true for _, rule := range rules { var cmpB strings.Builder var verB strings.Builder wildCard := false for _, b := range rule { if b == '<' || b == '>' || b == '=' || b == '!' || b == '~' { cmpB.WriteRune(b) } else if b == ' ' { continue } else { if b == '*' { wildCard = true } verB.WriteRune(b) } } var g glob.Glob var ver *version.Version cmp := cmpB.String() if wildCard && cmp == "==" || cmp == "!=" { g = glob.MustCompile(verB.String()) } else { verS := verB.String() if wildCard { verS = strings.ReplaceAll(verS, "*", "0") } ver, err = version.NewVersion(verS) if err != nil { log.Println(err) } } switch cmp { case "==": if wildCard { if !g.Match(config.MaxPythonVersion) { versionMatch = false } } else if !pyVersion.Equal(ver) { versionMatch = false } case "!=": if wildCard { if g.Match(config.MaxPythonVersion) { versionMatch = false } } else if pyVersion.Equal(ver) { versionMatch = false } case "=~": g = glob.MustCompile(fmt.Sprintf("%s.*", config.MaxPythonVersion)) if !g.Match(ver.String()) { versionMatch = false } case ">=": if !pyVersion.GreaterThanOrEqual(ver) { versionMatch = false } case ">": if !pyVersion.GreaterThan(ver) { versionMatch = false } case "<=": if !pyVersion.LessThanOrEqual(ver) { versionMatch = false } case "<": if !pyVersion.LessThan(ver) { versionMatch = false } } } if !versionMatch { continue } } // If there is a package version constraint, compare the package version and skip if not met. if hasVerConstraint { matches = pkgVersionRx.FindAllStringSubmatch(line, 1) if len(matches) == 1 { ver, err := version.NewVersion(matches[0][1]) if err != nil { log.Println(err) } else { if pkgVer.LessThan(ver) { continue } } } } // Write the read line. fmt.Fprintln(bodyBuff, line) } // Get the modified data and compress with gzip if needed. bodyBytes := bodyBuff.Bytes() if gzipEncoded { bodyBuff = new(bytes.Buffer) gzipWriter := gzip.NewWriter(bodyBuff) gzipWriter.Write(bodyBytes) gzipWriter.Close() bodyBytes = bodyBuff.Bytes() } // Update the content length. rw.Header().Set("Content-Length", strconv.Itoa(len(bodyBytes))) // Send headers. rw.WriteHeader(res.StatusCode) // Send the modified body. rw.Write(bodyBytes) } else { // Just copy the body as there is nothing to limit here. rw.WriteHeader(res.StatusCode) io.Copy(rw, res.Body) } } func main() { // Read the yaml configuration. yamlD, err := os.ReadFile("config.yaml") if err != nil { log.Fatal(err) } config = &Config{ HttpPort: 8080, } err = yaml.Unmarshal(yamlD, config) if err != nil { log.Fatal(err) } // Start the server. bindAddr := fmt.Sprintf("%s:%d", config.HttpBind, config.HttpPort) log.Println("Starting proxy server on", bindAddr) err = http.ListenAndServe(bindAddr, http.HandlerFunc(handleRequest)) if err != nil { log.Fatal(err) } }