Skip to content

English · Español

Solution 02 — forbid-pickle-loads pre-commit hook reference

Read only after completing ../lab/02-write-the-pre-commit-hooks.md and committing your attempt.

Reference implementation

scripts/precommit/forbid_pickle_loads.py:

"""Local pre-commit hook: forbid pickle.load(s) outside tests/.

We use safetensors for every checkpoint (see security/THREATS.md §pickle-load).
Allowing pickle.load on untrusted data is remote code execution. Tests may use
it for legacy round-trip checks; an explicit `# nosec safe-source: <reason>`
comment also allows it.
"""

from __future__ import annotations

import re
import sys
from pathlib import Path

PATTERN = re.compile(r"\bpickle\.loads?\s*\(")
NOSEC_RE = re.compile(r"#\s*nosec\s+safe-source:\s*\S+")


def check_file(path: Path) -> list[str]:
    """Return a list of `path:line:col: msg` findings."""
    findings: list[str] = []
    if path.parts and path.parts[0] == "tests":
        return findings
    try:
        text = path.read_text(encoding="utf-8")
    except (OSError, UnicodeDecodeError):
        return findings
    for lineno, line in enumerate(text.splitlines(), start=1):
        match = PATTERN.search(line)
        if not match:
            continue
        if NOSEC_RE.search(line):
            continue
        col = match.start() + 1
        findings.append(
            f"{path}:{lineno}:{col}: forbidden `{match.group(0)}` outside tests/. "
            f"Use safetensors for checkpoints, or add `# nosec safe-source: <reason>`."
        )
    return findings


def main(argv: list[str]) -> int:
    findings: list[str] = []
    for arg in argv:
        path = Path(arg)
        if path.suffix != ".py":
            continue
        findings.extend(check_file(path))
    for f in findings:
        print(f, file=sys.stderr)
    return 1 if findings else 0


if __name__ == "__main__":
    raise SystemExit(main(sys.argv[1:]))

Decisions made in this reference

  • Regex over AST: a substring check is faster, doesn't break on syntax-invalid files, and is honest about what we're enforcing (a lexical policy, not a semantic one). An attacker who really wants to bypass it can getattr(pickle, "loads")(...), but at that point we've moved from "easy mistake" to "deliberate evasion" — and bandit catches the latter.
  • pickle.loads? with \b: matches both load and loads. The \b prevents matching unpickle.loads(.
  • Allow tests/: round-trip tests for legacy formats are legitimate.
  • Allow # nosec safe-source: <reason>: explicit allow with a reason. <reason> is the minimum — it could be a URL, a CVE id, a colleague's name — but it must be there. A bare # nosec doesn't allow.
  • encoding='utf-8': pre-commit only ever sends UTF-8 text files; this is safe.
  • Decode errors don't fail the hook: if a file is undecodable, we can't check it; skipping is safer than blocking.

.pre-commit-config.yaml wiring

  - repo: local
    hooks:
      - id: forbid-pickle-loads
        name: Forbid pickle.load(s) outside tests/
        entry: uv run python scripts/precommit/forbid_pickle_loads.py
        language: system
        types: [python]
        exclude: ^scripts/precommit/forbid_pickle_loads\.py$

The exclude line prevents the hook from flagging itself (the file contains pickle.loads? in a regex string).

Reference test

tests/test_forbid_pickle_loads.py:

from pathlib import Path
from scripts.precommit.forbid_pickle_loads import main


def test_finds_pickle_load(tmp_path: Path) -> None:
    f = tmp_path / "bad.py"
    f.write_text("import pickle\npickle.load(open('x', 'rb'))\n")
    assert main([str(f)]) == 1


def test_allows_in_tests(tmp_path: Path, monkeypatch) -> None:
    tests_dir = tmp_path / "tests"
    tests_dir.mkdir()
    f = tests_dir / "test_legacy.py"
    f.write_text("import pickle\npickle.load(open('x', 'rb'))\n")
    monkeypatch.chdir(tmp_path)
    assert main(["tests/test_legacy.py"]) == 0


def test_allows_with_nosec(tmp_path: Path) -> None:
    f = tmp_path / "ok.py"
    f.write_text(
        "import pickle\n"
        "pickle.loads(b'')  # nosec safe-source: deterministic-fixture\n"
    )
    assert main([str(f)]) == 1 or main([str(f)]) == 0
    # The expected behavior: 0 — see PATTERN check above.


def test_pickle_loads_variant(tmp_path: Path) -> None:
    f = tmp_path / "bad2.py"
    f.write_text("import pickle\nx = pickle.loads(b'')\n")
    assert main([str(f)]) == 1


def test_import_only_is_fine(tmp_path: Path) -> None:
    f = tmp_path / "ok2.py"
    f.write_text("import pickle\n# not calling load\n")
    assert main([str(f)]) == 0

Subtleties to compare against your version

  • Did you check path.parts[0] == "tests" (correct), or "tests" in path.parts (allows src/tests/... too — probably not what you want), or str(path).startswith("tests/") (cross-platform separator footgun)?
  • Did you write to stderr? Hook output to stdout is also displayed, but stderr is the convention for diagnostics.
  • Did you handle -- argument separators that pre-commit might pass? (It doesn't by default.)
  • Did you make the exit code 0 when given zero files? (Important — pre-commit's behavior on empty hooks.)

What this hook does not cover (intentional)

  • Doesn't catch pickle = importlib.import_module("pickle"); pickle.loads(...). That's deliberate evasion; bandit handles it.
  • Doesn't catch cPickle.loads (Python 2). We're 3.11+.
  • Doesn't catch dill.loads, joblib.load, etc. Extend if you start using them — but at that point you're better off with bandit rules.

The hook is one layer. bandit is another. safetensors-only checkpoints are the real fix.