Repository: https://github.com/groq/openbench
HEAD commit: dae937838b90ba39fb134daf694ea4bc3563508c
Total files: 190 · Rendered: 162 · Skipped: 28
View:

Directory tree

repo
├── .claude
│   └── settings.json
├── .github
│   ├── ISSUE_TEMPLATE
│   │   ├── bug_report.yml
│   │   └── feature_request.yml
│   ├── workflows
│   │   ├── ci.yml
│   │   ├── claude-code-review.yml
│   │   ├── claude.yml
│   │   ├── dependency-check.yml
│   │   ├── release-please.yaml
│   │   └── stale.yaml
│   ├── dependabot.yml
│   └── pull_request_template.md
├── .vscode
│   └── settings.json
├── src
│   └── openbench
│       ├── _cli
│       │   ├── __init__.py
│       │   ├── describe_command.py
│       │   ├── eval_command.py
│       │   ├── eval_retry_command.py
│       │   ├── export.py
│       │   ├── list_command.py
│       │   ├── utils.py
│       │   └── view_command.py
│       ├── datasets
│       │   ├── __init__.py
│       │   ├── browsecomp.py
│       │   ├── cti_bench.py
│       │   ├── drop.py
│       │   ├── gpqa.py
│       │   ├── graphwalks.py
│       │   ├── healthbench.py
│       │   ├── hle.py
│       │   ├── humaneval.py
│       │   ├── jsonschemabench.py
│       │   ├── math.py
│       │   ├── mgsm.py
│       │   ├── mmlu.py
│       │   ├── mrcr.py
│       │   ├── rootly_gmcq.py
│       │   ├── scicode.py
│       │   └── simpleqa.py
│       ├── evals
│       │   ├── matharena
│       │   │   ├── aime_2023_I
│       │   │   │   ├── __init__.py
│       │   │   │   └── aime_2023_I.py
│       │   │   ├── aime_2023_II
│       │   │   │   ├── __init__.py
│       │   │   │   └── aime_2023_II.py
│       │   │   ├── aime_2024
│       │   │   │   ├── __init__.py
│       │   │   │   └── aime_2024.py
│       │   │   ├── aime_2024_I
│       │   │   │   ├── __init__.py
│       │   │   │   └── aime_2024_I.py
│       │   │   ├── aime_2024_II
│       │   │   │   ├── __init__.py
│       │   │   │   └── aime_2024_II.py
│       │   │   ├── aime_2025
│       │   │   │   ├── __init__.py
│       │   │   │   └── aime_2025.py
│       │   │   ├── aime_2025_II
│       │   │   │   ├── __init__.py
│       │   │   │   └── aime_2025_II.py
│       │   │   ├── brumo_2025
│       │   │   │   ├── __init__.py
│       │   │   │   └── brumo_2025.py
│       │   │   ├── hmmt_feb_2023
│       │   │   │   ├── __init__.py
│       │   │   │   └── hmmt_feb_2023.py
│       │   │   ├── hmmt_feb_2024
│       │   │   │   ├── __init__.py
│       │   │   │   └── hmmt_feb_2024.py
│       │   │   ├── hmmt_feb_2025
│       │   │   │   ├── __init__.py
│       │   │   │   └── hmmt_feb_2025.py
│       │   │   ├── __init__.py
│       │   │   └── matharena.py
│       │   ├── __init__.py
│       │   ├── browsecomp.py
│       │   ├── cti_bench.py
│       │   ├── drop.py
│       │   ├── gpqa_diamond.py
│       │   ├── graphwalks.py
│       │   ├── healthbench.py
│       │   ├── hle.py
│       │   ├── humaneval.py
│       │   ├── jsonschemabench.py
│       │   ├── math.py
│       │   ├── mgsm.py
│       │   ├── mmlu.py
│       │   ├── mrcr.py
│       │   ├── musr.py
│       │   ├── openbookqa.py
│       │   ├── rootly_gmcq.py
│       │   ├── scicode.py
│       │   ├── simpleqa.py
│       │   └── supergpqa.py
│       ├── metrics
│       │   ├── __init__.py
│       │   └── grouped.py
│       ├── model
│       │   ├── _providers
│       │   │   ├── __init__.py
│       │   │   ├── ai21.py
│       │   │   ├── baseten.py
│       │   │   ├── cerebras.py
│       │   │   ├── cohere.py
│       │   │   ├── crusoe.py
│       │   │   ├── deepinfra.py
│       │   │   ├── friendli.py
│       │   │   ├── huggingface.py
│       │   │   ├── hyperbolic.py
│       │   │   ├── lambda_ai.py
│       │   │   ├── minimax.py
│       │   │   ├── moonshot.py
│       │   │   ├── nebius.py
│       │   │   ├── nous.py
│       │   │   ├── novita.py
│       │   │   ├── parasail.py
│       │   │   ├── reka.py
│       │   │   ├── sambanova.py
│       │   │   └── vercel.py
│       │   └── __init__.py
│       ├── monkeypatch
│       │   ├── __init__.py
│       │   ├── display_results_patch.py
│       │   └── file_recorder_logfile_patch.py
│       ├── scorers
│       │   ├── __init__.py
│       │   ├── browsecomp.py
│       │   ├── cti_bench.py
│       │   ├── drop.py
│       │   ├── fallback_scorer.py
│       │   ├── graphwalks.py
│       │   ├── healthbench.py
│       │   ├── hle.py
│       │   ├── humaneval.py
│       │   ├── json_schema.py
│       │   ├── math.py
│       │   ├── mgsm.py
│       │   ├── mmlu.py
│       │   ├── mrcr.py
│       │   ├── musr.py
│       │   ├── robust_boxed.py
│       │   ├── robust_mcq.py
│       │   ├── rootly_gmcq.py
│       │   ├── scicode.py
│       │   ├── score_boxed.py
│       │   ├── score_last_number.py
│       │   └── simpleqa.py
│       ├── utils
│       │   ├── __init__.py
│       │   ├── imports.py
│       │   └── text.py
│       ├── __init__.py
│       ├── _registry.py
│       ├── config.py
│       ├── eval_config.py
│       └── py.typed
├── tests
│   ├── _cli
│   │   ├── __init__.py
│   │   ├── test_eval_command.py
│   │   └── test_eval_command_functions.py
│   ├── integration
│   │   ├── __init__.py
│   │   └── test_cli.py
│   ├── monkeypatch
│   │   └── test_file_recorder_logfile_patch.py
│   ├── __init__.py
│   ├── conftest.py
│   ├── test_json_schema_scorer.py
│   ├── test_registry.py
│   └── test_robust_scorers.py
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── .release-please-manifest.json
├── CHANGELOG.md
├── CITATION.cff
├── CLAUDE.md
├── CODEOWNERS
├── CONTRIBUTING.md
├── LICENSE.md
├── MANIFEST.in
├── pyproject.toml
├── README.md
├── release-please-config.json
└── uv.lock

Table of contents (162)

Skipped items

Skipped large files (1)
  • uv.lock (697.1 KiB)

.claude/settings.json (80 B)

{
  "permissions": {
    "allow": [
      "WebFetch"
    ],
    "deny": []
  }
}

.github/ISSUE_TEMPLATE/bug_report.yml (2.0 KiB)

name: Bug Report
description: File a bug report to help us improve OpenBench
title: "[Bug]: "
labels: ["bug"]
body:
  - type: markdown
    attributes:
      value: |
        Thanks for taking the time to fill out this bug report!

  - type: input
    id: version
    attributes:
      label: OpenBench Version
      description: What version of OpenBench are you running?
      placeholder: "0.3.0"
    validations:
      required: true

  - type: textarea
    id: what-happened
    attributes:
      label: What happened?
      description: A clear and concise description of what the bug is.
      placeholder: Tell us what you see!
    validations:
      required: true

  - type: textarea
    id: expected
    attributes:
      label: Expected behavior
      description: A clear and concise description of what you expected to happen.
    validations:
      required: true

  - type: textarea
    id: reproduce
    attributes:
      label: Steps to reproduce
      description: Steps to reproduce the behavior
      placeholder: |
        1. Run command '...'
        2. With arguments '...'
        3. See error
    validations:
      required: true

  - type: input
    id: command
    attributes:
      label: Command that failed
      description: The exact command you ran that caused the issue
      placeholder: "bench eval mmlu --model groq/llama-3.1-70b --limit 10"

  - type: textarea
    id: logs
    attributes:
      label: Error logs
      description: Please copy and paste any relevant log output or error messages
      render: shell

  - type: dropdown
    id: os
    attributes:
      label: Operating System
      options:
        - macOS
        - Linux
        - Windows
        - Other
    validations:
      required: true

  - type: input
    id: python-version
    attributes:
      label: Python Version
      placeholder: "3.11.5"
    validations:
      required: true

  - type: textarea
    id: additional-context
    attributes:
      label: Additional context
      description: Add any other context about the problem here (environment, model provider, etc.)

.github/ISSUE_TEMPLATE/feature_request.yml (1.7 KiB)

name: Feature Request
description: Suggest an idea for OpenBench
title: "[Feature]: "
labels: ["enhancement"]
body:
  - type: markdown
    attributes:
      value: |
        Thanks for suggesting a new feature for OpenBench!

  - type: textarea
    id: problem
    attributes:
      label: Is your feature request related to a problem?
      description: A clear and concise description of what the problem is.
      placeholder: "I'm always frustrated when..."

  - type: textarea
    id: solution
    attributes:
      label: Describe the solution you'd like
      description: A clear and concise description of what you want to happen.
    validations:
      required: true

  - type: dropdown
    id: feature-type
    attributes:
      label: What are you requesting?
      options:
        - New benchmark/evaluation
        - New model provider
        - CLI enhancement
        - Performance improvement
        - Documentation
        - API/SDK feature
        - Integration (CI/CD, tools)
        - Export/import functionality
        - Other
    validations:
      required: true

  - type: textarea
    id: alternatives
    attributes:
      label: Describe alternatives you've considered
      description: A clear and concise description of any alternative solutions or features you've considered.

  - type: textarea
    id: use-case
    attributes:
      label: Use case
      description: Describe your specific use case and how this feature would help you or others.
    validations:
      required: true

  - type: textarea
    id: additional-context
    attributes:
      label: Additional context
      description: Add any other context, screenshots, or examples about the feature request here.

.github/dependabot.yml (753 B)

version: 2
updates:
  # Enable version updates for GitHub Actions
  - package-ecosystem: "github-actions"
    directory: "/"
    schedule:
      interval: "weekly"
      day: "monday"
      time: "03:00"
    labels:
      - "dependencies"
      - "github-actions"
    open-pull-requests-limit: 10
    groups:
      actions:
        patterns:
          - "*"

  # Enable version updates for Python dependencies via pip
  - package-ecosystem: "pip"
    directory: "/"
    schedule:
      interval: "weekly"
      day: "monday"
      time: "03:00"
    labels:
      - "dependencies"
      - "python"
    open-pull-requests-limit: 10
    # Group all Python dependency updates together
    groups:
      python-dependencies:
        patterns:
          - "*"

.github/pull_request_template.md (1.4 KiB)

Summary

What are you adding?

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [ ] New benchmark/evaluation
  • [ ] New model provider
  • [ ] CLI enhancement
  • [ ] Performance improvement
  • [ ] Documentation update
  • [ ] API/SDK feature
  • [ ] Integration (CI/CD, tools)
  • [ ] Export/import functionality
  • [ ] Code refactoring
  • [ ] Breaking change
  • [ ] Other

Changes Made

-

Testing

  • [ ] I have run the existing test suite (pytest)
  • [ ] I have added tests for my changes
  • [ ] I have tested with multiple model providers (if applicable)
  • [ ] I have run pre-commit hooks (pre-commit run --all-files)

Checklist

  • [ ] My code follows the project's style guidelines
  • [ ] I have performed a self-review of my own code
  • [ ] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation (if applicable)
  • [ ] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] New and existing unit tests pass locally with my changes

Closes #

Additional Context

.github/workflows/ci.yml (4.3 KiB)

name: CI

on:
  push:
    branches: [ main ]
  pull_request:
    branches: [ main ]

# Cancel in-progress runs when a new run is queued on the same branch
concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: true

# Set environment variables for consistency
env:
  UV_CACHE_TTL_SECONDS: 604800  # 1 week
  UV_HTTP_TIMEOUT: 600  # 10 minutes

jobs:
  # Single job for linting and type checking to reduce overhead
  quality-checks:
    name: Quality Checks (Python ${{ matrix.python-version }})
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: ["3.12"]  # Use latest for quality checks
    steps:
      - uses: actions/checkout@v4

      - name: Install uv
        uses: astral-sh/setup-uv@v5
        with:
          version: "latest"
          enable-cache: true
          cache-dependency-glob: |
            **/pyproject.toml
            **/uv.lock

      - name: Set up Python ${{ matrix.python-version }}
        run: uv python install ${{ matrix.python-version }}

      - name: Install dependencies
        run: uv sync --group dev

      - name: Run ruff format check
        run: uv run ruff format --check .

      - name: Run ruff lint
        run: uv run ruff check .

      - name: Run type checking
        run: uv run mypy .

  test:
    name: Tests (Python ${{ matrix.python-version }})
    runs-on: ubuntu-latest
    strategy:
      fail-fast: false
      matrix:
        python-version: ["3.10", "3.11", "3.12", "3.13"]
    steps:
      - uses: actions/checkout@v4

      - name: Install uv
        uses: astral-sh/setup-uv@v5
        with:
          version: "latest"
          enable-cache: true
          cache-dependency-glob: |
            **/pyproject.toml
            **/uv.lock

      - name: Set up Python ${{ matrix.python-version }}
        run: uv python install ${{ matrix.python-version }}

      - name: Install dependencies
        run: uv sync --group dev

      - name: Run unit tests with coverage
        run: |
          uv run pytest -m "not integration" --cov=openbench --cov-report=term-missing

  integration-test:
    name: Integration Tests
    runs-on: ubuntu-latest
    if: github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository
    steps:
      - uses: actions/checkout@v4

      - name: Install uv
        uses: astral-sh/setup-uv@v5
        with:
          version: "latest"
          enable-cache: true
          cache-dependency-glob: |
            **/pyproject.toml
            **/uv.lock

      - name: Set up Python
        run: uv python install 3.12

      - name: Install dependencies
        run: uv sync --group dev

      - name: Run integration tests
        env:
          GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
        run: |
          uv run pytest -m integration -v
        if: env.GROQ_API_KEY != ''

  # Security scanning with pip-audit
  security:
    name: Security Scan
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4

      - name: Install uv
        uses: astral-sh/setup-uv@v5
        with:
          version: "latest"
          enable-cache: true

      - name: Set up Python
        run: uv python install 3.12

      - name: Install dependencies
        run: uv sync --group dev

      - name: Run security audit
        run: |
          uv pip install pip-audit
          uv run pip-audit --desc
        continue-on-error: true  # Don't fail the build on vulnerabilities

  # All checks must pass
  all-checks:
    name: All Checks Pass
    runs-on: ubuntu-latest
    needs: [quality-checks, test, integration-test, security]
    if: always()
    steps:
      - name: Verify all checks passed
        run: |
          # Quality checks and test must always pass
          if [[ "${{ needs.quality-checks.result }}" != "success" || 
                "${{ needs.test.result }}" != "success" ]]; then
            echo "One or more required checks failed"
            exit 1
          fi
          
          # Integration test can be skipped (for external PRs) or must succeed
          if [[ "${{ needs.integration-test.result }}" != "success" && 
                "${{ needs.integration-test.result }}" != "skipped" ]]; then
            echo "Integration tests failed"
            exit 1
          fi
          
          # Security is continue-on-error, so we don't check it
          echo "All required checks passed!"

.github/workflows/claude-code-review.yml (3.5 KiB)

name: Claude Code Review

on:
  pull_request:
    types: [opened, synchronize]
    # Optional: Only run on specific file changes
    # paths:
    #   - "src/**/*.ts"
    #   - "src/**/*.tsx"
    #   - "src/**/*.js"
    #   - "src/**/*.jsx"

jobs:
  claude-review:
    # Optional: Filter by PR author
    # if: |
    #   github.event.pull_request.user.login == 'external-contributor' ||
    #   github.event.pull_request.user.login == 'new-developer' ||
    #   github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
    
    runs-on: ubuntu-latest
    continue-on-error: true
    permissions:
      contents: read
      pull-requests: read
      issues: read
      id-token: write
    
    steps:
      - name: Checkout repository
        uses: actions/checkout@v4
        with:
          fetch-depth: 1

      - name: Run Claude Code Review
        id: claude-review
        uses: anthropics/claude-code-action@beta
        with:
          anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
          
          # Optional: Specify model (defaults to Claude Sonnet 4, uncomment for Claude Opus 4.1)
          model: "claude-opus-4-1-20250805"
          
          # Direct prompt for automated review (no @claude mention needed)
          direct_prompt: |
            Review this PR like a chill staff engineer. Focus on NEW changes only and think DRY.

            ONLY bring up things if they're actually important - you don't need to comment on every category or find something to say. Quality over quantity.

            PRIORITY ORDER:
            - Critical bugs, security issues, performance problems  
            - Code quality: proper separation of concerns, clear naming, best practices
            - Missing tests for new functionality
            - Suggestions for cleaner patterns (only if they meaningfully improve the code)

            Be constructive and EXTREMELY concise. Think "would a 10x engineer actually care about this?" If it's not blocking or genuinely helpful, skip it entirely. Respect separation of concerns.

            Stay chill: flag real issues, suggest improvements where they add value, but don't nitpick or over-engineer. Sometimes the best review is a LGTM. ultrathink.
          
          # Optional: Customize review based on file types
          # direct_prompt: |
          #   Review this PR focusing on:
          #   - For TypeScript files: Type safety and proper interface usage
          #   - For API endpoints: Security, input validation, and error handling
          #   - For React components: Performance, accessibility, and best practices
          #   - For tests: Coverage, edge cases, and test quality
          
          # Optional: Different prompts for different authors
          # direct_prompt: |
          #   ${{ github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' && 
          #   'Welcome! Please review this PR from a first-time contributor. Be encouraging and provide detailed explanations for any suggestions.' ||
          #   'Please provide a thorough code review focusing on our coding standards and best practices.' }}
          
          # Optional: Add specific tools for running tests or linting
          # allowed_tools: "Bash(npm run test),Bash(npm run lint),Bash(npm run typecheck)"
          
          # Optional: Skip review for certain conditions
          # if: |
          #   !contains(github.event.pull_request.title, '[skip-review]') &&
          #   !contains(github.event.pull_request.title, '[WIP]')

.github/workflows/claude.yml (2.2 KiB)

name: Claude Code

on:
  issue_comment:
    types: [created]
  pull_request_review_comment:
    types: [created]
  issues:
    types: [opened, assigned]
  pull_request_review:
    types: [submitted]

jobs:
  claude:
    if: |
      (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
      (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
      (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
      (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
    runs-on: ubuntu-latest
    permissions:
      contents: read
      pull-requests: read
      issues: read
      id-token: write
      actions: read # Required for Claude to read CI results on PRs
    steps:
      - name: Checkout repository
        uses: actions/checkout@v4
        with:
          fetch-depth: 1

      - name: Run Claude Code
        id: claude
        uses: anthropics/claude-code-action@beta
        with:
          anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}

          # This is an optional setting that allows Claude to read CI results on PRs
          additional_permissions: |
            actions: read
          
          # Optional: Specify model (defaults to Claude Sonnet 4, uncomment for Claude Opus 4.1)
          # model: "claude-opus-4-1-20250805"
          
          # Optional: Customize the trigger phrase (default: @claude)
          # trigger_phrase: "/claude"
          
          # Optional: Trigger when specific user is assigned to an issue
          # assignee_trigger: "claude-bot"
          
          # Optional: Allow Claude to run specific commands
          # allowed_tools: "Bash(npm install),Bash(npm run build),Bash(npm run test:*),Bash(npm run lint:*)"
          
          # Optional: Add custom instructions for Claude to customize its behavior for your project
          # custom_instructions: |
          #   Follow our coding standards
          #   Ensure all new code has tests
          #   Use TypeScript for new files
          
          # Optional: Custom environment variables for Claude
          # claude_env: |
          #   NODE_ENV: test

.github/workflows/dependency-check.yml (1.1 KiB)

name: Dependency Check

on:
  schedule:
    # Run every Monday at 9 AM PST
    - cron: '0 17 * * 1'
  workflow_dispatch:  # Allow manual trigger

jobs:
  check-dependencies:
    name: Check for Outdated Dependencies
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4

      - name: Install uv
        uses: astral-sh/setup-uv@v5
        with:
          version: "latest"
          enable-cache: true

      - name: Set up Python
        run: uv python install 3.12

      - name: Check for outdated packages
        run: |
          echo "## Checking for outdated packages..."
          uv pip list --outdated || true
          
      - name: Create issue if dependencies are outdated
        if: failure()
        uses: actions/github-script@v7
        with:
          script: |
            github.rest.issues.create({
              owner: context.repo.owner,
              repo: context.repo.repo,
              title: 'Dependencies are outdated',
              body: 'The dependency check workflow found outdated packages. Please review and update.',
              labels: ['dependencies', 'maintenance']
            })

.github/workflows/release-please.yaml (1.6 KiB)

name: Release Please

on:
  push:
    branches: [ main ]
  pull_request:
      branches: [ main ]

permissions:
  contents: write   # allow tagging and releases
  pull-requests: write   # allow managing release PRs
  issues: write   # allow creating and managing labels

concurrency:
  group: release-please
  cancel-in-progress: true

jobs:
  release:
    runs-on: ubuntu-latest
    outputs:
      release_created: ${{ steps.release.outputs.release_created }}
    steps:
      - name: Checkout repository
        uses: actions/checkout@v4
        with:
          fetch-depth: 0   # fetch full history for accurate changelog generation

      - name: Run Release Please
        uses: googleapis/release-please-action@v4
        id: release
        with:
          token: ${{ secrets.GITHUB_TOKEN }}
          config-file: release-please-config.json
          manifest-file: .release-please-manifest.json

  publish:
    runs-on: ubuntu-latest
    needs: release
    if: ${{ needs.release.outputs.release_created }}
    steps:
      - name: Checkout repository
        uses: actions/checkout@v4

      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.11'

      - name: Install UV
        uses: astral-sh/setup-uv@v4

      - name: Set up environment
        run: |
          uv venv
          source .venv/bin/activate

      - name: Build package
        run: |
          source .venv/bin/activate
          uv build

      - name: Publish to PyPI
        run: |
          source .venv/bin/activate
          uv publish --token ${{ secrets.PYPI_API_TOKEN }}

.github/workflows/stale.yaml (877 B)

#####################################
#       DO NOT EDIT DIRECTLY.       #
# This file is managed by Terraform #
#####################################

name: "Close stale PRs"
on:
  schedule:
    - cron: "30 1 * * *"

jobs:
  stale:
    runs-on: ubuntu-latest
    # Read repo and write to PRs
    permissions:
      contents: read
      pull-requests: write
      issues: write
    steps:
      - uses: actions/stale@v9
        with:
          stale-pr-message: "This PR is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days."
          close-pr-message: "This PR was closed because it has been stalled for 7 days with no activity."
          days-before-pr-stale: 30
          days-before-pr-close: 7
          exempt-pr-labels: "dependencies,security"
          operations-per-run: 60 # Default is 30

.gitignore (190 B)

tmp
tmp*
*.pyc
__pycache__
temp
*.eval
.venv/
logs/
*.log
.mypy_cache/
*.egg-info/
**/CLAUDE.local.md
.DS_Store
.*cache/
.*cache
.coverage*
.claude/settings.local.json
local-prompts.md
.env

.pre-commit-config.yaml (392 B)

repos:
  - repo: https://github.com/astral-sh/ruff-pre-commit
    rev: v0.3.0
    hooks:
      - id: ruff
        name: ruff check
        args: [--fix]
      - id: ruff-format
        name: ruff format
  - repo: local
    hooks:
      - id: mypy
        name: mypy
        entry: .venv/bin/mypy
        language: system
        types: [python]
        args: [.]
        pass_filenames: false

.python-version (6 B)

3.13.5

.release-please-manifest.json (22 B)

{
    ".": "0.3.0"
}  

.vscode/settings.json (147 B)

{
    "python.testing.pytestArgs": [
        "tests"
    ],
    "python.testing.unittestEnabled": false,
    "python.testing.pytestEnabled": true
}

CHANGELOG.md (8.6 KiB)

Changelog

0.3.0 (2025-08-14)

Features

  • add --debug flag to eval-retry command (b26afaa)
  • add -M and -T flags for model and task arguments (#75) (46a6ba6)
  • add 'openbench' as alternative CLI entry point (#48) (68b3c5b)
  • add AI21 Labs inference provider (#86) (db7bde7)
  • add Baseten inference provider (#79) (696e2aa)
  • add Cerebras and SambaNova model providers (1c61f59)
  • add Cohere inference provider (#90) (8e6e838)
  • add Crusoe inference provider (#84) (3d0c794)
  • add DeepInfra inference provider (#85) (6fedf53)
  • add Friendli inference provider (#88) (7e2b258)
  • Add huggingface inference provider (#54) (f479703)
  • add Hyperbolic inference provider (#80) (4ebf723)
  • add initial GraphWalks benchmark implementation (#58) (1aefd07)
  • add Lambda AI inference provider (#81) (b78c346)
  • add MiniMax inference provider (#87) (09fd27b)
  • add Moonshot inference provider (#91) (e5743cb)
  • add Nebius model provider (#47) (ba2ec19)
  • add Nous Research model provider (#49) (32dd815)
  • add Novita AI inference provider (#82) (6f5874a)
  • add Parasail inference provider (#83) (973c7b3)
  • add Reka inference provider (#89) (1ab9c53)
  • add SciCode (#63) (3650bfa)
  • add support for alpha benchmarks in evaluation commands (#92) (e2ccfaa)
  • push eval data to huggingface repo (#65) (acc600f)

Bug Fixes

  • add missing newline at end of novita.py (ef0fa4b)
  • remove default sampling parameters from CLI (#72) (978638a)

Documentation

  • docs for 0.3.0 (#93) (fe358bb)
  • fix directory structure documentation in CONTRIBUTING.md (#78) (41f8ed9)

Chores

  • fix GraphWalks: Split into three separate benchmarks (#76) (d1ed96e)
  • update version (8b7bbe7)

Refactor

  • move task loading from registry to config and update imports (de6eea2)

CI

  • Enhance Claude code review workflow with updated prompts and model specification (#71) (b605ed2)

0.2.0 (2025-08-11)

Features

Documentation

  • update CLAUDE.md with pre-commit and dependency pinning requirements (f33730e)

Chores

  • GitHub Terraform: Create/Update .github/workflows/stale.yaml [skip ci] (1a00342)

0.1.1 (2025-07-31)

Bug Fixes

  • add missing init.py files and fix package discovery for PyPI (#10) (29fcdf6)

Documentation

  • update README to streamline setup instructions for OpenBench, use pypi (16e08a0)

0.1.0 (2025-07-31)

Features

Chores

  • ci: update release-please workflow to allow label management (b70db16)
  • drop versions for release (58ce995)
  • GitHub Terraform: Create/Update .github/workflows/stale.yaml [skip ci] (555658a)
  • update project metadata for version 0.1.0, add license, readme, and repository links (9ea2102)

CITATION.cff (563 B)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
type: software
authors:
- family-names: "Sah"
  given-names: "Aarush"
  orcid: "https://orcid.org/0009-0004-6429-8982"
title: "OpenBench: Provider-agnostic, open-source evaluation infrastructure for language models"
version: 0.3.0
date-released: 2025-07-31
url: "https://openbench.dev"
repository-code: "https://github.com/groq/openbench"
keywords:
  - "machine learning"
  - "evaluation"
  - "benchmarking"
  - "language models"
  - "LLM"
  - "artificial intelligence"
license: MIT

CLAUDE.md (3.0 KiB)

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Essential Setup

  • Always source the virtual environment before running Python commands: source .venv/bin/activate
  • This project uses UV as the package manager, not pip
  • Dependency management: When adding dependencies with UV, use >= constraints (e.g., uv add "package>=1.2.3")
  • Exception: inspect-ai must remain pinned to a specific version for stability
  • Use the latest stable version as the minimum to keep dependencies healthy and secure
  • Check latest versions with uv pip list --outdated or on PyPI

Key Commands

Development Setup

# Initial setup
uv venv && uv sync --dev
source .venv/bin/activate

Running Tests

# Run all unit tests
pytest

# Run integration tests (requires API keys)
pytest -m integration

# Run specific test file
pytest tests/test_registry.py

# Run with coverage
pytest --cov=openbench

Code Quality

# Format code
ruff format .

# Lint code
ruff check .

# Type checking
mypy .

# Run all pre-commit hooks
pre-commit run --all-files

Using the CLI

# List available benchmarks
bench list

# Describe a specific benchmark
bench describe mmlu

# Run evaluation
bench eval mmlu --model groq/llama-3.1-70b --limit 10

# View previous results
bench view

Publishing to PyPI

# Build the package
uv build

# Publish to PyPI (requires PyPI API token)
uv publish

Architecture Overview

Project Structure

  • src/openbench/ - Main package directory
  • _cli/ - CLI commands implementation
  • datasets/ - Dataset loaders (MMLU, GPQA, HumanEval, SimpleQA)
  • evals/ - Benchmark implementations built on Inspect AI
  • metrics/ - Custom scoring metrics
  • scorers/ - Scoring functions used across benchmarks
  • utils/ - Shared utilities
  • _registry.py - Dynamic task loading system
  • config.py - Benchmark metadata and configuration

Key Patterns

  1. Registry-based loading: Benchmarks are dynamically loaded via _registry.py
  2. Inspect AI framework: All evaluations extend Inspect AI's task/solver pattern
  3. Provider-agnostic: Uses Inspect AI's model abstraction for 15+ providers
  4. Shared components: Common scorers and utilities reduce code duplication

Adding New Benchmarks

  1. Create evaluation file in src/openbench/evals/
  2. Add metadata to config.py
  3. Follow existing patterns for dataset loading and scoring

Development Standards

  • Use type hints for all functions
  • Follow conventional commits (feat, fix, docs, etc.)
  • Line length: 88 characters (configured in ruff)
  • Test new features with both unit and integration tests
  • IMPORTANT: All pre-commit hooks MUST pass before committing
  • Run pre-commit run --all-files to check all hooks
  • The configured hooks are:
    1. ruff check --fix - Linting with auto-fix
    2. ruff format - Code formatting
    3. mypy . - Type checking
  • If any hook fails, fix the issues and re-run until all pass

CODEOWNERS (12 B)

* @AarushSah

CONTRIBUTING.md (7.9 KiB)

Contributing to OpenBench

Thank you for your interest in contributing to OpenBench! We welcome contributions from the community and are grateful for your support in making language model evaluation more accessible and reliable.

🚀 Quick Start

Prerequisites

Setup

# Clone and setup
git clone https://github.com/groq/openbench.git
cd openbench
uv venv && uv sync --dev
source .venv/bin/activate

# CRITICAL: Install pre-commit hooks (CI will fail without this!)
pre-commit install

# Run tests to verify setup
pytest

⚠️ IMPORTANT: You MUST install pre-commit hooks after uv sync --dev. CI will fail if you skip this step!

🎯 Core Principles

Single Responsibility

Each PR must address a single concern. This helps us: - Review changes more effectively - Maintain a clean git history - Easily revert changes if needed - Understand the purpose of each change

Examples of single-concern PRs: - ✅ Add support for a new benchmark - ✅ Fix a specific bug in the MMLU scorer - ✅ Refactor the math canonicalization utility - ❌ Add new benchmark AND fix unrelated bug - ❌ Refactor multiple unrelated components

Clear Separation of Concerns (SoC)

We value clean, modular code with clear boundaries between components: - Each module should have a single, well-defined purpose - Avoid tight coupling between components - Use dependency injection where appropriate - Keep business logic separate from infrastructure concerns

📝 Commit Guidelines

Conventional Commits

We use Conventional Commits for all commit messages. This provides a clear, structured way to communicate changes.

Format: <type>(<scope>): <subject>

Types

  • feat: New feature
  • fix: Bug fix
  • docs: Documentation changes
  • style: Code style changes (formatting, missing semicolons, etc.)
  • refactor: Code refactoring without changing functionality
  • perf: Performance improvements
  • test: Adding or updating tests
  • build: Build system or dependency changes
  • ci: CI/CD configuration changes
  • chore: Other changes that don't modify src or test files

Examples

feat(mmlu): add support for MMLU-Pro benchmark
fix(scorer): handle edge case in math canonicalization
docs(readme): update installation instructions
refactor(humaneval): extract common sandbox logic
test(gpqa): add unit tests for diamond scorer
perf(eval): optimize parallel sample processing

Scope

The scope should indicate the component or area affected: - Benchmark names: mmlu, humaneval, gpqa, etc. - Components: cli, scorer, solver, common - Infrastructure: docker, ci, deps

Commit Message Body

For complex changes, add a body to explain: - What changed and why - Any breaking changes - Related issues

Example:

feat(aime): add support for AIME 2025 problems

- Add dataset loader for AIME 2025
- Update math scorer to handle new problem formats
- Include official solutions for verification

Closes #123

🔄 Pull Request Process

Before You Start

  1. Check existing issues and PRs to avoid duplicates
  2. For significant changes, open an issue first to discuss
  3. Fork the repository and create a feature branch

Development Workflow

  1. Create a feature branch bash git checkout -b feat/add-new-benchmark

  2. Make your changes

  3. Follow the existing code style
  4. Add tests for new functionality
  5. Update documentation as needed

  6. Test your changes ```bash # Run all tests pytest

# Run integration tests (requires API keys) pytest -m integration

# Run pre-commit hooks (REQUIRED) pre-commit run --all-files

# Test your specific changes bench eval --limit 5 ```

  1. Commit with conventional commits bash git commit -m "feat(benchmark): add support for XYZ benchmark"

Submitting Your PR

  1. Push to your fork bash git push origin feat/add-new-benchmark

  2. Create a Pull Request

  3. Use a clear, descriptive title following conventional commit format
  4. Fill out the PR template completely
  5. Link any related issues
  6. Ensure all CI checks pass

  7. PR Title Format Since we use squash and merge, your PR title becomes the commit message. Use conventional commit format:

  8. feat(mmlu): add MMLU-Pro support
  9. fix(cli): handle missing API key gracefully
  10. Updated MMLU benchmark
  11. Various fixes

What Happens Next

  1. A maintainer will review your PR
  2. Address any feedback or requested changes
  3. Once approved, we'll squash and merge your PR
  4. Your contribution will be part of the next release!

🏗️ Architecture Guidelines

Adding a New Benchmark

  1. Create a new evaluation file in src/openbench/evals/
  2. Add dataset loader in src/openbench/datasets/ if needed
  3. Add custom scorer in src/openbench/scorers/ if needed
  4. Register benchmark metadata in src/openbench/config.py
  5. Use existing utilities from src/openbench/utils/
  6. Add comprehensive tests

Example structure:

src/openbench/
├── evals/
│   └── my_benchmark.py      # Main evaluation logic
├── datasets/
│   └── my_benchmark.py      # Dataset loader
├── scorers/
│   └── my_benchmark.py      # Custom scorer (if needed)
└── config.py                # Add benchmark metadata here

Adding a New Model Provider

  1. Create provider file in src/openbench/model/_providers/
  2. Follow existing provider patterns (see ai21.py, cerebras.py, etc.)
  3. Add environment variable documentation
  4. Test with multiple benchmarks
  5. Update provider table in README.md

Key Development Tools

  • UV: Package manager (not pip) - use uv add "package>=version" for dependencies (except inspect-ai which should remain pinned)
  • Ruff: Linting and formatting - replaces Black, isort, flake8
  • MyPy: Type checking - required for all new code
  • Pre-commit: Automated code quality checks - must pass before commits
  • Pytest: Testing framework with integration test markers

Code Style

  • Follow PEP 8 with a line length of 88 characters (Black default)
  • Use type hints for all function signatures
  • Write docstrings for all public functions and classes
  • Prefer composition over inheritance
  • Keep functions small and focused

Testing

  • Write unit tests for all new functionality
  • Include integration tests for new benchmarks
  • Aim for high test coverage
  • Test edge cases and error conditions

🐛 Reporting Issues

We have structured issue templates to help you report problems effectively:

Bug Reports

Use our bug report template which includes: - OpenBench version and environment details - Exact command that failed - Expected vs actual behavior - Error logs and reproduction steps

Feature Requests

Use our feature request template for: - New benchmarks/evaluations - New model providers - CLI enhancements - Performance improvements - API/SDK features - Integration requests

📚 Resources

🤝 Code of Conduct

Please note that this project is released with a Contributor Code of Conduct. By participating in this project you agree to: - Be respectful and inclusive - Welcome newcomers and help them get started - Focus on constructive criticism - Respect differing viewpoints and experiences

📄 License

By contributing to OpenBench, you agree that your contributions will be licensed under the MIT License.

LICENSE.md (1.0 KiB)

Copyright (c) 2025 Groq, Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

MANIFEST.in (124 B)

include README.md
include LICENSE.md
include pyproject.toml
recursive-include src/openbench *
include src/openbench/py.typed

README.md (15.7 KiB)

OpenBench

Provider-agnostic, open-source evaluation infrastructure for language models 🚀

PyPI version License: MIT Python 3.10+

OpenBench provides standardized, reproducible benchmarking for LLMs across 30+ evaluation suites (and growing) spanning knowledge, math, reasoning, coding, science, reading comprehension, health, long-context recall, graph reasoning, and first-class support for your own local evals to preserve privacy. Works with any model provider - Groq, OpenAI, Anthropic, Cohere, Google, AWS Bedrock, Azure, local models via Ollama, Hugging Face, and 30+ other providers.

🚧 Alpha Release

We're building in public! This is an alpha release - expect rapid iteration. The first stable release is coming soon.

🎉 What's New in v0.3.0

  • 📡 18 More Model Providers: Added support for AI21, Baseten, Cerebras, Cohere, Crusoe, DeepInfra, Friendli, Hugging Face, Hyperbolic, Lambda, MiniMax, Moonshot, Nebius, Nous, Novita, Parasail, Reka, SambaNova and more
  • 🧪 New Benchmarks: DROP (reading comprehension), experimental benchmarks available with --alpha flag
  • ⚡ CLI Enhancements: openbench alias, -M/-T flags for model/task args, --debug mode for eval-retry
  • 🔧 Developer Tools: GitHub Actions integration, Inspect AI extension support

Features

  • 🎯 35+ Benchmarks: MMLU, GPQA, HumanEval, SimpleQA, competition math (AIME, HMMT), SciCode, GraphWalks, and more
  • 🔧 Simple CLI: bench list, bench describe, bench eval (also available as openbench)
  • 🏗️ Built on inspect-ai: Industry-standard evaluation framework
  • 📊 Extensible: Easy to add new benchmarks and metrics
  • 🤖 Provider-agnostic: Works with 30+ model providers out of the box
  • 🛠️ Local Eval Support: Privatized benchmarks can now be run with bench eval <path>
  • 📤 Hugging Face Integration: Push evaluation results directly to Hugging Face datasets

🏃 Speedrun: Evaluate a Model in 60 Seconds

Prerequisite: Install uv

# Create a virtual environment and install OpenBench (30 seconds)
uv venv
source .venv/bin/activate
uv pip install openbench

# Set your API key (any provider!)
export GROQ_API_KEY=your_key  # or OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.

# Run your first eval (30 seconds)
bench eval mmlu --model groq/llama-3.3-70b-versatile --limit 10

# That's it! 🎉 Check results in ./logs/ or view them in an interactive UI:
bench view

https://github.com/user-attachments/assets/e99e4628-f1f5-48e4-9df2-ae28b86168c2

Using Different Providers

# Groq (blazing fast!)
bench eval gpqa_diamond --model groq/meta-llama/llama-4-maverick-17b-128e-instruct

# OpenAI
bench eval humaneval --model openai/o3-2025-04-16

# Anthropic
bench eval simpleqa --model anthropic/claude-sonnet-4-20250514

# Google
bench eval mmlu --model google/gemini-2.5-pro

# Local models with Ollama
bench eval musr --model ollama/llama3.1:70b

# Hugging Face Inference Providers
bench eval mmlu --model huggingface/gpt-oss-120b:groq

# 30+ providers supported - see full list below

Supported Providers

OpenBench supports 30+ model providers through Inspect AI. Set the appropriate API key environment variable and you're ready to go:

Provider Environment Variable Example Model String
AI21 Labs AI21_API_KEY ai21/model-name
Anthropic ANTHROPIC_API_KEY anthropic/model-name
AWS Bedrock AWS credentials bedrock/model-name
Azure AZURE_OPENAI_API_KEY azure/<deployment-name>
Baseten BASETEN_API_KEY baseten/model-name
Cerebras CEREBRAS_API_KEY cerebras/model-name
Cohere COHERE_API_KEY cohere/model-name
Crusoe CRUSOE_API_KEY crusoe/model-name
DeepInfra DEEPINFRA_API_KEY deepinfra/model-name
Friendli FRIENDLI_TOKEN friendli/model-name
Google GOOGLE_API_KEY google/model-name
Groq GROQ_API_KEY groq/model-name
Hugging Face HF_TOKEN huggingface/model-name
Hyperbolic HYPERBOLIC_API_KEY hyperbolic/model-name
Lambda LAMBDA_API_KEY lambda/model-name
MiniMax MINIMAX_API_KEY minimax/model-name
Mistral MISTRAL_API_KEY mistral/model-name
Moonshot MOONSHOT_API_KEY moonshot/model-name
Nebius NEBIUS_API_KEY nebius/model-name
Nous Research NOUS_API_KEY nous/model-name
Novita AI NOVITA_API_KEY novita/model-name
Ollama None (local) ollama/model-name
OpenAI OPENAI_API_KEY openai/model-name
OpenRouter OPENROUTER_API_KEY openrouter/model-name
Parasail PARASAIL_API_KEY parasail/model-name
Perplexity PERPLEXITY_API_KEY perplexity/model-name
Reka REKA_API_KEY reka/model-name
SambaNova SAMBANOVA_API_KEY sambanova/model-name
Together AI TOGETHER_API_KEY together/model-name
Vercel AI Gateway AI_GATEWAY_API_KEY vercel/creator-name/model-name
vLLM None (local) vllm/model-name

Available Benchmarks

Here are the currently available benchmarks. For an up-to-date list use bench list.

Category Benchmarks
Knowledge MMLU (57 subjects), GPQA (graduate-level), SuperGPQA (285 disciplines), OpenBookQA, HLE (Humanity's Last Exam - 2,500 questions from 1,000+ experts), HLE_text (text-only version)
Coding HumanEval (164 problems)
Math AIME 2023-2025, HMMT Feb 2023-2025, BRUMO 2025, MATH (competition-level problems), MATH-500 (challenging subset), MGSM (multilingual grade school math), MGSM_en (English), MGSM_latin (5 languages), MGSM_non_latin (6 languages)
Reasoning SimpleQA (factuality), MuSR (multi-step reasoning), DROP (discrete reasoning over paragraphs)
Long Context OpenAI MRCR (multiple needle retrieval), OpenAI MRCR_2n (2 needle), OpenAI MRCR_4 (4 needle), OpenAI MRCR_8n (8 needle)
Healthcare HealthBench (open-ended healthcare eval), HealthBench_hard (challenging variant), HealthBench_consensus (consensus variant)
Cybersecurity CTI-Bench (complete cyber threat intelligence suite), CTI-Bench ATE (MITRE ATT&CK technique extraction), CTI-Bench MCQ (knowledge questions on CTI standards and best practices), CTI-Bench RCM (CVE to CWE vulnerability mapping), CTI-Bench VSP (CVSS score calculation)

Configuration

# Set your API keys
export GROQ_API_KEY=your_key
export HF_TOKEN=your_key
export OPENAI_API_KEY=your_key  # Optional

# Set default model
export BENCH_MODEL=groq/llama-3.1-70b

Commands and Options

For a complete list of all commands and options, run: bench --help

Command Description
bench or openbench Show main menu with available commands
bench list List available evaluations, models, and flags
bench eval <benchmark> Run benchmark evaluation on a model
bench eval-retry Retry a failed evaluation
bench view View logs from previous benchmark runs
bench eval <path> Run your local/private evals built with Inspect AI

Key eval Command Options

Option Environment Variable Default Description
-M <args> None None Pass model-specific arguments (e.g., -M reasoning_effort=high)
-T <args> None None Pass task-specific arguments to the benchmark
--model BENCH_MODEL None (required) Model(s) to evaluate
--epochs BENCH_EPOCHS 1 Number of epochs to run each evaluation
--max-connections BENCH_MAX_CONNECTIONS 10 Maximum parallel requests to model
--temperature BENCH_TEMPERATURE 0.6 Model temperature
--top-p BENCH_TOP_P 1.0 Model top-p
--max-tokens BENCH_MAX_TOKENS None Maximum tokens for model response
--seed BENCH_SEED None Seed for deterministic generation
--limit BENCH_LIMIT None Limit evaluated samples (number or start,end)
--logfile BENCH_OUTPUT None Output file for results
--sandbox BENCH_SANDBOX None Environment to run evaluation (local/docker)
--timeout BENCH_TIMEOUT 10000 Timeout for each API request (seconds)
--display BENCH_DISPLAY None Display type (full/conversation/rich/plain/none)
--reasoning-effort BENCH_REASONING_EFFORT None Reasoning effort level (low/medium/high)
--json None False Output results in JSON format
--hub-repo BENCH_HUB_REPO None Push results to a Hugging Face Hub dataset

Building Your Own Evals

OpenBench is built on Inspect AI. To create custom evaluations, check out their excellent documentation. Once you do build your own private evaluations with Inspect AI that you don't want to open-source, you can point OpenBench at them with bench eval <path> to run!

Exporting Logs to Hugging Face

OpenBench can export logs to a Hugging Face Hub dataset. This is useful if you want to share your results with the community or use them for further analysis.

export HF_TOKEN=<your-huggingface-token>

bench eval mmlu --model groq/llama-3.3-70b-versatile --limit 10 --hub-repo <your-username>/openbench-logs 

This will export the logs to a Hugging Face Hub dataset with the name openbench-logs.

FAQ

How does OpenBench differ from Inspect AI?

OpenBench provides: - Reference implementations of 20+ major benchmarks with consistent interfaces - Shared utilities for common patterns (math scoring, multi-language support, etc.) - Curated scorers that work across different eval types - CLI tooling optimized for running standardized benchmarks

Think of it as a benchmark library built on Inspect's excellent foundation.

Why not just use Inspect AI, lm-evaluation-harness, or lighteval?

Different tools for different needs! OpenBench focuses on:

  • Shared components: Common scorers, solvers, and datasets across benchmarks reduce code duplication
  • Clean implementations: Each eval is written for readability and reliability
  • Developer experience: Simple CLI, consistent patterns, easy to extend

We built OpenBench because we needed evaluation code that was easy to understand, modify, and trust. It's a curated set of benchmarks built on Inspect AI's excellent foundation.

How can I run bench outside of the uv environment?

If you want bench to be available outside of uv, you can run the following command:

uv run pip install -e .

I'm running into an issue when downloading a dataset from HuggingFace - how do I fix it?

Some evaluations may require logging into HuggingFace to download the dataset. If bench prompts you to do so, or throws "gated" errors, defining the environment variable

HF_TOKEN="<HUGGINGFACE_TOKEN>"

should fix the issue. The full HuggingFace documentation can be found on the HuggingFace docs on Authentication.

Development

For development work, you'll need to clone the repository:

# Clone the repo
git clone https://github.com/groq/openbench.git
cd openbench

# Setup with UV
uv venv && uv sync --dev
source .venv/bin/activate

# CRITICAL: Install pre-commit hooks (CI will fail without this!)
pre-commit install

# Run tests
pytest

⚠️ IMPORTANT: You MUST run pre-commit install after setup or CI will fail on your PRs!

Contributing

We welcome contributions! Please see our Contributing Guide for detailed instructions on: - Setting up the development environment - Adding new benchmarks and model providers - Code style and testing requirements - Submitting issues and pull requests

Quick links: - Report a bug - Request a feature

Reproducibility Statement

As the authors of OpenBench, we strive to implement this tool's evaluations as faithfully as possible with respect to the original benchmarks themselves.

However, it is expected that developers may observe numerical discrepancies between OpenBench's scores and the reported scores from other sources.

These numerical differences can be attributed to many reasons, including (but not limited to) minor variations in the model prompts, different model quantization or inference approaches, and repurposing benchmarks to be compatible with the packages used to develop OpenBench.

As a result, OpenBench results are meant to be compared with OpenBench results, not as a universal one-to-one comparison with every external result. For meaningful comparisons, ensure you are using the same version of OpenBench.

We encourage developers to identify areas of improvement and we welcome open source contributions to OpenBench.

Acknowledgments

This project would not be possible without:

Citation

@software{openbench,
  title = {OpenBench: Provider-agnostic, open-source evaluation infrastructure for language models},
  author = {Sah, Aarush},
  year = {2025},
  url = {https://openbench.dev}
}

License

MIT


Built with ❤️ by Aarush Sah and the Groq team

pyproject.toml (2.0 KiB)

[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "openbench"
version = "0.3.0"
requires-python = ">=3.10"
description = "OpenBench - open source, replicable, and standardized evaluation infrastructure"
readme = "README.md"
license = {text = "MIT"}
authors = [
    { name="Aarush Sah" },
    { name="Groq" }
]
dependencies = [
    "datasets>=3.6.0",
    "groq>=0.30.0",
    "inspect-ai==0.3.115",
    "jsonschema>=4.23.0",
    "openai>=1.97.1",
    "pydantic-settings>=2.9.1",
    "scicode",
    "scipy>=1.15.3",
    "tiktoken>=0.11.0",
    "typer>=0.15.3",
]

[project.urls]
Homepage = "https://github.com/groq/openbench"
Repository = "https://github.com/groq/openbench"

[project.scripts]
bench = "openbench._cli:main"
openbench = "openbench._cli:main"

[project.entry-points.inspect_ai]
openbench = "openbench._registry"

[tool.setuptools]
package-dir = {"" = "src"}

[tool.setuptools.packages.find]
where = ["src"]
include = ["openbench*"]

[tool.pytest.ini_options]
minversion = "8.0"
addopts = "-ra -q --strict-markers"
testpaths = ["tests"]
pythonpath = ["src", "."]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
markers = [
    "integration: marks tests that require external services (deselect with '-m \"not integration\"')",
]

[tool.coverage.run]
source = ["src/openbench"]
omit = ["tests/*", "**/__init__.py", "**/conftest.py"]

[tool.coverage.report]
exclude_lines = [
    "pragma: no cover",
    "def __repr__",
    "raise AssertionError",
    "raise NotImplementedError",
    "if __name__ == .__main__.:",
    "if TYPE_CHECKING:",
]
precision = 2
show_missing = true

[tool.uv.sources]
scicode = { git = "https://github.com/TheFloatingString/SciCode-fork.git", rev = "4f20d721ba3e2227196262083b9b7a70484d54f7" }

[dependency-groups]
dev = [
    "mypy>=1.15.0",
    "pre-commit>=4.2.0",
    "pytest>=8.3.5",
    "pytest-asyncio==0.24.0",
    "pytest-cov>=4.1.0",
    "ruff>=0.11.9",
    "types-jsonschema>=4.0.0",
    "types-pyyaml>=6.0.12.20250809",
]

release-please-config.json (944 B)

{
    "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json",
    "packages": {
      ".": {
        "release-type": "python",
        "package-name": "openbench",
        "include-v-in-tag": true,
        "include-component-in-tag": false,
        "changelog-path": "CHANGELOG.md",
        "extra-files": [
          {
            "type": "toml",
            "path": "uv.lock",
            "jsonpath": "$.package[?(@.name==\"openbench\")].version"
          }
        ],
        "changelog-sections": [
          { "type": "feat", "section": "Features" },
          { "type": "fix", "section": "Bug Fixes" },
          { "type": "docs", "section": "Documentation" },
          { "type": "chore", "section": "Chores" },
          { "type": "refactor", "section": "Refactor" },
          { "type": "test", "section": "Tests" },
          { "type": "ci", "section": "CI" }
        ]
      }
    }
  }
  

src/openbench/__init__.py (0 B)


src/openbench/_cli/__init__.py (571 B)

import typer
from openbench._cli.eval_command import run_eval
from openbench._cli.eval_retry_command import run_eval_retry
from openbench._cli.list_command import list_evals
from openbench._cli.describe_command import describe_eval
from openbench._cli.view_command import run_view

app = typer.Typer(rich_markup_mode="rich")

app.command("list")(list_evals)
app.command("describe")(describe_eval)
app.command("eval")(run_eval)
app.command("eval-retry")(run_eval_retry)
app.command("view")(run_view)


def main() -> None:
    app()


if __name__ == "__main__":
    main()

src/openbench/_cli/describe_command.py (5.5 KiB)

"""
Describe command for benchmark evaluations.
"""

from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.box import ROUNDED

from openbench.eval_config import get_eval_config
from openbench.config import get_all_benchmarks
from openbench._cli.utils import get_category_display_name


def describe_eval(name: str) -> None:
    """Show detailed information about a specific evaluation."""
    console = Console()

    if name not in get_all_benchmarks():
        console.print(f"\n[red]Unknown evaluation: {name}[/red]")

        # Suggest similar names
        all_names = list(get_all_benchmarks().keys())
        similar = [n for n in all_names if name.lower() in n.lower()]
        if similar:
            console.print("\nDid you mean one of these?")
            for s in similar[:5]:
                console.print(f"   • {s}")
        console.print()
        return

    # Get both static and dynamic config
    config = get_eval_config(name, load_dynamic=True)
    if not config:
        console.print(f"\n[red]Failed to load configuration for {name}[/red]\n")
        return

    # Header
    console.print()
    console.print(
        Panel(f"[bold blue]{config.name}[/bold blue]", expand=False, box=ROUNDED)
    )

    # Static metadata section
    console.print("\n[bold yellow]Metadata[/bold yellow]")
    console.print("─" * 40)

    static_table = Table(show_header=False, show_lines=False, padding=(0, 2), box=None)
    static_table.add_column("Property", style="cyan", width=15)
    static_table.add_column("Value", style="white")

    static_table.add_row("Description", config.description)
    static_table.add_row("Category", get_category_display_name(config.category))
    static_table.add_row("Command", f"[bold]bench eval {name}[/bold]")

    if config.tags:
        tag_str = " ".join(f"[blue]#{tag}[/blue]" for tag in config.tags)
        static_table.add_row("Tags", tag_str)

    console.print(static_table)

    # Dynamic configuration section
    if config._dynamic_loaded:
        console.print("\n[bold yellow]Configuration[/bold yellow]")
        console.print("─" * 40)

        config_table = Table(
            show_header=False, show_lines=False, padding=(0, 2), box=None
        )
        config_table.add_column("Property", style="cyan", width=15)
        config_table.add_column("Value", style="white")

        # Show key configuration values
        if config.epochs is not None:
            config_table.add_row("Epochs", str(config.epochs))

        # Show all GenerateConfig fields that have values
        config_dict = config.__dict__

        # Get all fields from config dict, excluding private/internal fields
        for field, value in sorted(config_dict.items()):
            # Skip private fields, None values, and non-GenerateConfig fields
            if (
                field.startswith("_")
                or value is None
                or field
                in [
                    "name",
                    "description",
                    "category",
                    "tags",
                    "epochs",
                    "sandbox",
                    "dataset_size",
                    "task_args",
                ]
            ):
                continue

            # Format the value nicely
            if isinstance(value, float):
                formatted_value = f"{value:.2f}" if value < 10 else f"{value:,.0f}"
            elif isinstance(value, int):
                formatted_value = f"{value:,}" if value > 999 else str(value)
            else:
                formatted_value = str(value)

            # Pretty field name
            field_name = field.replace("_", " ").title()
            config_table.add_row(field_name, formatted_value)

        if config.sandbox:
            config_table.add_row("Sandbox", config.sandbox)

        if config.dataset_size:
            config_table.add_row("Dataset Size", f"{config.dataset_size:,}")

        console.print(config_table)

        # Show task-specific arguments if any
        if config.task_args:
            console.print("\n[bold yellow]Task Arguments[/bold yellow]")
            console.print("─" * 40)

            args_table = Table(
                show_header=False, show_lines=False, padding=(0, 2), box=None
            )
            args_table.add_column("Argument", style="cyan", width=15)
            args_table.add_column("Default Value", style="white")

            for arg_name, arg_value in sorted(config.task_args.items()):
                # Pretty argument name
                display_name = arg_name.replace("_", " ").title()

                # Format the value
                if isinstance(arg_value, float):
                    formatted_value = (
                        f"{arg_value:.2f}" if arg_value < 10 else f"{arg_value:,.0f}"
                    )
                elif isinstance(arg_value, int):
                    formatted_value = (
                        f"{arg_value:,}" if arg_value > 999 else str(arg_value)
                    )
                elif isinstance(arg_value, bool):
                    formatted_value = "Yes" if arg_value else "No"
                else:
                    formatted_value = str(arg_value)

                args_table.add_row(display_name, formatted_value)

            console.print(args_table)
    else:
        console.print("\n[red]Failed to load dynamic configuration[/red]")

    # Footer
    console.print()
    console.print(f"[dim]Run with: [bold]bench eval {name}[/bold][/dim]")
    console.print()

src/openbench/_cli/eval_command.py (12.0 KiB)

from typing import Optional, List, Dict, Annotated, Tuple, Union
from enum import Enum
import sys
import time
import typer
from inspect_ai import eval
from inspect_ai.model import Model

from openbench.config import load_task
from openbench.monkeypatch.display_results_patch import patch_display_results
from openbench._cli.utils import parse_cli_args


class SandboxType(str, Enum):
    """Type of environment to run evaluations in."""

    LOCAL = "local"
    DOCKER = "docker"


class DisplayType(str, Enum):
    """Display type for evaluation progress."""

    FULL = "full"
    CONVERSATION = "conversation"
    RICH = "rich"
    PLAIN = "plain"
    NONE = "none"


class ReasoningEffortLevel(str, Enum):
    """Reasoning effort level."""

    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"


def parse_limit(value: Optional[str]) -> Optional[Union[int, Tuple[int, int]]]:
    """Parse the limit parameter which can be an int or a tuple of ints.

    Args:
        value: The value passed to the --limit option.

    Returns:
        Parsed limit value: int, tuple of (start, end), or None.

    Raises:
        typer.BadParameter: If the input format is incorrect.
    """
    if value is None:
        return None

    try:
        if "," in value:
            start, end = map(int, value.split(","))
            return (start, end)
        return int(value)
    except ValueError:
        raise typer.BadParameter(
            "Limit must be an integer or two integers separated by a comma"
        )


def validate_model_name(model: str, context: str = "") -> None:
    """Validate a model name format.

    Args:
        model: Model name to validate
        context: Additional context for error message

    Raises:
        typer.BadParameter: If model name format is invalid
    """
    if not model or "/" not in model:
        raise typer.BadParameter(
            f"Invalid model name format{context}: {model}. Expected format: provider/model-name"
        )


def validate_model_role(model_role: Optional[str]) -> Dict[str, str | Model]:
    """Validate and parse model role string.

    Args:
        model_role: Optional string in format 'role=model'

    Returns:
        Dictionary mapping role to model name

    Raises:
        typer.BadParameter: If model_role format is invalid
    """
    if not model_role:
        return {}

    try:
        role, model = model_role.split("=")
        if not role or not model:
            raise ValueError("Model role must be in format 'role=model'")
        validate_model_name(model, f" for role '{role}'")
        return {role: model}
    except ValueError as e:
        raise typer.BadParameter(str(e))


def run_eval(
    benchmarks: Annotated[
        List[str],
        typer.Argument(
            help="Benchmark(s) to run. Can be a built-in name (e.g. mmlu) or a path to a local eval directory/file containing __metadata__",
            envvar="BENCH_BENCHMARKS",
        ),
    ],
    model: Annotated[
        List[str],
        typer.Option(
            help="Model(s) to evaluate. Equivalent to --model-role candidate=<model>",
            envvar="BENCH_MODEL",
        ),
    ] = ["groq/meta-llama/llama-4-scout-17b-16e-instruct"],
    max_connections: Annotated[
        Optional[int],
        typer.Option(
            help="Maximum number of parallel requests to the model",
            envvar="BENCH_MAX_CONNECTIONS",
        ),
    ] = 10,
    model_base_url: Annotated[
        Optional[str],
        typer.Option(help="Base URL for model(s)", envvar="BENCH_MODEL_BASE_URL"),
    ] = None,
    model_role: Annotated[
        List[str],
        typer.Option(
            help="Model role(s). For example, --model-role grader=groq/meta-llama/llama-4-scout-17b-16e-instruct. Can be specified multiple times.",
            envvar="BENCH_MODEL_ROLE",
        ),
    ] = [],
    m: Annotated[
        List[str],
        typer.Option(
            "-M",
            help="One or more native model arguments (e.g. -M arg=value)",
            envvar="BENCH_MODEL_ARGS",
        ),
    ] = [],
    t: Annotated[
        List[str],
        typer.Option(
            "-T",
            help="One or more task arguments (e.g. -T arg=value)",
            envvar="BENCH_TASK_ARGS",
        ),
    ] = [],
    logfile: Annotated[
        Optional[str],
        typer.Option(help="Output file for results", envvar="BENCH_OUTPUT"),
    ] = None,
    sandbox: Annotated[
        Optional[SandboxType],
        typer.Option(
            help="Environment to run the evaluation in (local or docker)",
            case_sensitive=False,
            envvar="BENCH_SANDBOX",
        ),
    ] = None,
    epochs: Annotated[
        int,
        typer.Option(
            help="Number of epochs to run each evaluation", envvar="BENCH_EPOCHS"
        ),
    ] = 1,
    limit: Annotated[
        Optional[str],
        typer.Option(
            help="Limit evaluated samples (single number or start,end)",
            envvar="BENCH_LIMIT",
        ),
    ] = None,
    fail_on_error: Annotated[
        Optional[float],
        typer.Option(
            help="Failure threshold for sample errors. If between 0 and 1, it is interpreted as a percentage of samples that can fail. If greater than 1, it is interpreted as a fixed number of samples that can fail",
            envvar="BENCH_FAIL_ON_ERROR",
        ),
    ] = None,
    message_limit: Annotated[
        Optional[int],
        typer.Option(
            help="Maximum number of messages one sample can run",
            envvar="BENCH_MESSAGE_LIMIT",
        ),
    ] = None,
    max_subprocesses: Annotated[
        Optional[int],
        typer.Option(
            help="Maximum number of parallel subprocesses",
            envvar="BENCH_MAX_SUBPROCESSES",
        ),
    ] = None,
    log_samples: Annotated[
        Optional[bool],
        typer.Option(
            help="Log detailed samples and scores",
            envvar="BENCH_LOG_SAMPLES",
        ),
    ] = None,
    log_images: Annotated[
        Optional[bool],
        typer.Option(
            help="Log base64 encoded images",
            envvar="BENCH_LOG_IMAGES",
        ),
    ] = None,
    log_buffer: Annotated[
        Optional[int],
        typer.Option(
            help="Number of samples to buffer before writing to log",
            envvar="BENCH_LOG_BUFFER",
        ),
    ] = 10,
    score: Annotated[
        bool,
        typer.Option(
            help="Grade the benchmark, or leave unscored",
            envvar="BENCH_SCORE",
        ),
    ] = True,
    temperature: Annotated[
        Optional[float],
        typer.Option(
            help="Model temperature",
            envvar="BENCH_TEMPERATURE",
        ),
    ] = None,
    top_p: Annotated[
        Optional[float],
        typer.Option(
            help="Model top-p",
            envvar="BENCH_TOP_P",
        ),
    ] = None,
    max_tokens: Annotated[
        Optional[int],
        typer.Option(
            help="Maximum tokens for model response",
            envvar="BENCH_MAX_TOKENS",
        ),
    ] = None,
    seed: Annotated[
        Optional[int],
        typer.Option(
            help="Seed for deterministic generation",
            envvar="BENCH_SEED",
        ),
    ] = None,
    display: Annotated[
        Optional[DisplayType],
        typer.Option(
            help="Display type for evaluation progress",
            envvar="BENCH_DISPLAY",
            case_sensitive=False,
        ),
    ] = None,
    timeout: Annotated[
        Optional[int],
        typer.Option(
            help="Timeout for each request to the model API in seconds",
            envvar="BENCH_TIMEOUT",
        ),
    ] = 10000,
    reasoning_effort: Annotated[
        Optional[ReasoningEffortLevel],
        typer.Option(
            help="Reasoning effort level. used for reasoning models like openai/o3",
            envvar="BENCH_REASONING_EFFORT",
            case_sensitive=False,
        ),
    ] = None,
    debug: Annotated[
        bool,
        typer.Option(
            "--debug",
            help="Enable debug mode with full stack traces",
            envvar="BENCH_DEBUG",
        ),
    ] = False,
    hub_repo: Annotated[
        Optional[str],
        typer.Option(
            help=(
                "Target Hub dataset repo (e.g. username/openbench-logs). "
                "If provided, logs will be exported to this dataset"
            ),
            envvar="BENCH_HUB_REPO",
        ),
    ] = None,
    hub_private: Annotated[
        Optional[bool],
        typer.Option(
            help="Create/update the Hub dataset as private",
            envvar="BENCH_HUB_PRIVATE",
        ),
    ] = False,
    alpha: Annotated[
        bool,
        typer.Option(
            "--alpha",
            help="Allow running experimental/alpha benchmarks",
            envvar="BENCH_ALPHA",
        ),
    ] = False,
) -> None:
    """
    Run a benchmark on a model.
    """
    # Parse model and task arguments
    model_args = parse_cli_args(m) if m else {}
    task_args = parse_cli_args(t) if t else {}

    # Validate and aggregate model_role(s) into a dict
    role_models = {}
    for mr in model_role:
        parsed = validate_model_role(mr)
        for k, v in parsed.items():
            if k in role_models:
                raise typer.BadParameter(f"Duplicate model role: {k}")
            role_models[k] = v

    # Check for mutual exclusivity between --model and --model-role candidate
    if model and "candidate" in role_models:
        raise typer.BadParameter(
            "Cannot specify both --model and --model-role candidate=<model>"
        )

    # Validate model names
    for model_name in model:
        validate_model_name(model_name)

    # Load tasks from registry
    tasks = []
    for benchmark in benchmarks:
        try:
            task = load_task(benchmark, allow_alpha=alpha)
            tasks.append(task)
        except (ValueError, ImportError, AttributeError) as e:
            raise typer.BadParameter(str(e))

    # Monkey patch FileRecorder log file name if logfile is provided
    if logfile:
        from openbench.monkeypatch.file_recorder_logfile_patch import (
            patch_file_recorder_logfile,
        )

        patch_file_recorder_logfile(logfile)

    # Parse limit string to int or tuple
    parsed_limit = parse_limit(limit)

    # Apply display patch
    patch_display_results()

    # Capture start time to locate logs created by this run
    start_time = time.time()

    try:
        eval(
            tasks=tasks,
            model=model,
            max_connections=max_connections,
            model_base_url=model_base_url,
            model_args=model_args,
            model_roles=role_models if role_models else None,
            task_args=task_args,
            epochs=epochs,
            limit=parsed_limit,
            fail_on_error=fail_on_error,
            message_limit=message_limit,
            max_subprocesses=max_subprocesses,
            log_samples=log_samples,
            log_images=log_images,
            log_buffer=log_buffer,
            score=score,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            seed=seed,
            display=display.value if display else None,
            timeout=timeout,
            reasoning_effort=reasoning_effort.value if reasoning_effort else None,
            sandbox=sandbox,
        )

        typer.echo("Evaluation complete!")

        if hub_repo:
            from openbench._cli.export import export_logs_to_hub

            export_logs_to_hub(
                logfile=logfile,
                start_time=start_time,
                hub_repo=hub_repo,
                hub_private=hub_private,
            )
    except Exception as e:
        if debug:
            # In debug mode, let the full stack trace show
            raise
        else:
            # In normal mode, show clean error message
            error_msg = str(e)
            typer.secho(f"\n❌ Error: {error_msg}", fg=typer.colors.RED, err=True)
            typer.secho(
                "\nFor full stack trace, run with --debug flag",
                fg=typer.colors.CYAN,
                err=True,
            )
            sys.exit(1)

src/openbench/_cli/eval_retry_command.py (7.1 KiB)

from typing import List, Optional, Annotated
import sys
import typer
from inspect_ai import eval_retry
from inspect_ai.log._file import log_file_info
from inspect_ai._util.file import filesystem
from openbench.monkeypatch.display_results_patch import patch_display_results


def run_eval_retry(
    log_files: Annotated[
        List[str],
        typer.Argument(help="Log file(s) to retry failed evaluations from"),
    ],
    max_connections: Annotated[
        Optional[int],
        typer.Option(
            help="Maximum number of parallel requests to the model",
            envvar="BENCH_MAX_CONNECTIONS",
        ),
    ] = None,
    max_subprocesses: Annotated[
        Optional[int],
        typer.Option(
            help="Maximum number of parallel subprocesses",
            envvar="BENCH_MAX_SUBPROCESSES",
        ),
    ] = None,
    fail_on_error: Annotated[
        Optional[float],
        typer.Option(
            help="Failure threshold for sample errors. If between 0 and 1, it is interpreted as a percentage of samples that can fail. If greater than 1, it is interpreted as a fixed number of samples that can fail",
            envvar="BENCH_FAIL_ON_ERROR",
        ),
    ] = None,
    log_samples: Annotated[
        Optional[bool],
        typer.Option(
            help="Log detailed samples and scores",
            envvar="BENCH_LOG_SAMPLES",
        ),
    ] = None,
    log_images: Annotated[
        Optional[bool],
        typer.Option(
            help="Log base64 encoded images",
            envvar="BENCH_LOG_IMAGES",
        ),
    ] = None,
    log_buffer: Annotated[
        Optional[int],
        typer.Option(
            help="Number of samples to buffer before writing to log",
            envvar="BENCH_LOG_BUFFER",
        ),
    ] = 10,
    score: Annotated[
        bool,
        typer.Option(
            help="Grade the benchmark, or leave unscored",
            envvar="BENCH_SCORE",
        ),
    ] = True,
    timeout: Annotated[
        Optional[int],
        typer.Option(
            help="Timeout for each request to the model API in seconds",
            envvar="BENCH_TIMEOUT",
        ),
    ] = None,
    max_retries: Annotated[
        Optional[int],
        typer.Option(
            help="Maximum number of times to retry model API requests (defaults to unlimited)",
            envvar="BENCH_MAX_RETRIES",
        ),
    ] = None,
    retry_on_error: Annotated[
        Optional[int],
        typer.Option(
            help="Retry samples if they encounter errors (by default, no retries occur). Specify --retry-on-error to retry a single time, or specify e.g. --retry-on-error=3 to retry multiple times.",
            envvar="BENCH_RETRY_ON_ERROR",
        ),
    ] = None,
    no_fail_on_error: Annotated[
        bool,
        typer.Option(
            "--no-fail-on-error",
            help="Do not fail the eval if errors occur within samples (instead, continue running other samples)",
            envvar="BENCH_NO_FAIL_ON_ERROR",
        ),
    ] = False,
    no_log_samples: Annotated[
        bool,
        typer.Option(
            "--no-log-samples",
            help="Do not include samples in the log file",
            envvar="BENCH_NO_LOG_SAMPLES",
        ),
    ] = False,
    no_log_images: Annotated[
        bool,
        typer.Option(
            "--no-log-images",
            help="Do not include base64 encoded images in the log file",
            envvar="BENCH_NO_LOG_IMAGES",
        ),
    ] = False,
    no_score: Annotated[
        bool,
        typer.Option(
            "--no-score",
            help="Do not score model output (use the inspect score command to score output later)",
            envvar="BENCH_NO_SCORE",
        ),
    ] = False,
    sandbox_cleanup: Annotated[
        Optional[bool],
        typer.Option(
            help="Cleanup sandbox environments after task completes",
            envvar="BENCH_SANDBOX_CLEANUP",
        ),
    ] = None,
    no_sandbox_cleanup: Annotated[
        bool,
        typer.Option(
            "--no-sandbox-cleanup",
            help="Do not cleanup sandbox environments after task completes",
            envvar="BENCH_NO_SANDBOX_CLEANUP",
        ),
    ] = False,
    trace: Annotated[
        bool,
        typer.Option(
            "--trace",
            help="Trace message interactions with evaluated model to terminal",
            envvar="BENCH_TRACE",
        ),
    ] = False,
    log_dir: Annotated[
        str,
        typer.Option(
            help="Directory for log files",
            envvar="BENCH_LOG_DIR",
        ),
    ] = "./logs",
    debug_errors: Annotated[
        bool,
        typer.Option(
            "--debug-errors",
            help="Enable debug mode for errors",
            envvar="BENCH_DEBUG_ERRORS",
        ),
    ] = False,
    debug: Annotated[
        bool,
        typer.Option(
            "--debug",
            help="Enable debug mode with full stack traces",
            envvar="BENCH_DEBUG",
        ),
    ] = False,
) -> None:
    """Retry failed evaluation(s) from log files."""

    # Process negating options
    if no_log_samples:
        log_samples = False
    if no_log_images:
        log_images = False
    if no_score:
        score = False
    if no_sandbox_cleanup:
        sandbox_cleanup = False

    # Process fail_on_error
    if no_fail_on_error:
        fail_on_error = False
    elif fail_on_error == 0.0:
        fail_on_error = True

    # Process retry_on_error
    if retry_on_error == 0:
        retry_on_error = None

    # Resolve log files
    retry_log_files = [
        log_file_info(filesystem(log_file).info(log_file)) for log_file in log_files
    ]

    # Set defaults
    log_level = "info"
    log_level_transcript = "info"

    # Apply display patch
    patch_display_results()

    try:
        # Retry
        eval_retry(
            retry_log_files,
            log_level=log_level,
            log_level_transcript=log_level_transcript,
            log_dir=log_dir,
            max_connections=max_connections,
            max_subprocesses=max_subprocesses,
            fail_on_error=fail_on_error,
            retry_on_error=retry_on_error,
            debug_errors=debug_errors,
            log_samples=log_samples,
            log_images=log_images,
            log_buffer=log_buffer,
            score=score,
            timeout=timeout,
            max_retries=max_retries,
            sandbox_cleanup=sandbox_cleanup,
            trace=trace,
            # These are additional retry-specific parameters
            max_samples=None,
            max_tasks=None,
            max_sandboxes=None,
            log_shared=None,
            score_display=None,
        )

        typer.echo("Retry evaluation complete!")
    except Exception as e:
        if debug:
            # In debug mode, let the full stack trace show
            raise
        else:
            # In normal mode, show clean error message
            error_msg = str(e)
            typer.secho(f"\n❌ Error: {error_msg}", fg=typer.colors.RED, err=True)
            typer.secho(
                "\nFor full stack trace, run with --debug flag",
                fg=typer.colors.CYAN,
                err=True,
            )
            sys.exit(1)

src/openbench/_cli/export.py (6.5 KiB)

from __future__ import annotations

from typing import Optional, List, Dict, Any
import os
import json
import typer
from datasets import Dataset  # type: ignore[import-untyped]


def _read_log_json(path: str) -> Dict[str, Any]:
    """Read an Inspect log file regardless of .eval or .json format.

    Uses `inspect log dump` for .eval, else reads JSON directly.
    See Inspect docs:
    https://inspect.aisi.org.uk/eval-logs.html
    """
    if path.endswith(".eval"):
        import subprocess  # local import to avoid hard dep at import time

        proc = subprocess.run(
            ["inspect", "log", "dump", path],
            check=True,
            capture_output=True,
            text=True,
        )
        return json.loads(proc.stdout)

    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def _collect_log_files(logfile: Optional[str], start_time: float) -> List[str]:
    """Collect log files created by this run.

    Preference:
    - if `logfile` provided and exists, return it
    - else scan INSPECT_LOG_DIR or ./logs for files
      with mtime >= start_time
    """
    candidates: List[str] = []

    if logfile and os.path.exists(logfile):
        return [os.path.abspath(logfile)]

    log_dir = os.getenv("INSPECT_LOG_DIR") or os.path.join(os.getcwd(), "logs")
    if not os.path.isdir(log_dir):
        return candidates

    for name in os.listdir(log_dir):
        if not (name.endswith(".eval") or name.endswith(".json")):
            continue
        path = os.path.join(log_dir, name)
        try:
            mtime = os.path.getmtime(path)
        except OSError:
            continue
        recent = mtime >= (start_time - 1.0)
        if recent:
            candidates.append(os.path.abspath(path))

    candidates.sort(key=os.path.getmtime, reverse=True)
    return candidates


def _flatten_results(data: Dict[str, Any], base: Dict[str, Any]) -> Dict[str, Any]:
    results = data.get("results", {})
    out: Dict[str, Any] = {**base}
    out["total_samples"] = results.get("total_samples")
    out["completed_samples"] = results.get("completed_samples")
    scores = results.get("scores", [])
    if scores:
        metrics = scores[0].get("metrics", {})
        for metric_name, metric in metrics.items():
            out[metric_name] = metric.get("value")
    return out


def _flatten_stats(data: Dict[str, Any], base: Dict[str, Any]) -> List[Dict[str, Any]]:
    stats = data.get("stats", {})
    started_at = stats.get("started_at")
    completed_at = stats.get("completed_at")
    model_usage = stats.get("model_usage", {})
    rows: List[Dict[str, Any]] = []
    if isinstance(model_usage, dict):
        for model_name, usage in model_usage.items():
            row = {**base, "started_at": started_at}
            row["completed_at"] = completed_at
            row["usage_model"] = model_name
            row["input_tokens"] = usage.get("input_tokens")
            row["output_tokens"] = usage.get("output_tokens")
            # split for linter line length
            row["total_tokens"] = usage.get("total_tokens")
            rows.append(row)
    else:
        short = {
            **base,
            "started_at": started_at,
            "completed_at": completed_at,
        }
        rows.append(short)
    return rows


def _flatten_samples(
    data: Dict[str, Any], base: Dict[str, Any]
) -> List[Dict[str, Any]]:
    samples = data.get("samples", [])
    rows: List[Dict[str, Any]] = []
    for s in samples:
        row: Dict[str, Any] = {
            **base,
            "sample_id": s.get("id"),
            "epoch": s.get("epoch"),
            "target": s.get("target"),
            "messages": json.dumps(s.get("messages", [])),
        }
        if isinstance(s.get("metadata"), dict):
            for k, v in s["metadata"].items():
                row[f"meta_{k}"] = v
        if isinstance(s.get("scores"), dict):
            for scorer_name, score in s["scores"].items():
                row[f"score_{scorer_name}_value"] = score.get("value")
                row[f"score_{scorer_name}_answer"] = score.get("answer")
        rows.append(row)
    return rows


def export_logs_to_hub(
    *,
    logfile: Optional[str],
    start_time: float,
    hub_repo: str,
    hub_private: Optional[bool],
) -> None:
    files = _collect_log_files(logfile=logfile, start_time=start_time)
    if not files:
        msg = "No eval logs found to export (looked in INSPECT_LOG_DIR or ./logs)"
        typer.secho(msg, fg=typer.colors.YELLOW)
        return

    msg = f"Exporting {len(files)} eval logs to {hub_repo}"
    typer.secho(msg, fg=typer.colors.YELLOW)

    results_rows: List[Dict[str, Any]] = []
    stats_rows: List[Dict[str, Any]] = []
    samples_rows: List[Dict[str, Any]] = []

    for path in files:
        try:
            data = _read_log_json(path)
        except Exception as e:  # pragma: no cover
            msg = f"Skipping log '{path}': {e}"
            typer.secho(msg, fg=typer.colors.YELLOW)
            continue

        eval_info = data.get("eval", {})
        base = {
            "log_path": path,
            "eval_id": eval_info.get("eval_id"),
            "run_id": eval_info.get("run_id"),
            "created": eval_info.get("created"),
            "task": eval_info.get("task"),
            "task_id": eval_info.get("task_id"),
            "model": eval_info.get("model"),
        }

        results_rows.append(_flatten_results(data, base))
        stats_rows.extend(_flatten_stats(data, base))
        samples_rows.extend(_flatten_samples(data, base))

    if results_rows:
        ds = Dataset.from_list(results_rows)
        ds.push_to_hub(
            repo_id=hub_repo,
            config_name="results",
            split="train",
            private=hub_private,
        )
        msg = (
            f"Pushed results ({len(results_rows)} rows) to {hub_repo} [config=results]"
        )
        typer.echo(msg)

    if stats_rows:
        ds = Dataset.from_list(stats_rows)
        ds.push_to_hub(
            repo_id=hub_repo,
            config_name="stats",
            split="train",
            private=hub_private,
        )
        msg = f"Pushed stats ({len(stats_rows)} rows) to {hub_repo} [config=stats]"
        typer.echo(msg)

    if samples_rows:
        ds = Dataset.from_list(samples_rows)
        ds.push_to_hub(
            repo_id=hub_repo,
            config_name="samples",
            split="train",
            private=hub_private,
        )
        msg = (
            f"Pushed samples ({len(samples_rows)} rows) to {hub_repo} [config=samples]"
        )
        typer.echo(msg)

src/openbench/_cli/list_command.py (4.7 KiB)

"""
List command for benchmark evaluations.
"""

import typer
from typing import Optional, Dict, List, Any
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.text import Text
from rich.box import ROUNDED

from openbench.config import (
    get_all_benchmarks,
    get_benchmarks_by_category,
    get_categories,
)
from openbench._cli.utils import (
    get_category_display_name,
    benchmark_to_eval_config,
    matches_search,
)


def list_evals(
    category: Optional[str] = typer.Option(
        None, "--category", "-c", help="Filter by category (core, math)"
    ),
    search: Optional[str] = typer.Option(
        None, "--search", "-s", help="Search evaluations by name, description, or tags"
    ),
    tags: bool = typer.Option(
        False, "--tags", "-t", help="Show tags for each benchmark"
    ),
    alpha: bool = typer.Option(
        False, "--alpha", help="Include experimental/alpha benchmarks"
    ),
) -> None:
    """List available benchmark evaluations with enhanced UI."""
    console = Console()

    # Get evaluations based on filters
    if category:
        if category not in get_categories():
            console.print(f"\n❌ [red]Unknown category: {category}[/red]")
            console.print(f"   Available: {', '.join(sorted(get_categories()))}\n")
            return
        benchmarks = get_benchmarks_by_category(category, include_alpha=alpha)
        evals = [benchmark_to_eval_config(meta) for meta in benchmarks.values()]
    else:
        all_benchmarks = get_all_benchmarks(include_alpha=alpha)
        evals = [benchmark_to_eval_config(meta) for meta in all_benchmarks.values()]

    # Apply search filter
    if search:
        evals = [e for e in evals if matches_search(e, search)]

    if not evals:
        console.print("\n💭 [yellow]No evaluations match your criteria.[/yellow]\n")
        return

    # Group by category
    categories: Dict[str, List[Any]] = {}
    for eval_config in evals:
        if eval_config.category not in categories:
            categories[eval_config.category] = []
        categories[eval_config.category].append(eval_config)

    # Header
    console.print()
    if search:
        header = Text(f"Search Results for '{search}'", style="bold blue")
    else:
        header = Text("Available Benchmarks", style="bold blue")
    console.print(Panel(header, expand=False, box=ROUNDED))
    console.print()

    # Display each category
    for cat_name in sorted(categories.keys()):
        display_name = get_category_display_name(cat_name)

        # Category header with count
        cat_count = len(categories[cat_name])
        console.print(
            f"[bold green]{display_name}[/bold green] [dim]({cat_count})[/dim]"
        )
        console.print("─" * 60)

        # Get task names for this category
        cat_evals_with_keys = [
            (k, v)
            for k, v in get_all_benchmarks(include_alpha=alpha).items()
            if v.name in [e.name for e in categories[cat_name]]
        ]
        cat_evals_with_keys = sorted(cat_evals_with_keys, key=lambda x: x[0])

        # Create table for this category
        table = Table(show_header=False, show_lines=False, padding=(0, 1), box=None)
        table.add_column("Key", style="cyan", width=18)
        table.add_column("Name", style="white", width=20)
        table.add_column("Description", style="dim")

        for task_key, benchmark_meta in cat_evals_with_keys:
            # Find the corresponding eval config
            eval_config = next(
                e for e in categories[cat_name] if e.name == benchmark_meta.name
            )

            # Format description
            desc = eval_config.description
            if len(desc) > 60:
                desc = desc[:57] + "..."

            # Add tags if requested
            if tags and eval_config.tags:
                tag_str = (
                    " [dim blue]"
                    + " · ".join(f"#{tag}" for tag in eval_config.tags[:3])
                    + "[/dim blue]"
                )
                desc += tag_str

            table.add_row(f"[bold cyan]{task_key}[/bold cyan]", eval_config.name, desc)

        console.print(table)
        console.print()

    # Footer with stats and help
    total_count = len(evals)
    console.print("─" * 60)
    status_msg = f"[dim]Total: {total_count} benchmark{'s' if total_count != 1 else ''}"
    if not alpha:
        status_msg += " (use --alpha to see experimental benchmarks)"
    status_msg += "[/dim]"
    console.print(status_msg)
    console.print()
    console.print("[dim]Commands:[/dim]")
    console.print("   bench describe <name> - Show detailed information")
    console.print("   bench eval <name>     - Run evaluation")
    console.print()

src/openbench/_cli/utils.py (2.3 KiB)

"""
Utility functions shared across CLI commands.
"""

from types import SimpleNamespace
from typing import Any, Dict, List, Optional
import yaml
from openbench.config import BenchmarkMetadata


def get_category_display_name(category: str) -> str:
    """Generate display name for category."""
    display = category.replace("_", " ").replace("-", " ").title()
    if "benchmark" not in display.lower():
        display += " Benchmarks"
    return display


def benchmark_to_eval_config(meta: BenchmarkMetadata):
    """Convert BenchmarkMetadata to a simple eval config for display."""
    return SimpleNamespace(
        name=meta.name,
        description=meta.description,
        category=meta.category,
        tags=meta.tags,
    )


def matches_search(eval_config, query: str) -> bool:
    """Check if an evaluation matches the search query."""
    query = query.lower()
    return (
        query in eval_config.name.lower()
        or query in eval_config.description.lower()
        or any(query in tag.lower() for tag in eval_config.tags)
    )


def parse_cli_args(
    args: Optional[List[str]], force_str: bool = False
) -> Dict[str, Any]:
    """Parse CLI arguments in the format key=value.

    Args:
        args: List of arguments in the format key=value
        force_str: Force all values to be strings

    Returns:
        Dictionary of parsed arguments
    """
    params: Dict[str, Any] = {}
    if args:
        for arg in args:
            parts = arg.split("=", 1)
            if len(parts) == 2:
                key = parts[0].replace("-", "_")
                try:
                    # Try to parse with yaml for proper type conversion
                    value = yaml.safe_load(parts[1])
                    # Handle comma-separated values as lists
                    if isinstance(value, str) and "," in value:
                        value = value.split(",")
                        value = value if len(value) > 1 else value[0]
                except yaml.YAMLError:
                    # If parsing fails, treat as string
                    value = parts[1]
                params[key] = str(value) if force_str else value
            else:
                # If no '=' found, this is an invalid argument format
                raise ValueError(f"Invalid argument format: {arg}. Expected key=value")
    return params

src/openbench/_cli/view_command.py (1.5 KiB)

import subprocess
import sys
from typing import Optional
from pathlib import Path
import typer


def run_view(
    log_dir: Optional[Path] = typer.Option(
        None,
        "--log-dir",
        help="Log directory to view (defaults to ./logs)",
        envvar="INSPECT_LOG_DIR",
    ),
    recursive: bool = typer.Option(
        True,
        "--recursive/--no-recursive",
        help="Include all logs in log_dir recursively",
    ),
    host: Optional[str] = typer.Option(
        None,
        "--host",
        help="TCP/IP host for server",
    ),
    port: Optional[int] = typer.Option(
        None,
        "--port",
        help="TCP/IP port for server",
    ),
    log_level: Optional[str] = typer.Option(
        None,
        "--log-level",
        help="Set the log level",
    ),
) -> None:
    """
    View evaluation logs using inspect view.

    This is a wrapper around 'inspect view' that provides access to the log viewer.
    """
    cmd = ["inspect", "view"]

    # Add arguments if provided
    if log_dir:
        cmd.extend(["--log-dir", str(log_dir)])
    if not recursive:
        cmd.append("--no-recursive")
    if host:
        cmd.extend(["--host", host])
    if port:
        cmd.extend(["--port", str(port)])
    if log_level:
        cmd.extend(["--log-level", log_level])

    # Run the command
    try:
        subprocess.run(cmd, check=False)
    except KeyboardInterrupt:
        sys.exit(0)
    except Exception as e:
        typer.echo(f"Error running inspect view: {e}", err=True)
        sys.exit(1)

src/openbench/_registry.py (5.5 KiB)

"""
Registry for inspect_ai extensions (model providers and tasks).
This module is the entry point for inspect_ai to discover our extensions.
"""

from typing import Type
from inspect_ai.model import ModelAPI
from inspect_ai.model._registry import modelapi


# Model Provider Registration


@modelapi(name="huggingface")
def huggingface() -> Type[ModelAPI]:
    """Register Hugging Face Inference Providers router provider."""
    from .model._providers.huggingface import HFInferenceProvidersAPI

    return HFInferenceProvidersAPI


@modelapi(name="cerebras")
def cerebras() -> Type[ModelAPI]:
    """Register Cerebras provider."""
    from .model._providers.cerebras import CerebrasAPI

    return CerebrasAPI


@modelapi(name="sambanova")
def sambanova() -> Type[ModelAPI]:
    """Register SambaNova provider."""
    from .model._providers.sambanova import SambaNovaAPI

    return SambaNovaAPI


@modelapi(name="nebius")
def nebius() -> Type[ModelAPI]:
    """Register Nebius provider."""
    from .model._providers.nebius import NebiusAPI

    return NebiusAPI


@modelapi(name="nous")
def nous() -> Type[ModelAPI]:
    """Register Nous Research provider."""
    from .model._providers.nous import NousAPI

    return NousAPI


@modelapi(name="lambda")
def lambda_provider() -> Type[ModelAPI]:
    """Register Lambda provider."""
    from .model._providers.lambda_ai import LambdaAPI

    return LambdaAPI


@modelapi(name="baseten")
def baseten() -> Type[ModelAPI]:
    """Register Baseten provider."""
    from .model._providers.baseten import BasetenAPI

    return BasetenAPI


@modelapi(name="hyperbolic")
def hyperbolic() -> Type[ModelAPI]:
    """Register Hyperbolic provider."""
    from .model._providers.hyperbolic import HyperbolicAPI

    return HyperbolicAPI


@modelapi(name="novita")
def novita() -> Type[ModelAPI]:
    """Register Novita provider."""
    from .model._providers.novita import NovitaAPI

    return NovitaAPI


@modelapi(name="parasail")
def parasail() -> Type[ModelAPI]:
    """Register Parasail provider."""
    from .model._providers.parasail import ParasailAPI

    return ParasailAPI


@modelapi(name="crusoe")
def crusoe() -> Type[ModelAPI]:
    """Register Crusoe provider."""
    from .model._providers.crusoe import CrusoeAPI

    return CrusoeAPI


@modelapi(name="deepinfra")
def deepinfra() -> Type[ModelAPI]:
    """Register DeepInfra provider."""
    from .model._providers.deepinfra import DeepInfraAPI

    return DeepInfraAPI


@modelapi(name="ai21")
def ai21() -> Type[ModelAPI]:
    """Register AI21 Labs provider."""
    from .model._providers.ai21 import AI21API

    return AI21API


@modelapi(name="minimax")
def minimax() -> Type[ModelAPI]:
    """Register MiniMax provider."""
    from .model._providers.minimax import MiniMaxAPI

    return MiniMaxAPI


@modelapi(name="friendli")
def friendli() -> Type[ModelAPI]:
    """Register Friendli provider."""
    from .model._providers.friendli import FriendliAPI

    return FriendliAPI


@modelapi(name="reka")
def reka() -> Type[ModelAPI]:
    """Register Reka provider."""
    from .model._providers.reka import RekaAPI

    return RekaAPI


@modelapi(name="cohere")
def cohere() -> Type[ModelAPI]:
    """Register Cohere provider."""
    from .model._providers.cohere import CohereAPI

    return CohereAPI


@modelapi(name="moonshot")
def moonshot() -> Type[ModelAPI]:
    """Register Moonshot provider."""
    from .model._providers.moonshot import MoonshotAPI

    return MoonshotAPI


@modelapi(name="vercel")
def vercel() -> Type[ModelAPI]:
    """Register Vercel AI Gateway provider."""
    from .model._providers.vercel import VercelAPI

    return VercelAPI


# Task Registration

# Core benchmarks
from .evals.drop import drop  # noqa: F401, E402
from .evals.gpqa_diamond import gpqa_diamond  # noqa: F401, E402
from .evals.graphwalks import graphwalks  # noqa: F401, E402
from .evals.healthbench import healthbench, healthbench_hard, healthbench_consensus  # noqa: F401, E402
from .evals.hle import hle, hle_text  # noqa: F401, E402
from .evals.humaneval import humaneval  # noqa: F401, E402
from .evals.math import math, math_500  # noqa: F401, E402
from .evals.mgsm import mgsm, mgsm_en, mgsm_latin, mgsm_non_latin  # noqa: F401, E402
from .evals.mmlu import mmlu  # noqa: F401, E402
from .evals.mrcr import openai_mrcr, openai_mrcr_2n, openai_mrcr_4n, openai_mrcr_8n  # noqa: F401, E402
from .evals.musr import musr  # noqa: F401, E402
from .evals.openbookqa import openbookqa  # noqa: F401, E402
from .evals.simpleqa import simpleqa  # noqa: F401, E402
from .evals.supergpqa import supergpqa  # noqa: F401, E402

# MathArena benchmarks
from .evals.matharena.aime_2023_I.aime_2023_I import aime_2023_I  # noqa: F401, E402
from .evals.matharena.aime_2023_II.aime_2023_II import aime_2023_II  # noqa: F401, E402
from .evals.matharena.aime_2024_I.aime_2024_I import aime_2024_I  # noqa: F401, E402
from .evals.matharena.aime_2024_II.aime_2024_II import aime_2024_II  # noqa: F401, E402
from .evals.matharena.aime_2024.aime_2024 import aime_2024  # noqa: F401, E402
from .evals.matharena.aime_2025.aime_2025 import aime_2025  # noqa: F401, E402
from .evals.matharena.aime_2025_II.aime_2025_II import aime_2025_II  # noqa: F401, E402
from .evals.matharena.brumo_2025.brumo_2025 import brumo_2025  # noqa: F401, E402
from .evals.matharena.hmmt_feb_2023.hmmt_feb_2023 import hmmt_feb_2023  # noqa: F401, E402
from .evals.matharena.hmmt_feb_2024.hmmt_feb_2024 import hmmt_feb_2024  # noqa: F401, E402
from .evals.matharena.hmmt_feb_2025.hmmt_feb_2025 import hmmt_feb_2025  # noqa: F401, E402

src/openbench/config.py (27.2 KiB)

"""
Minimal configuration for benchmarks.
Only contains human-written metadata that cannot be extracted from code.
Everything else (epochs, temperature, etc.) comes from the actual task definitions.
"""

from dataclasses import dataclass
from functools import lru_cache
import importlib
import importlib.util
import sys
import uuid
from pathlib import Path
from types import ModuleType
from typing import Callable, List, Optional


@dataclass
class BenchmarkMetadata:
    """Minimal metadata for a benchmark - only what can't be extracted."""

    name: str  # Human-readable display name
    description: str  # Human-written description
    category: str  # Category for grouping
    tags: List[str]  # Tags for searchability

    # Registry info - still needed
    module_path: str
    function_name: str

    # Alpha/experimental flag
    is_alpha: bool = False  # Whether this benchmark is experimental/alpha


# Benchmark metadata - minimal, no duplication
BENCHMARKS = {
    # Graphwalks benchmarks (alpha)
    "graphwalks": BenchmarkMetadata(
        name="GraphWalks",
        description="Multi-hop reasoning on graphs - both BFS and parent finding tasks",
        category="core",
        tags=["long-context", "graphs", "reasoning", "alpha"],
        module_path="openbench.evals.graphwalks",
        function_name="graphwalks",
        is_alpha=True,
    ),
    "graphwalks_bfs": BenchmarkMetadata(
        name="GraphWalks BFS",
        description="Multi-hop reasoning on graphs - BFS traversal tasks only",
        category="core",
        tags=["long-context", "graphs", "reasoning", "bfs", "alpha"],
        module_path="openbench.evals.graphwalks",
        function_name="graphwalks_bfs",
        is_alpha=True,
    ),
    "graphwalks_parents": BenchmarkMetadata(
        name="GraphWalks Parents",
        description="Multi-hop reasoning on graphs - parent finding tasks only",
        category="core",
        tags=["long-context", "graphs", "reasoning", "parents", "alpha"],
        module_path="openbench.evals.graphwalks",
        function_name="graphwalks_parents",
        is_alpha=True,
    ),
    # Core benchmarks
    "mmlu": BenchmarkMetadata(
        name="MMLU (cais/mmlu)",
        description="Massive Multitask Language Understanding - 57 academic subjects from the cais/mmlu dataset",
        category="core",
        tags=["multiple-choice", "knowledge", "reasoning", "multitask"],
        module_path="openbench.evals.mmlu",
        function_name="mmlu",
    ),
    "openai_mrcr": BenchmarkMetadata(
        name="OpenAI MRCR (Full)",
        description="Memory-Recall with Contextual Retrieval - long-context evaluation that measures recall of 2, 4, and 8 needles across million-token contexts",
        category="core",
        tags=["long-context", "retrieval", "needle", "sequence-matching"],
        module_path="openbench.evals.mrcr",
        function_name="openai_mrcr",
    ),
    "openai_mrcr_2n": BenchmarkMetadata(
        name="OpenAI MRCR (2 Needles)",
        description="Memory-Recall with Contextual Retrieval - long-context evaluation that measures recall of 2 needles across million-token contexts",
        category="core",
        tags=["long-context", "retrieval", "needle", "sequence-matching"],
        module_path="openbench.evals.mrcr",
        function_name="openai_mrcr_2n",
    ),
    "openai_mrcr_4n": BenchmarkMetadata(
        name="OpenAI MRCR (4 Needles)",
        description="Memory-Recall with Contextual Retrieval - long-context evaluation that measures recall of 4 needles across million-token contexts",
        category="core",
        tags=["long-context", "retrieval", "needle", "sequence-matching"],
        module_path="openbench.evals.mrcr",
        function_name="openai_mrcr_4n",
    ),
    "openai_mrcr_8n": BenchmarkMetadata(
        name="OpenAI MRCR (8 Needles)",
        description="Memory-Recall with Contextual Retrieval - long-context evaluation that measures recall of 8 needles across million-token contexts",
        category="core",
        tags=["long-context", "retrieval", "needle", "sequence-matching"],
        module_path="openbench.evals.mrcr",
        function_name="openai_mrcr_8n",
    ),
    "gpqa_diamond": BenchmarkMetadata(
        name="GPQA Diamond",
        description="Graduate-level Google-Proof Q&A in biology, chemistry, and physics",
        category="core",
        tags=["multiple-choice", "science", "graduate-level"],
        module_path="openbench.evals.gpqa_diamond",
        function_name="gpqa_diamond",
    ),
    "humaneval": BenchmarkMetadata(
        name="HumanEval",
        description="Code generation benchmark with 164 programming problems",
        category="core",
        tags=["coding", "generation", "execution"],
        module_path="openbench.evals.humaneval",
        function_name="humaneval",
    ),
    "openbookqa": BenchmarkMetadata(
        name="OpenBookQA",
        description="Elementary-level science questions probing understanding of core facts",
        category="core",
        tags=["multiple-choice", "science", "elementary", "open-book"],
        module_path="openbench.evals.openbookqa",
        function_name="openbookqa",
    ),
    "musr": BenchmarkMetadata(
        name="MuSR",
        description="Testing the Limits of Chain-of-thought with Multistep Soft Reasoning - includes murder mysteries, object placements, and team allocation tasks",
        category="core",
        tags=["multiple-choice", "reasoning", "commonsense", "chain-of-thought"],
        module_path="openbench.evals.musr",
        function_name="musr",
    ),
    "musr_murder_mysteries": BenchmarkMetadata(
        name="MuSR Murder Mysteries",
        description="MuSR murder mystery scenarios - who is the most likely murderer?",
        category="core",
        tags=[
            "multiple-choice",
            "reasoning",
            "commonsense",
            "chain-of-thought",
            "murder-mysteries",
        ],
        module_path="openbench.evals.musr",
        function_name="musr_murder_mysteries",
    ),
    "musr_object_placements": BenchmarkMetadata(
        name="MuSR Object Placements",
        description="MuSR object placement reasoning - where would someone look for an object?",
        category="core",
        tags=[
            "multiple-choice",
            "reasoning",
            "commonsense",
            "chain-of-thought",
            "object-placements",
        ],
        module_path="openbench.evals.musr",
        function_name="musr_object_placements",
    ),
    "musr_team_allocation": BenchmarkMetadata(
        name="MuSR Team Allocation",
        description="MuSR team allocation problems - how to allocate people to tasks efficiently?",
        category="core",
        tags=[
            "multiple-choice",
            "reasoning",
            "commonsense",
            "chain-of-thought",
            "team-allocation",
        ],
        module_path="openbench.evals.musr",
        function_name="musr_team_allocation",
    ),
    "supergpqa": BenchmarkMetadata(
        name="SuperGPQA",
        description="Scaling LLM Evaluation across 285 Graduate Disciplines - 26,529 multiple-choice questions across science, engineering, medicine, economics, and philosophy",
        category="core",
        tags=["multiple-choice", "knowledge", "graduate-level", "multidisciplinary"],
        module_path="openbench.evals.supergpqa",
        function_name="supergpqa",
    ),
    "simpleqa": BenchmarkMetadata(
        name="SimpleQA",
        description="Measuring short-form factuality in large language models with simple Q&A pairs",
        category="core",
        tags=["factuality", "question-answering", "graded"],
        module_path="openbench.evals.simpleqa",
        function_name="simpleqa",
    ),
    "browsecomp": BenchmarkMetadata(
        name="BrowseComp",
        description="A Simple Yet Challenging Benchmark for Browsing Agents - evaluates model performance on browsing-related tasks",
        category="core",
        tags=["browsing", "web", "reasoning", "graded"],
        module_path="openbench.evals.browsecomp",
        function_name="browsecomp",
    ),
    "hle": BenchmarkMetadata(
        name="Humanity's Last Exam",
        description="Multi-modal benchmark at the frontier of human knowledge - 2,500 questions across mathematics, humanities, and natural sciences designed by subject-matter experts globally",
        category="core",
        tags=["knowledge", "reasoning", "multi-modal", "graded", "frontier"],
        module_path="openbench.evals.hle",
        function_name="hle",
    ),
    "hle_text": BenchmarkMetadata(
        name="Humanity's Last Exam (Text-Only)",
        description="Text-only variant of HLE with multi-modal questions filtered out - evaluates models without vision capabilities on text-based questions from the frontier of human knowledge",
        category="core",
        tags=["knowledge", "reasoning", "text-only", "graded", "frontier"],
        module_path="openbench.evals.hle",
        function_name="hle_text",
    ),
    "healthbench": BenchmarkMetadata(
        name="HealthBench",
        description="Medical dialogue evaluation using physician-created rubrics for assessing healthcare conversations",
        category="core",
        tags=["medical", "dialogue", "graded", "rubric-based"],
        module_path="openbench.evals.healthbench",
        function_name="healthbench",
    ),
    "healthbench_hard": BenchmarkMetadata(
        name="HealthBench Hard",
        description="Most challenging medical dialogue cases from HealthBench requiring nuanced medical knowledge",
        category="core",
        tags=["medical", "dialogue", "graded", "rubric-based", "hard"],
        module_path="openbench.evals.healthbench",
        function_name="healthbench_hard",
    ),
    "healthbench_consensus": BenchmarkMetadata(
        name="HealthBench Consensus",
        description="Medical dialogue cases with strong physician consensus on appropriate responses",
        category="core",
        tags=["medical", "dialogue", "graded", "rubric-based", "consensus"],
        module_path="openbench.evals.healthbench",
        function_name="healthbench_consensus",
    ),
    "mgsm": BenchmarkMetadata(
        name="MGSM",
        description="Multilingual Grade School Math benchmark across 11 languages for testing mathematical reasoning",
        category="core",
        tags=["math", "multilingual", "reasoning", "chain-of-thought"],
        module_path="openbench.evals.mgsm",
        function_name="mgsm",
    ),
    "mgsm_en": BenchmarkMetadata(
        name="MGSM English",
        description="Grade school math problems in English for testing mathematical reasoning",
        category="core",
        tags=["math", "english", "reasoning", "chain-of-thought"],
        module_path="openbench.evals.mgsm",
        function_name="mgsm_en",
    ),
    "mgsm_latin": BenchmarkMetadata(
        name="MGSM Latin Script",
        description="Grade school math problems in Latin script languages (German, English, Spanish, French, Swahili)",
        category="core",
        tags=["math", "multilingual", "latin-script", "reasoning", "chain-of-thought"],
        module_path="openbench.evals.mgsm",
        function_name="mgsm_latin",
    ),
    "mgsm_non_latin": BenchmarkMetadata(
        name="MGSM Non-Latin Script",
        description="Grade school math problems in non-Latin script languages (Bengali, Japanese, Russian, Telugu, Thai, Chinese)",
        category="core",
        tags=[
            "math",
            "multilingual",
            "non-latin-script",
            "reasoning",
            "chain-of-thought",
        ],
        module_path="openbench.evals.mgsm",
        function_name="mgsm_non_latin",
    ),
    "drop": BenchmarkMetadata(
        name="DROP",
        description="Reading comprehension benchmark requiring discrete reasoning over paragraphs (arithmetic, counting, sorting)",
        category="core",
        tags=[
            "reading-comprehension",
            "reasoning",
            "arithmetic",
            "counting",
            "sorting",
        ],
        module_path="openbench.evals.drop",
        function_name="drop",
    ),
    "math": BenchmarkMetadata(
        name="MATH",
        description="Measuring Mathematical Problem Solving - 5000 competition math problems across 7 subjects and 5 difficulty levels",
        category="core",
        tags=["math", "problem-solving", "reasoning", "competition", "graded"],
        module_path="openbench.evals.math",
        function_name="math",
    ),
    "math_500": BenchmarkMetadata(
        name="MATH-500",
        description="500-problem subset of MATH dataset for faster evaluation of mathematical problem solving",
        category="core",
        tags=[
            "math",
            "problem-solving",
            "reasoning",
            "competition",
            "graded",
            "subset",
        ],
        module_path="openbench.evals.math",
        function_name="math_500",
    ),
    # Math competitions
    "aime_2023_I": BenchmarkMetadata(
        name="AIME 2023 I",
        description="American Invitational Mathematics Examination 2023 (First)",
        category="math",
        tags=["math", "competition", "aime", "2023"],
        module_path="openbench.evals.matharena.aime_2023_I.aime_2023_I",
        function_name="aime_2023_I",
    ),
    "aime_2023_II": BenchmarkMetadata(
        name="AIME 2023 II",
        description="American Invitational Mathematics Examination 2023 (Second)",
        category="math",
        tags=["math", "competition", "aime", "2023"],
        module_path="openbench.evals.matharena.aime_2023_II.aime_2023_II",
        function_name="aime_2023_II",
    ),
    "aime_2024": BenchmarkMetadata(
        name="AIME 2024",
        description="American Invitational Mathematics Examination 2024 (Combined I & II)",
        category="math",
        tags=["math", "competition", "aime", "2024", "combined"],
        module_path="openbench.evals.matharena.aime_2024.aime_2024",
        function_name="aime_2024",
    ),
    "aime_2024_I": BenchmarkMetadata(
        name="AIME 2024 I",
        description="American Invitational Mathematics Examination 2024 (First)",
        category="math",
        tags=["math", "competition", "aime", "2024"],
        module_path="openbench.evals.matharena.aime_2024_I.aime_2024_I",
        function_name="aime_2024_I",
    ),
    "aime_2024_II": BenchmarkMetadata(
        name="AIME 2024 II",
        description="American Invitational Mathematics Examination 2024 (Second)",
        category="math",
        tags=["math", "competition", "aime", "2024"],
        module_path="openbench.evals.matharena.aime_2024_II.aime_2024_II",
        function_name="aime_2024_II",
    ),
    "aime_2025": BenchmarkMetadata(
        name="AIME 2025",
        description="American Invitational Mathematics Examination 2025",
        category="math",
        tags=["math", "competition", "aime", "2025"],
        module_path="openbench.evals.matharena.aime_2025.aime_2025",
        function_name="aime_2025",
    ),
    "aime_2025_II": BenchmarkMetadata(
        name="AIME 2025 II",
        description="American Invitational Mathematics Examination 2025 (Second)",
        category="math",
        tags=["math", "competition", "aime", "2025"],
        module_path="openbench.evals.matharena.aime_2025_II.aime_2025_II",
        function_name="aime_2025_II",
    ),
    "brumo_2025": BenchmarkMetadata(
        name="BRUMO 2025",
        description="Bruno Mathematical Olympiad 2025",
        category="math",
        tags=["math", "competition", "olympiad", "2025"],
        module_path="openbench.evals.matharena.brumo_2025.brumo_2025",
        function_name="brumo_2025",
    ),
    "hmmt_feb_2023": BenchmarkMetadata(
        name="HMMT Feb 2023",
        description="Harvard-MIT Mathematics Tournament February 2023",
        category="math",
        tags=["math", "competition", "hmmt", "2023"],
        module_path="openbench.evals.matharena.hmmt_feb_2023.hmmt_feb_2023",
        function_name="hmmt_feb_2023",
    ),
    "hmmt_feb_2024": BenchmarkMetadata(
        name="HMMT Feb 2024",
        description="Harvard-MIT Mathematics Tournament February 2024",
        category="math",
        tags=["math", "competition", "hmmt", "2024"],
        module_path="openbench.evals.matharena.hmmt_feb_2024.hmmt_feb_2024",
        function_name="hmmt_feb_2024",
    ),
    "hmmt_feb_2025": BenchmarkMetadata(
        name="HMMT Feb 2025",
        description="Harvard-MIT Mathematics Tournament February 2025",
        category="math",
        tags=["math", "competition", "hmmt", "2025"],
        module_path="openbench.evals.matharena.hmmt_feb_2025.hmmt_feb_2025",
        function_name="hmmt_feb_2025",
    ),
    "scicode": BenchmarkMetadata(
        name="SciCode",
        description="Scientific computing and programming challenges",
        category="core",
        tags=["code-generation", "science", "alpha"],
        module_path="openbench.evals.scicode",
        function_name="scicode",
        is_alpha=True,
    ),
    "cti_bench": BenchmarkMetadata(
        name="CTI-Bench",
        description="Comprehensive evaluation framework for cyber threat intelligence understanding with 4 tasks: knowledge questions, vulnerability classification, CVSS scoring, and technique extraction",
        category="cybersecurity",
        tags=["cybersecurity", "multi-task"],
        module_path="openbench.evals.cti_bench",
        function_name="cti_bench",
    ),
    "cti_bench_ate": BenchmarkMetadata(
        name="CTI-Bench ATE",
        description="Extracting MITRE ATT&CK techniques from malware and threat descriptions",
        category="cybersecurity",
        tags=["extraction", "cybersecurity"],
        module_path="openbench.evals.cti_bench",
        function_name="cti_bench_ate",
    ),
    "cti_bench_mcq": BenchmarkMetadata(
        name="CTI-Bench MCQ",
        description="Multiple-choice questions evaluating understanding of CTI standards, threats, detection strategies, and best practices using authoritative sources like NIST and MITRE",
        category="cybersecurity",
        tags=["multiple-choice", "cybersecurity", "knowledge"],
        module_path="openbench.evals.cti_bench",
        function_name="cti_bench_mcq",
    ),
    "cti_bench_rcm": BenchmarkMetadata(
        name="CTI-Bench RCM",
        description="Mapping CVE descriptions to CWE categories to evaluate vulnerability classification ability",
        category="cybersecurity",
        tags=["classification", "cybersecurity"],
        module_path="openbench.evals.cti_bench",
        function_name="cti_bench_rcm",
    ),
    "cti_bench_vsp": BenchmarkMetadata(
        name="CTI-Bench VSP",
        description="Calculating CVSS scores from vulnerability descriptions to assess severity evaluation skills",
        category="cybersecurity",
        tags=["regression", "cybersecurity"],
        module_path="openbench.evals.cti_bench",
        function_name="cti_bench_vsp",
    ),
    "rootly_gmcq": BenchmarkMetadata(
        name="GMCQ",
        description="GitHub Multiple Choice Questions",
        category="core",
        tags=["code-understanding"],
        module_path="openbench.evals.rootly_gmcq",
        function_name="rootly_gmcq",
    ),
    "jsonschemabench": BenchmarkMetadata(
        name="JSONSchemaBench",
        description="JSON Schema generation benchmark with ~10K real-world schemas from GitHub, Kubernetes, and other sources for evaluating constrained decoding",
        category="core",
        tags=["json", "jsonschema", "generation", "constrained-decoding"],
        module_path="openbench.evals.jsonschemabench",
        function_name="jsonschemabench",
    ),
}


def get_benchmark_metadata(name: str) -> Optional[BenchmarkMetadata]:
    """Get benchmark metadata by name."""
    return BENCHMARKS.get(name)


def get_all_benchmarks(include_alpha: bool = False) -> dict[str, BenchmarkMetadata]:
    """Get all benchmark metadata.

    Args:
        include_alpha: Whether to include alpha/experimental benchmarks
    """
    if include_alpha:
        return BENCHMARKS
    return {name: meta for name, meta in BENCHMARKS.items() if not meta.is_alpha}


def get_benchmarks_by_category(
    category: str, include_alpha: bool = False
) -> dict[str, BenchmarkMetadata]:
    """Get all benchmarks in a category.

    Args:
        category: Category to filter by
        include_alpha: Whether to include alpha/experimental benchmarks
    """
    results = {
        name: meta for name, meta in BENCHMARKS.items() if meta.category == category
    }
    if not include_alpha:
        results = {name: meta for name, meta in results.items() if not meta.is_alpha}
    return results


def get_categories() -> List[str]:
    """Get all available categories."""
    return sorted(list(set(meta.category for meta in BENCHMARKS.values())))


def search_benchmarks(
    query: str, include_alpha: bool = False
) -> dict[str, BenchmarkMetadata]:
    """Search benchmarks by name, description, or tags.

    Args:
        query: Search query
        include_alpha: Whether to include alpha/experimental benchmarks
    """
    query = query.lower()
    results = {}

    for name, meta in BENCHMARKS.items():
        if not include_alpha and meta.is_alpha:
            continue
        if (
            query in meta.name.lower()
            or query in meta.description.lower()
            or any(query in tag.lower() for tag in meta.tags)
        ):
            results[name] = meta

    return results


# ============================================================================
# Task Loading for CLI
# ============================================================================


def _generate_task_registry(include_alpha: bool = True):
    """Generate task registry from config.

    Args:
        include_alpha: Whether to include alpha/experimental benchmarks
    """
    registry = {}
    for name, metadata in get_all_benchmarks(include_alpha=include_alpha).items():
        registry[name] = f"{metadata.module_path}.{metadata.function_name}"
    return registry


# Full registry including alpha benchmarks for backward compatibility
TASK_REGISTRY = _generate_task_registry(include_alpha=True)


def _import_module_from_path(path: Path) -> ModuleType:
    """
    Import a .py file or package directory as an anonymous module.
    """
    file_path = path
    if path.is_dir():
        file_path = path / "__init__.py"
        if not file_path.exists():
            raise ValueError(f"{path} is a directory but has no __init__.py")

    mod_name = f"_openbench_dyn_{uuid.uuid4().hex}"
    spec = importlib.util.spec_from_file_location(mod_name, str(file_path))
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot create import spec for {file_path}")

    module = importlib.util.module_from_spec(spec)

    # For packages, set up proper package structure for relative imports
    if path.is_dir():
        module.__package__ = mod_name
        sys.modules[mod_name] = module

        # Pre-load submodules to support relative imports
        for submodule_file in path.glob("*.py"):
            if submodule_file.name != "__init__.py":
                submodule_name = submodule_file.stem
                submodule_full_name = f"{mod_name}.{submodule_name}"
                submodule_spec = importlib.util.spec_from_file_location(
                    submodule_full_name, str(submodule_file)
                )
                if submodule_spec and submodule_spec.loader:
                    submodule = importlib.util.module_from_spec(submodule_spec)
                    submodule.__package__ = mod_name
                    sys.modules[submodule_full_name] = submodule
                    submodule_spec.loader.exec_module(submodule)
    else:
        sys.modules[mod_name] = module

    spec.loader.exec_module(module)
    return module


@lru_cache()
def load_task(benchmark_name: str, allow_alpha: bool = False) -> Callable:
    """
    Loads a task by benchmark name using the registry or from a local path.

    Args:
        benchmark_name (str): The name of the benchmark or path to a local eval.
        allow_alpha (bool): Whether to allow loading alpha/experimental benchmarks.

    Returns:
        Callable: The imported function object.

    Raises:
        ValueError: If the benchmark is not in the registry and not a valid path.
        ImportError: If the module cannot be imported.
        AttributeError: If the function does not exist in the module.
    """
    # Check if this is an alpha benchmark
    benchmark_meta = get_benchmark_metadata(benchmark_name)
    if benchmark_meta and benchmark_meta.is_alpha and not allow_alpha:
        raise ValueError(
            f"'{benchmark_name}' is an experimental/alpha benchmark. "
            f"Use --alpha flag to run it."
        )

    # Try registry first (registry names take precedence)
    import_path = TASK_REGISTRY.get(benchmark_name)
    if import_path:
        module_path, func_name = import_path.rsplit(".", 1)
        module = importlib.import_module(module_path)
        return getattr(module, func_name)

    # Fallback to path-based loading
    path = Path(benchmark_name).expanduser()
    if path.exists():
        return _load_task_from_local_path(path)

    # Neither registry nor valid path
    raise ValueError(
        f"Unknown benchmark: '{benchmark_name}'. "
        f"Available benchmarks: {', '.join(TASK_REGISTRY.keys())}"
    )


def _load_task_from_local_path(path: Path) -> Callable:
    """
    Load a task from a local path containing __metadata__.

    Args:
        path: Path to a directory or .py file containing an eval

    Returns:
        Callable: The imported function object

    Raises:
        ValueError: If no valid __metadata__ is found
        AttributeError: If the function does not exist in the module
        ImportError: If the module cannot be imported
    """
    root_module = _import_module_from_path(path)
    metadata = getattr(root_module, "__metadata__", None)

    if not isinstance(metadata, BenchmarkMetadata):
        raise ValueError(f"{path} has no valid __metadata__")

    # Resolve module path relative to root module
    # For local evals, module_path is typically relative like "simpleqa.simpleqa"
    # We need to extract just the last part and combine with the root module name
    if metadata.module_path.startswith(root_module.__name__):
        full_module_name = metadata.module_path
    else:
        # For paths like "simpleqa.simpleqa", we want the last component "simpleqa"
        module_components = metadata.module_path.split(".")
        module_name = module_components[-1]  # Take the last component
        full_module_name = f"{root_module.__name__}.{module_name}"

    try:
        module = importlib.import_module(full_module_name)
    except ImportError as e:
        raise ImportError(f"Cannot import module '{full_module_name}': {e}")

    try:
        return getattr(module, metadata.function_name)
    except AttributeError:
        raise AttributeError(
            f"Function '{metadata.function_name}' not found in module '{full_module_name}'"
        )


def get_eval_metadata(path_like: str) -> BenchmarkMetadata | None:
    """
    Best-effort extraction of __metadata__ for path-based evals.
    Returns None for registry-based benchmarks or when no metadata is present.
    """
    p = Path(path_like).expanduser()
    if not p.exists():
        return None

    try:
        module = _import_module_from_path(p)
        meta = getattr(module, "__metadata__", None)
        return meta if isinstance(meta, BenchmarkMetadata) else None
    except Exception:
        return None

src/openbench/datasets/__init__.py (0 B)


src/openbench/datasets/browsecomp.py (2.1 KiB)

"""Dataset loader for BrowseComp: A Simple Yet Challenging Benchmark for Browsing Agents.

https://openai.com/index/browsecomp/
"""

import base64
import hashlib
from inspect_ai.dataset import Dataset, csv_dataset, Sample, MemoryDataset


def derive_key(password: str, length: int) -> bytes:
    """Derive a fixed-length key from the password using SHA256."""
    hasher = hashlib.sha256()
    hasher.update(password.encode())
    key = hasher.digest()
    return key * (length // len(key)) + key[: length % len(key)]


def decrypt(ciphertext_b64: str, password: str) -> str:
    """Decrypt base64-encoded ciphertext with XOR."""
    encrypted = base64.b64decode(ciphertext_b64)
    key = derive_key(password, len(encrypted))
    decrypted = bytes(a ^ b for a, b in zip(encrypted, key))
    return decrypted.decode()


def record_to_sample(record: dict) -> Sample:
    """Convert a BrowseComp CSV record to an Inspect Sample."""
    # Decrypt the problem and answer using the canary
    problem = decrypt(record.get("problem", ""), record.get("canary", ""))
    answer = decrypt(record.get("answer", ""), record.get("canary", ""))

    # Format the input with the query template
    formatted_input = f"""{problem}

Your response should be in the following format:
Explanation: {{your explanation for your final answer}}
Exact Answer: {{your succinct, final answer}}
Confidence: {{your confidence score between 0% and 100% for your answer}}"""

    return Sample(
        input=formatted_input,
        target=answer,
        metadata={
            "canary": record.get("canary", ""),
            "plain_question": problem,  # Store the plain question for the grader
        },
    )


def get_dataset() -> Dataset:
    """Load the BrowseComp dataset.

    Returns:
        Dataset containing BrowseComp samples
    """
    # Load the full dataset
    dataset = csv_dataset(
        csv_file="https://openaipublic.blob.core.windows.net/simple-evals/browse_comp_test_set.csv",
        sample_fields=record_to_sample,
        auto_id=True,
        name="browsecomp",
    )

    # Convert to list of samples
    samples = list(dataset)

    return MemoryDataset(samples=samples, name="browsecomp")

src/openbench/datasets/cti_bench.py (3.6 KiB)

"""CTI-Bench dataset loaders for cybersecurity threat intelligence benchmarks."""

from typing import Any, Dict
from inspect_ai.dataset import Dataset, Sample, hf_dataset


def mcq_record_to_sample(record: Dict[str, Any]) -> Sample:
    """Convert MCQ record to Sample format."""
    question = record["Question"]

    # Format options as A) ... B) ... C) ... D) ...
    formatted_options = [
        f"{chr(65 + i)}) {record[f'Option {chr(65 + i)}']}"
        for i in range(4)  # A, B, C, D
    ]

    prompt = f"{question}\n\n" + "\n".join(formatted_options) + "\n\nAnswer:"

    return Sample(
        input=prompt,
        target=record["GT"],
        metadata={
            "question_type": "multiple_choice",
            "domain": "cybersecurity",
            "url": record.get("URL", ""),
        },
    )


def rcm_record_to_sample(record: Dict[str, Any]) -> Sample:
    """Convert RCM (CVE→CWE mapping) record to Sample format."""
    description = record["Description"]

    prompt = f"""Given the following vulnerability description, identify the most appropriate CWE (Common Weakness Enumeration) category.

Description: {description}

Respond with only the CWE ID (e.g., CWE-79):"""

    return Sample(
        input=prompt,
        target=record["GT"],
        metadata={
            "task_type": "classification",
            "domain": "vulnerability_mapping",
            "url": record.get("URL", ""),
        },
    )


def vsp_record_to_sample(record: Dict[str, Any]) -> Sample:
    """Convert VSP (CVSS severity prediction) record to Sample format."""
    description = record["Description"]

    prompt = f"""Given the following vulnerability description, predict the CVSS (Common Vulnerability Scoring System) base score.

Description: {description}

The CVSS base score ranges from 0.0 to 10.0, where:
- 0.1-3.9: Low severity
- 4.0-6.9: Medium severity  
- 7.0-8.9: High severity
- 9.0-10.0: Critical severity

Respond with only the numeric CVSS score (e.g., 7.5):"""

    return Sample(
        input=prompt,
        target=record["GT"],
        metadata={
            "task_type": "regression",
            "domain": "vulnerability_scoring",
            "url": record.get("URL", ""),
        },
    )


def ate_record_to_sample(record: Dict[str, Any]) -> Sample:
    """Convert ATE (ATT&CK Technique Extraction) record to Sample format."""
    prompt = record["Prompt"]

    return Sample(
        input=prompt,
        target=record["GT"],
        metadata={
            "task_type": "technique_extraction",
            "domain": "mitre_attack",
            "url": record.get("URL", ""),
            "platform": record.get("Platform", ""),
            "description": record.get("Description", ""),
        },
    )


def get_cti_bench_mcq_dataset() -> Dataset:
    """Load CTI-Bench MCQ dataset."""
    return hf_dataset(
        path="AI4Sec/cti-bench",
        name="cti-mcq",
        split="test",
        sample_fields=mcq_record_to_sample,
    )


def get_cti_bench_rcm_dataset() -> Dataset:
    """Load CTI-Bench RCM dataset."""
    return hf_dataset(
        path="AI4Sec/cti-bench",
        name="cti-rcm",
        split="test",
        sample_fields=rcm_record_to_sample,
    )


def get_cti_bench_vsp_dataset() -> Dataset:
    """Load CTI-Bench VSP dataset."""
    return hf_dataset(
        path="AI4Sec/cti-bench",
        name="cti-vsp",
        split="test",
        sample_fields=vsp_record_to_sample,
    )


def get_cti_bench_ate_dataset() -> Dataset:
    """Load CTI-Bench ATE (ATT&CK Technique Extraction) dataset."""
    return hf_dataset(
        path="AI4Sec/cti-bench",
        name="cti-ate",
        split="test",
        sample_fields=ate_record_to_sample,
    )

src/openbench/datasets/drop.py (3.5 KiB)

"""DROP dataset loader for Inspect AI."""

import gzip
import json
import random
from io import BytesIO
from urllib.request import urlopen

from inspect_ai.dataset import Dataset, MemoryDataset, Sample


def record_to_sample(record: dict) -> Sample:
    """Convert a DROP record to an Inspect Sample."""
    # Format the context and question
    context = record["context"]
    completion = record["completion"]

    # Format input as context + completion (which contains the question)
    input_text = f"{context}\n\n{completion}"

    # Get reference answers (can be multiple, separated by |)
    target = record.get("ref_text", "")

    return Sample(
        input=input_text,
        target=target,
        metadata={
            "context": context,
            "completion": completion,
            "ref_text": target,
        },
    )


def get_dataset(
    num_examples: int | None = None,
    train_samples_per_prompt: int = 3,
    seed: int = 42,
) -> Dataset:
    """Load the DROP dataset.

    Args:
        num_examples: Number of examples to use (None for all)
        train_samples_per_prompt: Number of training examples for few-shot prompting
        seed: Random seed for sampling

    Returns:
        Dataset ready for evaluation
    """
    # URLs for the DROP dataset
    train_url = (
        "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz"
    )
    test_url = (
        "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz"
    )

    # Load training samples for few-shot examples
    with gzip.GzipFile(fileobj=BytesIO(urlopen(train_url).read()), mode="rb") as f:
        train_samples = [json.loads(line) for line in f.readlines()]

    # Load test samples
    with gzip.GzipFile(fileobj=BytesIO(urlopen(test_url).read()), mode="rb") as f:
        test_samples = [json.loads(line) for line in f.readlines()]

    # Sample if requested
    if num_examples:
        rng = random.Random(seed)
        test_samples = rng.sample(test_samples, min(num_examples, len(test_samples)))

    # Convert to Inspect samples
    samples = []
    rng = random.Random(seed)

    for test_sample in test_samples:
        # Get few-shot examples
        few_shot_examples = rng.sample(train_samples, train_samples_per_prompt)

        # Build the prompt with few-shot examples
        prompt_parts = [
            "You will be asked to read a passage and answer a question. Some examples of passages and Q&A are provided below.",
            "\n# Examples",
        ]

        # Add few-shot examples
        for example in few_shot_examples:
            prompt_parts.append("\n---")
            prompt_parts.append(example["context"])
            prompt_parts.append(example["completion"])

        # Add the test example
        prompt_parts.append("\n# Your Task\n---")
        prompt_parts.append(test_sample["context"])
        prompt_parts.append(
            '\nThink step by step, then write a line of the form "Answer: $ANSWER" at the end of your response.'
        )

        # Create the sample
        sample = Sample(
            input="\n".join(prompt_parts),
            target=test_sample.get("ref_text", ""),
            metadata={
                "context": test_sample["context"],
                "completion": test_sample["completion"],
                "ref_text": test_sample.get("ref_text", ""),
                "train_samples": few_shot_examples,
            },
        )
        samples.append(sample)

    return MemoryDataset(samples=samples, name="drop")

src/openbench/datasets/gpqa.py (1.1 KiB)

import random
from inspect_ai.dataset import Dataset, Sample, csv_dataset
from openbench.utils.text import MULTIPLE_CHOICE_PROMPT_TEMPLATE


def record_to_sample(record: dict) -> Sample:
    random.seed(0)
    options = [
        record["Correct Answer"],
        record["Incorrect Answer 1"],
        record["Incorrect Answer 2"],
        record["Incorrect Answer 3"],
    ]
    random.shuffle(options)
    # Get index of correct answer and convert to A, B, C, D
    correct_index = options.index(record["Correct Answer"])
    correct_letter = "ABCD"[correct_index]
    return Sample(
        input=MULTIPLE_CHOICE_PROMPT_TEMPLATE.format(
            prompt=record["Question"],
            option_a=options[0],
            option_b=options[1],
            option_c=options[2],
            option_d=options[3],
        ),
        target=correct_letter,
    )


def get_dataset() -> Dataset:
    return csv_dataset(
        "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv",
        sample_fields=record_to_sample,
        auto_id=True,
        name="gpqa_simple_eval",
    )

src/openbench/datasets/graphwalks.py (1.6 KiB)

# src/openbench/datasets/graphwalks.py
from __future__ import annotations
from typing import Any, Optional
from inspect_ai.dataset import Dataset, Sample, hf_dataset

_ALLOWED = {"bfs", "parents"}


def record_to_sample(
    record: dict[str, Any], *, allowed: Optional[set[str]] = None
) -> Sample | list[Sample]:
    """
    Map one HF row to an Inspect Sample.
    If `allowed` is provided, drop rows whose problem_type isn't in it by returning empty list.
    """
    problem_type = (record.get("problem_type") or "").strip().lower()

    # Filter here by returning empty list (row is skipped)
    if allowed is not None and problem_type not in allowed:
        return []

    gold = record.get("answer", record.get("answer_nodes", []))

    return Sample(
        input=record["prompt"],
        target=gold,
        metadata={
            "problem_type": problem_type,
            "prompt_chars": record.get("prompt_chars"),
        },
    )


def get_dataset(split: str = "train", task_type: str = "both") -> Dataset:
    """
    task_type: 'bfs' | 'parents' | 'both' (default: keep all)
    """
    task = (task_type or "both").strip().lower()
    if task in ("both", "all", "*"):
        allowed = None
    elif task in _ALLOWED:
        allowed = {task}
    else:
        raise ValueError("task_type must be one of 'bfs', 'parents', 'both'")

    def _map_sample(rec: dict[str, Any]) -> Sample | list[Sample]:
        return record_to_sample(rec, allowed=allowed)

    return hf_dataset(
        path="openai/graphwalks",
        split=split,
        sample_fields=_map_sample,
    )

src/openbench/datasets/healthbench.py (2.0 KiB)

"""HealthBench dataset loader."""

import json
from typing import Any, Dict, Optional

import httpx
from inspect_ai.dataset import Dataset, MemoryDataset, Sample


INPUT_PATH = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/2025-05-07-06-14-12_oss_eval.jsonl"
INPUT_PATH_HARD = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/hard_2025-05-08-21-00-10.jsonl"
INPUT_PATH_CONSENSUS = "https://openaipublic.blob.core.windows.net/simple-evals/healthbench/consensus_2025-05-09-20-00-46.jsonl"


def record_to_sample(record: Dict[str, Any]) -> Sample:
    """Convert a HealthBench record to an Inspect Sample."""
    return Sample(
        id=record.get("prompt_id", ""),
        input=record["prompt"],  # Keep as message list for chat format
        target="",  # No single target - we grade against rubrics
        metadata={
            "rubrics": record["rubrics"],
            "example_tags": record.get("example_tags", []),
            "prompt_id": record.get("prompt_id", ""),
        },
    )


def get_dataset(subset: Optional[str] = None) -> Dataset:
    """Load the HealthBench dataset.

    Args:
        subset: Which subset to load ("hard", "consensus", or None for main)

    Returns:
        Dataset configured for HealthBench evaluation
    """
    # Select URL based on subset
    if subset == "hard":
        url = INPUT_PATH_HARD
    elif subset == "consensus":
        url = INPUT_PATH_CONSENSUS
    elif subset is None:
        url = INPUT_PATH
    else:
        raise ValueError(f"Invalid subset: {subset}")

    # Download and parse the JSONL file
    response = httpx.get(url)
    response.raise_for_status()

    examples = []
    for line in response.text.strip().split("\n"):
        if line:
            examples.append(json.loads(line))

    # Convert to samples
    samples = [record_to_sample(record) for record in examples]

    dataset_name = f"healthbench_{subset}" if subset else "healthbench"
    return MemoryDataset(samples=samples, name=dataset_name)

src/openbench/datasets/hle.py (1.4 KiB)

from inspect_ai.dataset import Dataset, Sample, MemoryDataset, hf_dataset


def record_to_sample(record: dict) -> Sample:
    """Convert an HLE record to an Inspect Sample."""
    # Format the input with the system prompt used in HLE
    input_text = record["question"]

    # Include metadata for tracking
    metadata = {
        "question_id": record["id"],
    }

    # Add image if present (for multi-modal questions)
    if record.get("image"):
        metadata["image_url"] = record["image"]

    return Sample(
        input=input_text,
        target=record["answer"],
        id=record["id"],
        metadata=metadata,
    )


def get_dataset(text_only: bool = False) -> Dataset:
    """Load the HLE (Humanity's Last Exam) dataset.

    Args:
        text_only: If True, filter out multi-modal questions with images

    Returns:
        Dataset with HLE questions and answers
    """
    # Load the dataset from HuggingFace (no 'name' parameter - uses default config)
    dataset = hf_dataset(
        "cais/hle",
        split="test",
        sample_fields=record_to_sample,
    )

    # Convert to list for MemoryDataset
    samples = list(dataset)

    # Filter out image questions if text_only is True
    if text_only:
        samples = [
            s for s in samples if not (s.metadata and s.metadata.get("image_url"))
        ]
        dataset_name = "hle_text"
    else:
        dataset_name = "hle"

    return MemoryDataset(samples=samples, name=dataset_name)

src/openbench/datasets/humaneval.py (1.7 KiB)

from typing import Any, Callable

from inspect_ai.dataset import Sample, Dataset, hf_dataset

HUMANEVAL_INSTRUCTION = """
    Read the following function signature and docstring, and fully implement
    the function described. Your response should only contain the code for
    this function.
    """.strip()


# Adapted from https://github.com/UKGovernmentBEIS/inspect_evals
def record_to_sample(
    instruction_prompt: str = HUMANEVAL_INSTRUCTION,
) -> Callable[[dict[str, Any]], Sample]:
    """
    Convert a HumanEval record to a Sample for evaluation.

    Args:
        instruction_prompt (str): The prompt to prepend to the code problem.

    Returns:
        Callable[[dict[str, Any]], Sample]: Function to convert a record dict to a Sample.
    """

    def _record_to_sample(record: dict[str, Any]) -> Sample:
        return Sample(
            id=record["task_id"],
            input=instruction_prompt + record["prompt"],
            target=record["canonical_solution"],
            metadata={
                "prompt": record["prompt"],
                "test": record["test"],
                "entry_point": record["entry_point"],
            },
        )

    return _record_to_sample


# Adapted from https://github.com/UKGovernmentBEIS/inspect_evals
def get_humaneval_dataset(instruction_prompt: str = HUMANEVAL_INSTRUCTION) -> Dataset:
    """
    Load the HumanEval dataset for evaluation.

    Args:
        instruction_prompt (str): The prompt to prepend to the code problem.

    Returns:
        Dataset: The HumanEval dataset.
    """
    return hf_dataset(
        path="openai_humaneval",
        split="test",
        sample_fields=record_to_sample(instruction_prompt=instruction_prompt),
    )

src/openbench/datasets/jsonschemabench.py (14.1 KiB)

from typing import Dict, List, Tuple
from datasets import load_dataset  # type: ignore[import-untyped]
from inspect_ai.dataset import Dataset, Sample, MemoryDataset
from inspect_ai.model import (
    ChatMessageSystem,
    ChatMessageUser,
    ChatMessageAssistant,
    ChatMessageTool,
)

JSONSCHEMABENCH_SYSTEM_PROMPT = (
    "You need to generate a JSON object that matches the schema below."
)

FEWSHOT_EXAMPLES: Dict[Tuple[str, ...], List[Tuple[str, str]]] = {
    ("Snowplow",): [
        (
            '{\n    "additionalProperties": false,\n    "description": "Schema for a JSON Paths file for loading Redshift from JSON or Avro, http://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-data-format.html#copy-json-jsonpaths",\n    "properties": {\n        "jsonpaths": {\n            "items": {\n                "type": "string"\n            },\n            "minItems": 1,\n            "type": "array"\n        }\n    },\n    "required": [\n        "jsonpaths"\n    ],\n    "self": {\n        "format": "jsonschema",\n        "name": "jsonpaths_file",\n        "vendor": "com.amazon.aws.redshift",\n        "version": "1-0-0"\n    },\n    "type": "object"\n}',
            '{"jsonpaths": ["$.user.id", "$.user.name", "$.user.address.street"]}',
        ),
        (
            '{\n    "additionalProperties": false,\n    "description": "Schema for a Google Analytics enhanced e-commerce product impression custom metric entity",\n    "properties": {\n        "customMetricIndex": {\n            "maximum": 200,\n            "minimum": 1,\n            "type": "integer"\n        },\n        "listIndex": {\n            "maximum": 200,\n            "minimum": 1,\n            "type": "integer"\n        },\n        "productIndex": {\n            "maximum": 200,\n            "minimum": 1,\n            "type": "integer"\n        },\n        "value": {\n            "type": [\n                "integer",\n                "null"\n            ]\n        }\n    },\n    "self": {\n        "format": "jsonschema",\n        "name": "product_impression_custom_metric",\n        "vendor": "com.google.analytics.measurement-protocol",\n        "version": "1-0-0"\n    },\n    "type": "object"\n}',
            '{"customMetricIndex": 120, "listIndex": 45, "productIndex": 10, "value": 300}',
        ),
    ],
    ("Github_easy", "Github_hard", "Github_medium", "Github_trivial", "Github_ultra"): [
        (
            '{\n    "$schema": "http://json-schema.org/draft-04/schema#",\n    "definitions": {\n        "address1": {"type": "string"},\n        "address2": {"type": "string"},\n        "city": {"type": "string"},\n        "country": {"type": "string"},\n        "postalCode": {"type": "string"},\n        "state": {"type": "string"}\n    },\n    "description": "A simple address schema",\n    "properties": {\n        "address1": {"$ref": "#/definitions/address1"},\n        "address2": {"$ref": "#/definitions/address2"},\n        "city": {"$ref": "#/definitions/city"},\n        "country": {"$ref": "#/definitions/country"},\n        "postalCode": {"$ref": "#/definitions/postalCode"},\n        "state": {"$ref": "#/definitions/state"}\n    },\n    "type": "object"\n}',
            '{"address1": "123 Main Street", "address2": "Apt 4B", "city": "Seattle", "country": "USA", "postalCode": "98101", "state": "WA"}',
        ),
        (
            '{\n    "$schema": "http://json-schema.org/draft-06/schema#",\n    "definitions": {\n        "ElementType": {\n            "enum": ["component", "directive"],\n            "type": "string"\n        },\n        "SelectorChange": {\n            "properties": {\n                "remove": {\n                    "description": "Remove directive/component",\n                    "type": "boolean"\n                },\n                "replaceWith": {\n                    "description": "Replace original selector with new one",\n                    "type": "string"\n                },\n                "selector": {\n                    "description": "Original selector to apply change to",\n                    "type": "string"\n                },\n                "type": {\n                    "$ref": "#/definitions/ElementType",\n                    "description": "Type of selector the change applies to - either component or directive"\n                }\n            },\n            "required": ["selector", "type"],\n            "type": "object"\n        }\n    },\n    "properties": {\n        "changes": {\n            "description": "An array of changes to component/directive selectors",\n            "items": {\n                "$ref": "#/definitions/SelectorChange"\n            },\n            "type": "array"\n        }\n    },\n    "required": ["changes"],\n    "type": "object"\n}',
            '{\n  "changes": [\n    {\n      "selector": "app-root",\n      "type": "component",\n      "remove": false,\n      "replaceWith": "new-root"\n    },\n    {\n      "selector": "my-directive",\n      "type": "directive",\n      "remove": true,\n      "replaceWith": "new-directive"\n    }\n  ]\n}',
        ),
    ],
    ("Glaiveai2K",): [
        (
            '{"properties": {"username": {"description": "The user\'s username", "type": "string"}, "email": {"description": "The user\'s email address", "type": "string"}, "age": {"description": "The user\'s age", "type": "integer"}, "is_active": {"description": "Whether the user is active", "type": "boolean"}}, "required": ["username", "email"], "type": "object"}',
            '{"username": "johndoe", "email": "john@example.com", "age": 30, "is_active": true} ',
        ),
        (
            '{"properties": {"product_id": {"description": "The ID of the product", "type": "string"}, "rating": {"description": "The rating given by the user", "type": "integer"}, "comments": {"description": "Additional comments about the product", "type": "string"}}, "required": ["product_id", "rating"], "type": "object"}',
            '{"product_id": "12345", "rating": 5, "comments": "Excellent product! Highly recommend."} ',
        ),
    ],
    ("JsonSchemaStore",): [
        (
            '{\n  "$id": "https://json.schemastore.org/minecraft-trim-pattern.json",\n  "$schema": "http://json-schema.org/draft-07/schema#",\n  "description": "A trim pattern for a Minecraft data pack config schema",\n  "properties": {\n    "asset_id": {\n      "type": "string"\n    },\n    "description": {\n      "properties": {\n        "color": {\n          "type": "string"\n        },\n        "translate": {\n          "type": "string"\n        }\n      },\n      "required": ["translate"],\n      "type": "object"\n    },\n    "template_item": {\n      "type": "string"\n    }\n  },\n  "required": ["asset_id", "description", "template_item"],\n  "title": "Minecraft Data Pack Trim Pattern",\n  "type": "object"\n}',
            '{\n  "asset_id": "minecraft:trim_pattern",\n  "description": {\n    "color": "#FFAA00",\n    "translate": "trim_pattern.description"\n  },\n  "template_item": "minecraft:template_item"\n}',
        ),
        (
            '{\n  "$comment": "https://minecraft.fandom.com/wiki/Data_Pack",\n  "$id": "https://json.schemastore.org/minecraft-damage-type.json",\n  "$schema": "http://json-schema.org/draft-07/schema#",\n  "description": "A damage type\'s for a Minecraft data pack config schema",\n  "properties": {\n    "death_message_type": {\n      "enum": ["default", "fall_variants", "intentional_game_design"],\n      "type": "string"\n    },\n    "effects": {\n      "enum": ["hurt", "thorns", "drowning", "burning", "poking", "freezing"],\n      "type": "string"\n    },\n    "exhaustion": {\n      "type": "number"\n    },\n    "message_id": {\n      "type": "string"\n    },\n    "scaling": {\n      "enum": ["never", "always", "when_caused_by_living_non_player"],\n      "type": "string"\n    }\n  },\n  "required": ["message_id", "scaling", "exhaustion"],\n  "title": "Minecraft Data Pack Damage Type",\n  "type": "object"\n}',
            '{\n  "message_id": "minecraft:damage.message",\n  "scaling": "always",\n  "exhaustion": 0.3,\n  "death_message_type": "default",\n  "effects": "hurt"\n}',
        ),
    ],
    ("Kubernetes",): [
        (
            '{\n  "description": "A topology selector requirement is a selector that matches given label. This is an alpha feature and may change in the future.",\n  "properties": {\n    "key": {\n      "description": "The label key that the selector applies to.",\n      "type": ["string", "null"]\n    },\n    "values": {\n      "description": "An array of string values. One value must match the label to be selected. Each entry in Values is ORed.",\n      "items": {\n        "type": ["string", "null"]\n      },\n      "type": ["array", "null"]\n    }\n  },\n  "required": ["key", "values"],\n  "type": "object"\n}',
            '{\n  "key": "region",\n  "values": ["us-west-1", "us-east-1"]\n}',
        ),
        (
            '{\n  "description": "HostAlias holds the mapping between IP and hostnames that will be injected as an entry in the pod\'s hosts file.",\n  "properties": {\n    "hostnames": {\n      "description": "Hostnames for the above IP address.",\n      "items": {\n        "type": ["string", "null"]\n      },\n      "type": ["array", "null"]\n    },\n    "ip": {\n      "description": "IP address of the host file entry.",\n      "type": ["string", "null"]\n    }\n  },\n  "type": "object"\n}',
            '{\n  "ip": "192.168.1.1",\n  "hostnames": ["example.com", "test.com"]\n}',
        ),
    ],
    ("WashingtonPost",): [
        (
            '{\n  "additionalProperties": false,\n  "description": "Models a auxiliary used in targeting a piece of content.",\n  "properties": {\n    "_id": {\n      "description": "The unique identifier for this auxiliary.",\n      "type": "string"\n    },\n    "name": {\n      "description": "The general name for this auxiliary.",\n      "type": "string"\n    },\n    "uid": {\n      "description": "A short identifier for this auxiliary. Usually used in cases where a long form id cannot work.",\n      "type": "string"\n    }\n  },\n  "required": ["_id", "uid"],\n  "title": "Auxiliary",\n  "type": "object"\n}',
            '{\n  "_id": "12345",\n  "uid": "aux123",\n  "name": "Sample Auxiliary"\n}',
        ),
        (
            '{\n  "additionalProperties": {},\n  "definitions": {\n    "trait_additional_properties_json": {\n      "$schema": "http://json-schema.org/draft-04/schema#",\n      "additionalProperties": {},\n      "description": "A grab-bag object for non-validatable data.",\n      "title": "Has additional properties",\n      "type": "object"\n    }\n  },\n  "description": "Comment configuration data",\n  "properties": {\n    "additional_properties": {\n      "$ref": "#/definitions/trait_additional_properties_json"\n    },\n    "allow_comments": {\n      "description": "If false, commenting is disabled on this content.",\n      "type": "boolean"\n    },\n    "comments_period": {\n      "description": "How long (in days) after publish date until comments are closed.",\n      "type": "integer"\n    },\n    "display_comments": {\n      "description": "If false, do not render comments on this content.",\n      "type": "boolean"\n    },\n    "moderation_required": {\n      "description": "If true, comments must be moderator-approved before being displayed.",\n      "type": "boolean"\n    }\n  },\n  "title": "Comments",\n  "type": "object"\n}',
            '{\n  "allow_comments": true,\n  "comments_period": 30,\n  "display_comments": true,\n  "moderation_required": false,\n  "additional_properties": {}\n}',
        ),
    ],
    ("default",): [],
}


def _find_examples_for_subset(subset: str | None) -> List[Tuple[str, str]]:
    """Find few-shot examples for a subset."""
    for key, examples in FEWSHOT_EXAMPLES.items():
        if subset in key:
            return examples
    return FEWSHOT_EXAMPLES[("default",)]


def _get_few_shot_examples(subset: str | None, num_shots: int) -> List[Tuple[str, str]]:
    """Get first N few-shot examples for a subset."""
    examples = _find_examples_for_subset(subset)
    if num_shots > len(examples):
        raise ValueError(
            f"Not enough {subset} examples to prompt with {num_shots} shots"
        )
    return examples[:num_shots]


def _build_messages(
    schema: str, examples: List[Tuple[str, str]]
) -> List[ChatMessageSystem | ChatMessageUser | ChatMessageAssistant | ChatMessageTool]:
    """Build message list with few-shot examples."""
    messages: List[
        ChatMessageSystem | ChatMessageUser | ChatMessageAssistant | ChatMessageTool
    ] = [ChatMessageSystem(content=JSONSCHEMABENCH_SYSTEM_PROMPT)]

    for schema_str, json_str in examples:
        messages.append(ChatMessageUser(content=schema_str))
        messages.append(ChatMessageAssistant(content=json_str))

    messages.append(ChatMessageUser(content=schema))
    return messages


def record_to_sample(
    record: dict, num_shots: int = 0, subset: str | None = None
) -> Sample:
    """Convert HuggingFace dataset record to Inspect Sample with few-shot prompting."""
    schema = record["json_schema"]
    unique_id = record["unique_id"]

    # Build few-shot prompt if requested
    examples = _get_few_shot_examples(subset, num_shots)
    messages = _build_messages(schema, examples)

    return Sample(
        input=messages,
        target="",
        metadata={
            "schema": schema,
            "unique_id": unique_id,
            "num_shots": num_shots,
        },
    )


def get_dataset(
    subset: str | None = None, split: str = "all", num_shots: int = 0
) -> Dataset:
    """Load JSONSchemaBench dataset from HuggingFace with few-shot examples.

    Args:
        subset: Dataset subset (e.g., "Github_easy", "Kubernetes", "Snowplow")
        split: Dataset split ("test", "val", "train", or "all")
        num_shots: Number of few-shot examples (0 for zero-shot, paper used 2)
    """
    split_param = {
        "test": "test",
        "val": "val",
        "train": "train",
        "all": "train[:]+val[:]+test[:]",
    }
    config = subset if subset else "default"
    name = f"jsonschemabench_{config}_{split}_{num_shots}shot"
    dataset = load_dataset(
        "epfl-dlab/JSONSchemaBench", config, split=split_param[split]
    )
    samples = [
        record_to_sample(record, num_shots=num_shots, subset=subset)
        for record in dataset
    ]
    return MemoryDataset(samples=samples, name=name)

src/openbench/datasets/math.py (907 B)

from inspect_ai.dataset import Dataset, csv_dataset, Sample, MemoryDataset


def record_to_sample(record: dict) -> Sample:
    """Convert a MATH CSV record to an Inspect Sample."""
    return Sample(
        input=record["Question"],
        target=record["Answer"],
        metadata={},
    )


def get_dataset(split: str = "math_test") -> Dataset:
    """Load the MATH dataset.

    Args:
        split: Which dataset split to use - "math_test" for full 5000 problems,
               or "math_500_test" for 500 problem subset
    """
    # Load the dataset from OpenAI's blob storage
    dataset = csv_dataset(
        csv_file=f"https://openaipublic.blob.core.windows.net/simple-evals/{split}.csv",
        sample_fields=record_to_sample,
        auto_id=True,
        name=split,
    )

    # Convert to list of samples
    samples = list(dataset)

    return MemoryDataset(samples=samples, name=split)

src/openbench/datasets/mgsm.py (7.8 KiB)

"""MGSM (Multilingual Grade School Math) dataset loader.

Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
Based on: Language Models are Multilingual Chain-of-Thought Reasoners
Freda Shi et al., 2022
https://arxiv.org/abs/2210.03057
"""

import urllib.request
from typing import List, Optional
from inspect_ai.dataset import Dataset, Sample, MemoryDataset

ALL_LANGUAGES = ["bn", "de", "en", "es", "fr", "ja", "ru", "sw", "te", "th", "zh"]
LATIN_LANGUAGES = ["de", "en", "es", "fr", "sw"]
NON_LATIN_LANGUAGES = ["bn", "ja", "ru", "te", "th", "zh"]

LANG_TO_FPATH = {
    "bn": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_bn.tsv",
    "de": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_de.tsv",
    "en": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_en.tsv",
    "es": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_es.tsv",
    "fr": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_fr.tsv",
    "ja": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ja.tsv",
    "ru": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_ru.tsv",
    "sw": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_sw.tsv",
    "te": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_te.tsv",
    "th": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_th.tsv",
    "zh": "https://openaipublic.blob.core.windows.net/simple-evals/mgsm_zh.tsv",
}

LANG_TO_INSTRUCTIONS = {
    "en": """Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".

{input}""",
    "bn": """এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.

{input}""",
    "de": """Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.

{input}""",
    "es": """Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".

{input}""",
    "fr": """Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N'ajoutez rien d'autre que la réponse entière après "Réponse:".

{input}""",
    "ja": """の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。

{input}""",
    "ru": """Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".

{input}""",
    "sw": """Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".

{input}""",
    "te": """ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.

{input}""",
    "th": """แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:"

{input}""",
    "zh": """解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。

{input}""",
}

LANG_TO_ANSWER_PREFIX = {
    "en": "Answer",
    "bn": "উত্তর",
    "de": "Antwort",
    "es": "Respuesta",
    "fr": "Réponse",
    "ja": "答え",
    "ru": "Ответ",
    "sw": "Jibu",
    "te": "సమాధానం",
    "th": "คำตอบ",
    "zh": "答案",
}


def load_language_data(language: str) -> List[Sample]:
    """Load MGSM data for a specific language."""
    url = LANG_TO_FPATH[language]
    samples = []

    # Download and parse TSV data
    with urllib.request.urlopen(url) as response:
        content = response.read()

    # Parse TSV
    lines = content.decode("utf-8").strip().split("\n")
    for idx, line in enumerate(lines):
        parts = line.split("\t")
        if len(parts) == 2:
            problem, answer = parts
            # Format the instruction with the problem
            instruction = LANG_TO_INSTRUCTIONS[language].format(input=problem)

            samples.append(
                Sample(
                    input=instruction,
                    target=answer.strip(),
                    id=f"{language}_{idx + 1}",
                    metadata={
                        "language": language,
                        "answer_prefix": LANG_TO_ANSWER_PREFIX[language],
                        "latin_script": language in LATIN_LANGUAGES,
                        "original_problem": problem,
                    },
                )
            )

    return samples


def get_dataset(
    languages: Optional[List[str]] = None, limit_per_language: Optional[int] = None
) -> Dataset:
    """Load the MGSM dataset.

    Args:
        languages: List of language codes to include (defaults to all)
        limit_per_language: Maximum samples per language (defaults to all)

    Returns:
        Dataset with MGSM samples
    """
    if languages is None:
        languages = ALL_LANGUAGES
    else:
        # Validate language codes
        for lang in languages:
            if lang not in ALL_LANGUAGES:
                raise ValueError(
                    f"Invalid language code: {lang}. Must be one of {ALL_LANGUAGES}"
                )

    all_samples = []
    for lang in languages:
        lang_samples = load_language_data(lang)
        if limit_per_language is not None:
            lang_samples = lang_samples[:limit_per_language]
        all_samples.extend(lang_samples)

    return MemoryDataset(
        samples=all_samples,
        name=f"mgsm_{'_'.join(languages)}"
        if len(languages) > 1
        else f"mgsm_{languages[0]}",
    )

src/openbench/datasets/mmlu.py (3.4 KiB)

from inspect_ai.dataset import Dataset, csv_dataset, Sample
from openbench.utils.text import MULTIPLE_CHOICE_PROMPT_TEMPLATE

# Adapted from https://github.com/openai/simple-evals
SUBJECT_TO_CATEGORY = {
    "abstract_algebra": "stem",
    "anatomy": "other",
    "astronomy": "stem",
    "business_ethics": "other",
    "clinical_knowledge": "other",
    "college_biology": "stem",
    "college_chemistry": "stem",
    "college_computer_science": "stem",
    "college_mathematics": "stem",
    "college_medicine": "other",
    "college_physics": "stem",
    "computer_security": "stem",
    "conceptual_physics": "stem",
    "econometrics": "social_sciences",
    "electrical_engineering": "stem",
    "elementary_mathematics": "stem",
    "formal_logic": "humanities",
    "global_facts": "other",
    "high_school_biology": "stem",
    "high_school_chemistry": "stem",
    "high_school_computer_science": "stem",
    "high_school_european_history": "humanities",
    "high_school_geography": "social_sciences",
    "high_school_government_and_politics": "social_sciences",
    "high_school_macroeconomics": "social_sciences",
    "high_school_mathematics": "stem",
    "high_school_microeconomics": "social_sciences",
    "high_school_physics": "stem",
    "high_school_psychology": "social_sciences",
    "high_school_statistics": "stem",
    "high_school_us_history": "humanities",
    "high_school_world_history": "humanities",
    "human_aging": "other",
    "human_sexuality": "social_sciences",
    "international_law": "humanities",
    "jurisprudence": "humanities",
    "logical_fallacies": "humanities",
    "machine_learning": "stem",
    "management": "other",
    "marketing": "other",
    "medical_genetics": "other",
    "miscellaneous": "other",
    "moral_disputes": "humanities",
    "moral_scenarios": "humanities",
    "nutrition": "other",
    "philosophy": "humanities",
    "prehistory": "humanities",
    "professional_accounting": "other",
    "professional_law": "humanities",
    "professional_medicine": "other",
    "professional_psychology": "social_sciences",
    "public_relations": "social_sciences",
    "security_studies": "social_sciences",
    "sociology": "social_sciences",
    "us_foreign_policy": "social_sciences",
    "virology": "other",
    "world_religions": "humanities",
}

LANGUAGES = [
    "EN-US",
    "AR-XY",
    "BN-BD",
    "DE-DE",
    "ES-LA",
    "FR-FR",
    "HI-IN",
    "ID-ID",
    "IT-IT",
    "JA-JP",
    "KO-KR",
    "PT-BR",
    "ZH-CN",
    "SW-KE",
    "YO-NG",
]


def record_to_sample(record: dict[str, str]) -> Sample:
    return Sample(
        input=MULTIPLE_CHOICE_PROMPT_TEMPLATE.format(
            prompt=record["Question"],
            option_a=record["A"],
            option_b=record["B"],
            option_c=record["C"],
            option_d=record["D"],
        ),
        target=record["Answer"],
        metadata={
            "subject": record["Subject"],
            "category": SUBJECT_TO_CATEGORY[record["Subject"]],
        },
    )


def get_dataset(language: str = "EN-US") -> Dataset:
    if language not in LANGUAGES:
        raise ValueError(f"Language {language} not supported")

    return csv_dataset(
        csv_file="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
        if language == "EN-US"
        else f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv",
        sample_fields=record_to_sample,
        auto_id=True,
        name="mmlu_simple_eval",
    )

src/openbench/datasets/mrcr.py (2.2 KiB)

from typing import Any, Callable, Optional

from inspect_ai.dataset import Sample, Dataset, hf_dataset, FieldSpec
from openbench.utils.text import get_chatml_tok_cnt, str_to_chat_messages


def record_to_sample(
    max_context_size: Optional[int] = None,
) -> FieldSpec | Callable[[dict[str, Any]], Sample | list[Sample]]:
    """Create a mapper from MRCR records to Inspect Samples.

    Expected fields in the source record:
    - prompt (str): input to the model
    - answer (str): expected output
    - random_string_to_prepend (str)
    - n_needles (int)
    - desired_msg_index (int)
    - total_messages (int)
    - n_chars (int)
    """

    def _record_to_sample(record: dict[str, Any]) -> Sample | list[Sample]:
        chat_messages = str_to_chat_messages(record["prompt"])
        input_tok_cnt = get_chatml_tok_cnt(record["prompt"])
        if max_context_size is not None and input_tok_cnt > max_context_size:
            return []
        metadata = {
            "random_string_to_prepend": record.get("random_string_to_prepend"),
            "n_needles": record.get("n_needles"),
            "desired_msg_index": record.get("desired_msg_index"),
            "total_messages": record.get("total_messages"),
            "n_chars": record.get("n_chars"),
            "raw_input_tok_cnt": input_tok_cnt,
        }

        return Sample(
            input=chat_messages,
            target=record["answer"],
            metadata=metadata,
        )

    return _record_to_sample


def get_dataset(
    needles: Optional[int] = None, max_context_size: Optional[int] = None
) -> Dataset:
    """Load the MRCR dataset.

    Args:
        needles: Number of needles to include (2, 4, or 8). Defaults to None.
        max_context_size: Maximum context size in tokens. Defaults to None.
    Returns:
        Dataset filtered to the requested number of needles.
    """

    if needles in (2, 4, 8):
        return hf_dataset(
            path="openai/mrcr",
            split="train",
            sample_fields=record_to_sample(max_context_size),
            data_files=f"{needles}needle.parquet",
        )

    return hf_dataset(
        path="openai/mrcr",
        split="train",
        sample_fields=record_to_sample(max_context_size),
    )

src/openbench/datasets/rootly_gmcq.py (810 B)

from inspect_ai.dataset import Sample
from inspect_ai.dataset import hf_dataset
from typing import Any

SUBTASK = None


def record_to_sample_gmcq(record: dict[str, Any]):
    if SUBTASK is None:
        return Sample(
            input=record["input"],
            target=record["ideal"],
        )
    else:
        if record["repository_name"] in SUBTASK.split(","):
            return Sample(
                input=record["input"],
                target=record["ideal"],
            )
        else:
            return []


def load_dataset(subtask):
    global SUBTASK
    SUBTASK = subtask

    dataset = hf_dataset(
        "TheFloatingString/gmcq",
        split="test",
        sample_fields=record_to_sample_gmcq,
        revision="51c9eace06dd5791e72717bf6ba0348d23857c50",
    )
    return dataset

src/openbench/datasets/scicode.py (473 B)

from inspect_ai.dataset import hf_dataset
from inspect_ai.dataset import Sample


def record_to_sample(record):
    return Sample(
        input="problem_id",
        target=record["problem_id"],
        id=record["problem_id"],
        metadata={k: v for k, v in record.items()},
    )


def return_hf_dataset(split: str = "test"):
    dataset = hf_dataset(
        "SciCode1/SciCode",
        split=split,
        sample_fields=record_to_sample,
    )
    return dataset

src/openbench/datasets/simpleqa.py (948 B)

from inspect_ai.dataset import Dataset, csv_dataset, Sample, MemoryDataset


def record_to_sample(record: dict) -> Sample:
    """Convert a SimpleQA CSV record to an Inspect Sample."""
    return Sample(
        input=record["problem"],
        target=record["answer"],
        metadata={"metadata": record.get("metadata", "")},
    )


def get_dataset() -> Dataset:
    """Load the SimpleQA dataset.

    Args:
        num_examples: Number of examples to use (None for all)
        n_repeats: Number of times to repeat the dataset (only valid when num_examples is None)
    """
    # Load the full dataset
    dataset = csv_dataset(
        csv_file="https://openaipublic.blob.core.windows.net/simple-evals/simple_qa_test_set.csv",
        sample_fields=record_to_sample,
        auto_id=True,
        name="simpleqa",
    )

    # Convert to list of samples
    samples = list(dataset)

    return MemoryDataset(samples=samples, name="simpleqa")

src/openbench/eval_config.py (3.8 KiB)

"""
Evaluation configuration that combines static metadata with dynamic extraction.
Fast static data for list command, comprehensive dynamic data for describe.
"""

from typing import Optional, List, Dict, Any
from dataclasses import dataclass, field
from functools import lru_cache

from openbench.config import get_benchmark_metadata, load_task


@dataclass
class EvalConfig:
    """Combined static metadata and dynamic config."""

    # Static metadata
    name: str
    description: str
    category: str
    tags: List[str] = field(default_factory=list)

    # Dynamic data (populated on demand)
    epochs: Optional[int] = None
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    sandbox: Optional[str] = None
    dataset_size: Optional[int] = None
    task_args: Optional[Dict[str, Any]] = None

    # Loading state
    _dynamic_loaded: bool = False


def _extract_task_config(task_name: str) -> Dict[str, Any]:
    """Extract configuration from actual task definition."""
    try:
        task_func = load_task(task_name)

        # Get task function signature to extract any arguments
        import inspect

        sig = inspect.signature(task_func)
        task_args = {}
        if sig.parameters:
            # Extract parameter names and defaults
            for param_name, param in sig.parameters.items():
                if param.default != inspect.Parameter.empty:
                    task_args[param_name] = param.default

        # Call task function to get the task object
        task = task_func()

        config = {
            "epochs": getattr(task, "epochs", None),
        }

        # Add task-specific arguments
        if task_args:
            config["task_args"] = task_args

        # Extract sandbox info
        sandbox = getattr(task, "sandbox", None)
        if sandbox:
            # Handle SandboxEnvironmentSpec or string
            if hasattr(sandbox, "type"):
                config["sandbox"] = sandbox.type
            else:
                config["sandbox"] = str(sandbox)
        else:
            config["sandbox"] = None

        # Extract all GenerateConfig values dynamically
        if hasattr(task, "config") and task.config:
            # Get all fields from GenerateConfig that have values
            for field_name in task.config.model_fields:
                value = getattr(task.config, field_name, None)
                if value is not None:
                    config[field_name] = value

        # Try to get dataset size without loading full dataset
        if hasattr(task, "dataset"):
            try:
                # Some datasets have length without loading
                config["dataset_size"] = len(task.dataset)
            except Exception:
                config["dataset_size"] = None

        return config
    except Exception:
        return {}


@lru_cache(maxsize=None)
def get_eval_config(name: str, load_dynamic: bool = False) -> Optional[EvalConfig]:
    """Get evaluation configuration by name.

    Args:
        name: Benchmark name
        load_dynamic: Whether to load dynamic data (slow but comprehensive)
    """
    metadata = get_benchmark_metadata(name)
    if not metadata:
        return None

    config = EvalConfig(
        name=metadata.name,
        description=metadata.description,
        category=metadata.category,
        tags=metadata.tags,
    )

    if load_dynamic:
        dynamic_data = _extract_task_config(name)
        config.epochs = dynamic_data.get("epochs")
        config.temperature = dynamic_data.get("temperature")
        config.max_tokens = dynamic_data.get("max_tokens")
        config.sandbox = dynamic_data.get("sandbox")
        config.dataset_size = dynamic_data.get("dataset_size")
        config.task_args = dynamic_data.get("task_args")
        config._dynamic_loaded = True

    return config

src/openbench/evals/__init__.py (0 B)


src/openbench/evals/browsecomp.py (1.1 KiB)

"""BrowseComp: A Simple Yet Challenging Benchmark for Browsing Agents.

Authors: Jason Wei, Zhiqing Sun, Spencer Papay, Scott McKinney, Jeffrey Han,
Isa Fulford, Hyung Won Chung, Alex Tachard Passos, William Fedus, Mia Glaese

https://openai.com/index/browsecomp/
"""

from inspect_ai import task, Task
from inspect_ai.solver import generate
from openbench.datasets.browsecomp import get_dataset
from openbench.scorers.browsecomp import browsecomp_scorer


@task
def browsecomp(grader_model: str = "openai/gpt-4.1-2025-04-14") -> Task:
    """BrowseComp: A Simple Yet Challenging Benchmark for Browsing Agents.

    This benchmark evaluates model performance on browsing-related tasks
    that require understanding and reasoning about web content.

    Args:
        grader_model: Model to use for grading responses (defaults to gpt-4.1-2025-04-14)

    Returns:
        Task configured for BrowseComp evaluation
    """
    return Task(
        dataset=get_dataset(),
        solver=[generate()],
        scorer=browsecomp_scorer(model=grader_model),
        name="browsecomp",
    )

src/openbench/evals/cti_bench.py (4.8 KiB)

"""CTI-Bench evaluation tasks for cybersecurity threat intelligence benchmarks."""

from inspect_ai import task, Task
from inspect_ai.solver import generate
from inspect_ai.model import GenerateConfig
from inspect_ai.dataset import Dataset, MemoryDataset

from openbench.datasets.cti_bench import (
    get_cti_bench_mcq_dataset,
    get_cti_bench_rcm_dataset,
    get_cti_bench_vsp_dataset,
    get_cti_bench_ate_dataset,
)
from openbench.scorers.cti_bench import (
    cti_bench_mcq_scorer,
    cti_bench_rcm_scorer,
    cti_bench_vsp_scorer,
    cti_bench_ate_scorer,
)


@task
def cti_bench_mcq() -> Task:
    """CTI-Bench Multiple Choice Questions task."""
    return Task(
        dataset=get_cti_bench_mcq_dataset(),
        solver=[generate()],
        scorer=cti_bench_mcq_scorer(),
        name="cti_bench_mcq",
        config=GenerateConfig(temperature=0.0, max_tokens=8192),
    )


@task
def cti_bench_rcm() -> Task:
    """CTI-Bench RCM (CVE→CWE mapping) task."""
    return Task(
        dataset=get_cti_bench_rcm_dataset(),
        solver=[generate()],
        scorer=cti_bench_rcm_scorer(),
        name="cti_bench_rcm",
        config=GenerateConfig(temperature=0.0, max_tokens=8192),
    )


@task
def cti_bench_vsp() -> Task:
    """CTI-Bench VSP (CVSS severity prediction) task."""
    return Task(
        dataset=get_cti_bench_vsp_dataset(),
        solver=[generate()],
        scorer=cti_bench_vsp_scorer(),
        name="cti_bench_vsp",
        config=GenerateConfig(temperature=0.0, max_tokens=8192),
    )


@task
def cti_bench_ate() -> Task:
    """CTI-Bench ATE (ATT&CK Technique Extraction) task."""
    return Task(
        dataset=get_cti_bench_ate_dataset(),
        solver=[generate()],
        scorer=cti_bench_ate_scorer(),
        name="cti_bench_ate",
        config=GenerateConfig(temperature=0.0, max_tokens=8192),
    )


def combine_datasets() -> Dataset:
    """Combine all CTI-Bench datasets into one."""
    # Load all individual datasets
    mcq_dataset = get_cti_bench_mcq_dataset()
    rcm_dataset = get_cti_bench_rcm_dataset()
    vsp_dataset = get_cti_bench_vsp_dataset()
    ate_dataset = get_cti_bench_ate_dataset()

    combined_samples = []

    # Add MCQ samples with task type metadata
    for sample in mcq_dataset:
        sample.metadata = sample.metadata or {}
        sample.metadata["task_type"] = "mcq"
        combined_samples.append(sample)

    # Add RCM samples with task type metadata
    for sample in rcm_dataset:
        sample.metadata = sample.metadata or {}
        sample.metadata["task_type"] = "rcm"
        combined_samples.append(sample)

    # Add VSP samples with task type metadata
    for sample in vsp_dataset:
        sample.metadata = sample.metadata or {}
        sample.metadata["task_type"] = "vsp"
        combined_samples.append(sample)

    # Add ATE samples with task type metadata
    for sample in ate_dataset:
        sample.metadata = sample.metadata or {}
        sample.metadata["task_type"] = "ate"
        combined_samples.append(sample)

    return MemoryDataset(samples=combined_samples, name="cti_bench")


def combined_cti_bench_scorer():
    """Combined scorer that delegates to appropriate task-specific scorer."""
    from inspect_ai.scorer import scorer, Score, Target, accuracy, stderr
    from inspect_ai.solver import TaskState
    from typing import Callable

    # Get individual scorers
    mcq_scorer_fn = cti_bench_mcq_scorer()
    rcm_scorer_fn = cti_bench_rcm_scorer()
    vsp_scorer_fn = cti_bench_vsp_scorer()
    ate_scorer_fn = cti_bench_ate_scorer()

    @scorer(metrics=[accuracy(), stderr()])
    def cti_bench_combined_scorer() -> Callable:
        async def score(state: TaskState, target: Target) -> Score:
            # Determine which scorer to use based on task type
            task_type = state.metadata.get("task_type") if state.metadata else None

            if task_type == "mcq":
                return await mcq_scorer_fn(state, target)
            elif task_type == "rcm":
                return await rcm_scorer_fn(state, target)
            elif task_type == "vsp":
                return await vsp_scorer_fn(state, target)
            elif task_type == "ate":
                return await ate_scorer_fn(state, target)
            else:
                # Fallback - should not happen
                return Score(
                    value=0.0,
                    answer="",
                    metadata={"error": f"Unknown task type: {task_type}"},
                )

        return score

    return cti_bench_combined_scorer()


@task
def cti_bench() -> Task:
    """Combined CTI-Bench evaluation running all 4 cybersecurity tasks."""
    return Task(
        dataset=combine_datasets(),
        solver=[generate()],
        scorer=combined_cti_bench_scorer(),
        name="cti_bench",
        config=GenerateConfig(temperature=0.0, max_tokens=8192),
    )

src/openbench/evals/drop.py (1.3 KiB)

"""DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs.

Based on the paper by Dua et al. (2019).
https://arxiv.org/abs/1903.00161
"""

from inspect_ai import Task, task
from inspect_ai.model import GenerateConfig
from inspect_ai.solver import generate

from openbench.datasets.drop import get_dataset
from openbench.scorers.drop import drop_scorer


@task
def drop(
    num_examples: int | None = None,
    train_samples_per_prompt: int = 3,
) -> Task:
    """DROP: Reading comprehension requiring discrete reasoning.

    A reading comprehension benchmark that requires discrete reasoning over
    paragraphs, including arithmetic, counting, and sorting operations.

    Args:
        num_examples: Number of examples to evaluate (None for all)
        train_samples_per_prompt: Number of few-shot examples (default: 3)

    Returns:
        Task configured for DROP evaluation
    """
    return Task(
        dataset=get_dataset(
            num_examples=num_examples,
            train_samples_per_prompt=train_samples_per_prompt,
        ),
        solver=[generate()],
        scorer=drop_scorer(),
        name="drop",
        config=GenerateConfig(
            temperature=0.0,  # Deterministic for reasoning tasks
            max_tokens=8192,  # Allow for reasoning steps
        ),
    )

src/openbench/evals/gpqa_diamond.py (776 B)

from inspect_ai import Task, task, Epochs
from inspect_ai.model import GenerateConfig
from inspect_ai.solver import system_message, generate
from openbench.utils.text import SIMPLE_EVALS_SYSTEM_MESSAGE
from openbench.datasets.gpqa import get_dataset
from openbench.scorers import robust_mcq_scorer


# There is one difference between this and the original gpqa simple eval - the prompts are not reshuffled for every epoch. Shouldn't be that big of a deal, but worth noting.
@task
def gpqa_diamond() -> Task:
    return Task(
        dataset=get_dataset(),
        solver=[system_message(SIMPLE_EVALS_SYSTEM_MESSAGE), generate()],
        scorer=robust_mcq_scorer(),
        name="gpqa_diamond",
        config=GenerateConfig(temperature=0.5),
        epochs=Epochs(10),
    )

src/openbench/evals/graphwalks.py (1.2 KiB)

# src/openbench/evals/graphwalks.py
from __future__ import annotations

from inspect_ai import task, Task
from inspect_ai.model import GenerateConfig
from inspect_ai.solver import generate

from openbench.datasets.graphwalks import get_dataset
from openbench.scorers.graphwalks import graphwalks_scorer


@task
def graphwalks(split: str = "train") -> Task:
    return Task(
        dataset=get_dataset(split=split, task_type="both"),
        solver=[generate()],
        scorer=graphwalks_scorer(),
        name="graphwalks",
        config=GenerateConfig(temperature=0.0, top_p=1.0, max_tokens=256),
    )


@task
def graphwalks_bfs(split: str = "train") -> Task:
    return Task(
        dataset=get_dataset(split=split, task_type="bfs"),
        solver=[generate()],
        scorer=graphwalks_scorer(),
        name="graphwalks_bfs",
        config=GenerateConfig(temperature=0.0, top_p=1.0, max_tokens=256),
    )


@task
def graphwalks_parents(split: str = "train") -> Task:
    return Task(
        dataset=get_dataset(split=split, task_type="parents"),
        solver=[generate()],
        scorer=graphwalks_scorer(),
        name="graphwalks_parents",
        config=GenerateConfig(temperature=0.0, top_p=1.0, max_tokens=256),
    )

src/openbench/evals/healthbench.py (2.4 KiB)

"""HealthBench evaluation implementation."""

from typing import Optional

from inspect_ai import Task, task
from inspect_ai.model import GenerateConfig
from inspect_ai.solver import generate

from openbench.datasets.healthbench import get_dataset
from openbench.scorers.healthbench import healthbench_scorer


@task
def healthbench(
    subset: Optional[str] = None,
    grader_model: str = "openai/gpt-4.1-2025-04-14",
) -> Task:
    """HealthBench: Medical dialogue evaluation using physician-created rubrics.

    Based on the HealthBench benchmark from OpenAI's simple-evals.
    Evaluates medical dialogue completions against detailed rubrics.

    Args:
        subset: Which subset to evaluate ("hard", "consensus", or None for main)
        grader_model: Model to use for grading rubrics

    Returns:
        Task configured for HealthBench evaluation
    """
    return Task(
        dataset=get_dataset(subset=subset),
        solver=[generate()],
        scorer=healthbench_scorer(grader_model=grader_model),
        name="healthbench",
        config=GenerateConfig(
            temperature=0.0,  # Use deterministic generation for medical advice
            max_tokens=8192,  # Allow longer responses for detailed medical explanations
        ),
    )


@task
def healthbench_hard(grader_model: str = "openai/gpt-4.1-2025-04-14") -> Task:
    """HealthBench Hard subset: Most challenging medical dialogue cases.

    Args:
        grader_model: Model to use for grading rubrics

    Returns:
        Task configured for HealthBench Hard evaluation
    """
    return Task(
        dataset=get_dataset(subset="hard"),
        solver=[generate()],
        scorer=healthbench_scorer(grader_model=grader_model),
        name="healthbench_hard",
        config=GenerateConfig(
            temperature=0.0,
            max_tokens=8192,
        ),
    )


@task
def healthbench_consensus(grader_model: str = "openai/gpt-4.1-2025-04-14") -> Task:
    """HealthBench Consensus subset: Cases with physician consensus.

    Args:
        grader_model: Model to use for grading rubrics

    Returns:
        Task configured for HealthBench Consensus evaluation
    """
    return Task(
        dataset=get_dataset(subset="consensus"),
        solver=[generate()],
        scorer=healthbench_scorer(grader_model=grader_model),
        name="healthbench_consensus",
        config=GenerateConfig(
            temperature=0.0,
            max_tokens=8192,
        ),
    )

src/openbench/evals/hle.py (2.7 KiB)

from inspect_ai import task, Task
from inspect_ai.solver import generate, system_message
from inspect_ai.model import GenerateConfig
from openbench.datasets.hle import get_dataset
from openbench.scorers.hle import hle_scorer


# HLE system prompt as used in the original implementation
HLE_SYSTEM_PROMPT = "Your response should be in the following format:\nExplanation: {your explanation for your answer choice}\nAnswer: {your chosen answer}\nConfidence: {your confidence score between 0% and 100% for your answer}"


@task
def hle(
    grader_model: str = "openai/o3-mini-2025-01-31", max_tokens: int = 8192
) -> Task:
    """Humanity's Last Exam: A benchmark at the frontier of human knowledge.

    HLE consists of 2,500 questions across dozens of subjects including mathematics,
    humanities, and natural sciences. Questions are designed by subject-matter experts
    globally and include both multiple-choice and short-answer formats.

    Args:
        grader_model: Model to use for grading responses (defaults to o3-mini-2025-01-31)
        max_tokens: Maximum tokens for model response (defaults to 8192 as recommended by HLE)

    Returns:
        Task configured for HLE evaluation
    """
    return Task(
        dataset=get_dataset(text_only=False),
        solver=[
            system_message(HLE_SYSTEM_PROMPT),
            generate(),
        ],
        scorer=hle_scorer(model=grader_model),
        name="hle",
        config=GenerateConfig(
            temperature=0.0,  # Use deterministic generation as per HLE
            max_tokens=max_tokens,  # HLE recommends at least 8192 for reasoning models
        ),
    )


@task
def hle_text(
    grader_model: str = "openai/o3-mini-2025-01-31", max_tokens: int = 8192
) -> Task:
    """Humanity's Last Exam (Text-Only): HLE with multi-modal questions filtered out.

    This variant includes only text-based questions from HLE, excluding any questions
    that require image understanding. Useful for evaluating models without vision capabilities.

    Args:
        grader_model: Model to use for grading responses (defaults to o3-mini-2025-01-31)
        max_tokens: Maximum tokens for model response (defaults to 8192 as recommended by HLE)

    Returns:
        Task configured for HLE text-only evaluation
    """
    return Task(
        dataset=get_dataset(text_only=True),
        solver=[
            system_message(HLE_SYSTEM_PROMPT),
            generate(),
        ],
        scorer=hle_scorer(model=grader_model),
        name="hle_text",
        config=GenerateConfig(
            temperature=0.0,  # Use deterministic generation as per HLE
            max_tokens=max_tokens,  # HLE recommends at least 8192 for reasoning models
        ),
    )

src/openbench/evals/humaneval.py (1.2 KiB)

from inspect_ai import Task, task, Epochs
from inspect_ai.model import GenerateConfig
from inspect_ai.solver import generate
from openbench.scorers.humaneval import verify
from openbench.datasets.humaneval import get_humaneval_dataset
from typing import Optional


# Adapted from https://github.com/UKGovernmentBEIS/inspect_evals
@task
def humaneval(instruction_prompt: Optional[str] = None) -> Task:
    """
    Inspect Task implementation for the HumanEval benchmark.

    Args:
        instruction_prompt (str, optional): The prompt to prepend to the code problem.
            If None, uses the default HumanEval instruction.

    Returns:
        Task: The configured HumanEval task.
    """
    epochs_count = 5
    reducer_list = ["mean", "pass_at_1", "pass_at_2", "pass_at_5"]

    dataset = (
        get_humaneval_dataset()
        if instruction_prompt is None
        else get_humaneval_dataset(instruction_prompt=instruction_prompt)
    )

    return Task(
        dataset=dataset,
        solver=generate(),
        scorer=verify(),
        sandbox="local",
        config=GenerateConfig(
            temperature=0.5,
        ),
        epochs=Epochs(epochs_count, reducer=reducer_list),
    )

src/openbench/evals/jsonschemabench.py (2.8 KiB)

"""JSONSchemaBench: JSON Schema generation benchmark evaluation.

Based on: JSONSchemaBench: A Rigorous Benchmark of Structured Outputs for Language Models
EPFL DLAB, 2025
https://arxiv.org/html/2501.10868

Dataset: https://huggingface.co/datasets/epfl-dlab/JSONSchemaBench
"""

import json
from inspect_ai import Task, task
from inspect_ai.solver import TaskState, Generate, solver
from inspect_ai.model import GenerateConfig, ResponseSchema

from openbench.datasets.jsonschemabench import get_dataset
from openbench.scorers.json_schema import json_schema_scorer


@solver
def structured_output_solver():
    """Apply per-sample structured output for supported providers (OpenAI, Google, Mistral)."""

    async def solve(state: TaskState, generate: Generate) -> TaskState:
        if not state.metadata or "schema" not in state.metadata:
            return await generate(state)

        try:
            schema_str = state.metadata["schema"]
            schema_dict = json.loads(schema_str)

            return await generate(
                state,
                response_schema=ResponseSchema(
                    name="json_schema_output", json_schema=schema_dict, strict=True
                ),
            )

        except (json.JSONDecodeError, KeyError, ValueError, TypeError):
            return await generate(state)

    return solve


@task
def jsonschemabench(
    subset: str | None = None,
    split: str = "all",
    num_shots: int = 0,
    strip_markdown: bool = True,
) -> Task:
    """JSONSchemaBench: A Rigorous Benchmark of Structured Outputs
    for Language Models.

    Evaluates the ability of language models to generate valid JSON
    that conforms to provided JSON schemas. Based on ~10K real-world
    schemas from GitHub, Kubernetes, APIs, and other sources.

    Uses structured output when supported by the provider for API-level
    schema validation, otherwise falls back to text generation with
    post-hoc validation.

    See https://doi.org/10.48550/arXiv.2501.10868.

    Args:
        subset: Specific subset to evaluate (e.g., "Github_easy", "Kubernetes")
                or None for mixed benchmark
        split: Dataset split to use ("all", "test", "val", "train")
        num_shots: Number of few-shot examples to include (0 for zero-shot, paper used 2)
        strip_markdown: Whether to remove ```json``` markdown blocks from output (default True)

    Returns:
        Task configured for JSONSchemaBench evaluation
    """
    return Task(
        dataset=get_dataset(subset=subset, split=split, num_shots=num_shots),
        solver=[structured_output_solver()],
        scorer=json_schema_scorer(strip_markdown=strip_markdown),
        name="jsonschemabench",
        config=GenerateConfig(
            temperature=0.0,  # Following paper methodology (greedy decoding)
            timeout=40,  # 40-second timeout as per original paper
        ),
    )

src/openbench/evals/math.py (2.4 KiB)

from inspect_ai import task, Task
from inspect_ai.solver import generate
from inspect_ai.model import GenerateConfig
from openbench.datasets.math import get_dataset
from openbench.scorers.math import math_scorer


# Template for solving math problems - from simple-evals
QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.

{problem}

Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
""".strip()


@task
def math(grader_model: str = "openai/gpt-4-turbo-preview") -> Task:
    """MATH: Measuring Mathematical Problem Solving.

    Based on the paper by Hendrycks et al. (2021).
    Tests mathematical problem-solving across multiple difficulty levels and topics.
    Uses model-based grading to check mathematical equivalence of answers.

    Args:
        grader_model: Model to use for checking answer equality (defaults to gpt-4-turbo-preview)

    Returns:
        Task configured for MATH evaluation
    """
    # Get the dataset and format problems
    dataset = get_dataset("math_test")
    for sample in dataset:
        sample.input = QUERY_TEMPLATE.format(problem=sample.input)

    return Task(
        dataset=dataset,
        solver=[generate()],
        scorer=math_scorer(model=grader_model),
        name="math",
        config=GenerateConfig(
            max_tokens=8192,  # Allow long reasoning chains
        ),
    )


@task
def math_500(grader_model: str = "openai/gpt-4-turbo-preview") -> Task:
    """MATH-500: A 500-problem subset of the MATH dataset.

    A smaller, representative subset of the full MATH dataset for faster evaluation.
    Uses the same scoring and configuration as the full dataset.

    Args:
        grader_model: Model to use for checking answer equality (defaults to gpt-4-turbo-preview)

    Returns:
        Task configured for MATH-500 evaluation
    """
    # Get the dataset and format problems
    dataset = get_dataset("math_500_test")
    for sample in dataset:
        sample.input = QUERY_TEMPLATE.format(problem=sample.input)

    return Task(
        dataset=dataset,
        solver=[generate()],
        scorer=math_scorer(model=grader_model),
        name="math_500",
        config=GenerateConfig(
            max_tokens=8192,  # Allow long reasoning chains
        ),
    )

src/openbench/evals/matharena/__init__.py (1 B)


src/openbench/evals/matharena/aime_2023_I/__init__.py (0 B)


src/openbench/evals/matharena/aime_2023_I/aime_2023_I.py (783 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive."
# default_temperature: 0.6
# default_max_tokens: 8000
# strict_parsing: false
# n_problems: 15
# date: "2024-02-01"
# dataset_path: MathArena/aime_2023_I
@task
def aime_2023_I() -> Task:
    return matharena_task(
        dataset_path="MathArena/aime_2023_I",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive.",
        default_temperature=0.6,
        default_max_tokens=8000,
        default_epochs=4,
        name="aime_2023_I",
    )

src/openbench/evals/matharena/aime_2023_II/__init__.py (0 B)


src/openbench/evals/matharena/aime_2023_II/aime_2023_II.py (787 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive."
# default_temperature: 0.6
# default_max_tokens: 8000
# strict_parsing: false
# n_problems: 15
# date: "2024-02-01"
# dataset_path: MathArena/aime_2023_II
@task
def aime_2023_II() -> Task:
    return matharena_task(
        dataset_path="MathArena/aime_2023_II",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive.",
        default_temperature=0.6,
        default_max_tokens=8000,
        default_epochs=4,
        name="aime_2023_II",
    )

src/openbench/evals/matharena/aime_2024/__init__.py (0 B)


src/openbench/evals/matharena/aime_2024/aime_2024.py (544 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# Not MathArena, but concatenated from aime_2024_I and aime_2024_II
@task
def aime_2024() -> Task:
    return matharena_task(
        dataset_path="AarushSah/aime2024",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive.",
        default_temperature=0.6,
        default_max_tokens=8000,
        default_epochs=4,
        name="aime_2024",
    )

src/openbench/evals/matharena/aime_2024_I/__init__.py (0 B)


src/openbench/evals/matharena/aime_2024_I/aime_2024_I.py (783 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive."
# default_temperature: 0.6
# default_max_tokens: 8000
# strict_parsing: false
# n_problems: 15
# date: "2024-02-07"
# dataset_path: MathArena/aime_2024_I
@task
def aime_2024_I() -> Task:
    return matharena_task(
        dataset_path="MathArena/aime_2024_I",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive.",
        default_temperature=0.6,
        default_max_tokens=8000,
        default_epochs=4,
        name="aime_2024_I",
    )

src/openbench/evals/matharena/aime_2024_II/__init__.py (0 B)


src/openbench/evals/matharena/aime_2024_II/aime_2024_II.py (787 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive."
# default_temperature: 0.6
# default_max_tokens: 8000
# strict_parsing: false
# n_problems: 15
# date: "2024-02-07"
# dataset_path: MathArena/aime_2024_II
@task
def aime_2024_II() -> Task:
    return matharena_task(
        dataset_path="MathArena/aime_2024_II",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive.",
        default_temperature=0.6,
        default_max_tokens=8000,
        default_epochs=4,
        name="aime_2024_II",
    )

src/openbench/evals/matharena/aime_2025/__init__.py (0 B)


src/openbench/evals/matharena/aime_2025/aime_2025.py (775 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive."
# default_temperature: 0.6
# default_max_tokens: 8000
# strict_parsing: false
# n_problems: 30
# date: "2025-02-12"
# dataset_path: MathArena/aime_2025
@task
def aime_2025() -> Task:
    return matharena_task(
        dataset_path="MathArena/aime_2025",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive.",
        default_temperature=0.6,
        default_max_tokens=8000,
        default_epochs=4,
        name="aime_2025",
    )

src/openbench/evals/matharena/aime_2025_II/__init__.py (0 B)


src/openbench/evals/matharena/aime_2025_II/aime_2025_II.py (787 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive."
# default_temperature: 0.6
# default_max_tokens: 8000
# strict_parsing: false
# n_problems: 15
# date: "2025-02-06"
# dataset_path: MathArena/aime_2025_II
@task
def aime_2025_II() -> Task:
    return matharena_task(
        dataset_path="MathArena/aime_2025_II",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.\nThe answer is an integer between 0 and 999 inclusive.",
        default_temperature=0.6,
        default_max_tokens=8000,
        default_epochs=4,
        name="aime_2025_II",
    )

src/openbench/evals/matharena/brumo_2025/__init__.py (0 B)


src/openbench/evals/matharena/brumo_2025/brumo_2025.py (671 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}."
# default_temperature: 0.6
# default_max_tokens: 16000
# strict_parsing: false
# n_problems: 30
# date: "2025-04-08"
# dataset_path: MathArena/brumo_2025
@task
def brumo_2025() -> Task:
    return matharena_task(
        dataset_path="MathArena/brumo_2025",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.",
        default_temperature=0.6,
        default_max_tokens=16000,
        default_epochs=4,
        name="brumo_2025",
    )

src/openbench/evals/matharena/hmmt_feb_2023/__init__.py (0 B)


src/openbench/evals/matharena/hmmt_feb_2023/hmmt_feb_2023.py (683 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}."
# default_temperature: 0.6
# default_max_tokens: 16000
# strict_parsing: false
# n_problems: 30
# date: "2023-02-17"
# dataset_path: MathArena/hmmt_feb_2023
@task
def hmmt_feb_2023() -> Task:
    return matharena_task(
        dataset_path="MathArena/hmmt_feb_2023",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.",
        default_temperature=0.6,
        default_max_tokens=16000,
        default_epochs=4,
        name="hmmt_feb_2023",
    )

src/openbench/evals/matharena/hmmt_feb_2024/__init__.py (0 B)


src/openbench/evals/matharena/hmmt_feb_2024/hmmt_feb_2024.py (683 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}."
# default_temperature: 0.6
# default_max_tokens: 16000
# strict_parsing: false
# n_problems: 30
# date: "2024-02-17"
# dataset_path: MathArena/hmmt_feb_2024
@task
def hmmt_feb_2024() -> Task:
    return matharena_task(
        dataset_path="MathArena/hmmt_feb_2024",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.",
        default_temperature=0.6,
        default_max_tokens=16000,
        default_epochs=4,
        name="hmmt_feb_2024",
    )

src/openbench/evals/matharena/hmmt_feb_2025/__init__.py (0 B)


src/openbench/evals/matharena/hmmt_feb_2025/hmmt_feb_2025.py (683 B)

from openbench.evals.matharena.matharena import matharena_task
from inspect_ai import Task, task


# instruction: "Please reason step by step, and put your final answer within \\boxed{{}}."
# default_temperature: 0.6
# default_max_tokens: 16000
# strict_parsing: false
# n_problems: 30
# date: "2025-02-15"
# dataset_path: MathArena/hmmt_feb_2025
@task
def hmmt_feb_2025() -> Task:
    return matharena_task(
        dataset_path="MathArena/hmmt_feb_2025",
        instruction="Please reason step by step, and put your final answer within \\boxed{{}}.",
        default_temperature=0.6,
        default_max_tokens=16000,
        default_epochs=4,
        name="hmmt_feb_2025",
    )

src/openbench/evals/matharena/matharena.py (1.3 KiB)

from inspect_ai.dataset import hf_dataset, Sample
from inspect_ai import Task
from inspect_ai.model import GenerateConfig
from openbench.scorers import aime_scorer
from inspect_ai.solver import generate, prompt_template


def matharena_record_to_sample(record: dict) -> Sample:
    return Sample(
        input=record["problem"],
        target=str(record["answer"]),
        id=record["problem_idx"],
        metadata={
            k: v
            for k, v in record.items()
            if k not in ["problem", "answer", "problem_idx"]
        },
    )


def matharena_task(
    dataset_path: str,
    instruction: str,
    name: str,
    default_max_tokens: int,
    default_temperature: float = 0.6,
    default_epochs: int = 4,
) -> Task:
    dataset = hf_dataset(
        path=dataset_path,
        split="train",
        sample_fields=matharena_record_to_sample,
    )

    TEMPLATE = instruction + "\n\n" + "{prompt}"
    return Task(
        dataset=dataset,
        solver=[prompt_template(TEMPLATE), generate()],
        scorer=aime_scorer(),  # Use specialized AIME scorer with robust extraction
        name=name,
        config=GenerateConfig(
            temperature=default_temperature,
            max_tokens=default_max_tokens,
        ),
        epochs=default_epochs,
    )

src/openbench/evals/mgsm.py (2.6 KiB)

"""MGSM (Multilingual Grade School Math) benchmark evaluation.

Based on: Language Models are Multilingual Chain-of-Thought Reasoners
Freda Shi et al., 2022
https://arxiv.org/abs/2210.03057
"""

from inspect_ai import task, Task
from inspect_ai.solver import generate
from inspect_ai.model import GenerateConfig
from openbench.datasets.mgsm import (
    get_dataset,
    LATIN_LANGUAGES,
    NON_LATIN_LANGUAGES,
)
from openbench.scorers.mgsm import mgsm_scorer


@task
def mgsm() -> Task:
    """MGSM: Multilingual Grade School Math - All languages.

    Evaluates mathematical reasoning across 11 languages including
    English, German, French, Spanish, Russian, Chinese, Japanese,
    Thai, Swahili, Bengali, and Telugu.

    Returns:
        Task configured for MGSM evaluation across all languages
    """
    return Task(
        dataset=get_dataset(),
        solver=[generate()],
        scorer=mgsm_scorer(),
        name="mgsm",
        config=GenerateConfig(
            temperature=0.5,  # Simple Evals uses 0.5
            max_tokens=8192,  # Allow space for reasoning steps
        ),
    )


@task
def mgsm_en() -> Task:
    """MGSM: English only.

    Evaluates mathematical reasoning on English problems only.

    Returns:
        Task configured for MGSM evaluation in English
    """
    return Task(
        dataset=get_dataset(languages=["en"]),
        solver=[generate()],
        scorer=mgsm_scorer(),
        name="mgsm_en",
        config=GenerateConfig(
            temperature=0.0,
            max_tokens=8192,
        ),
    )


@task
def mgsm_latin() -> Task:
    """MGSM: Latin script languages.

    Evaluates mathematical reasoning across Latin script languages:
    German, English, Spanish, French, and Swahili.

    Returns:
        Task configured for MGSM evaluation on Latin script languages
    """
    return Task(
        dataset=get_dataset(languages=LATIN_LANGUAGES),
        solver=[generate()],
        scorer=mgsm_scorer(),
        name="mgsm_latin",
        config=GenerateConfig(
            temperature=0.0,
            max_tokens=8192,
        ),
    )


@task
def mgsm_non_latin() -> Task:
    """MGSM: Non-Latin script languages.

    Evaluates mathematical reasoning across non-Latin script languages:
    Bengali, Japanese, Russian, Telugu, Thai, and Chinese.

    Returns:
        Task configured for MGSM evaluation on non-Latin script languages
    """
    return Task(
        dataset=get_dataset(languages=NON_LATIN_LANGUAGES),
        solver=[generate()],
        scorer=mgsm_scorer(),
        name="mgsm_non_latin",
        config=GenerateConfig(
            temperature=0.0,
            max_tokens=8192,
        ),
    )

src/openbench/evals/mmlu.py (509 B)

from inspect_ai import task, Task
from inspect_ai.solver import generate
from inspect_ai.model import GenerateConfig
from openbench.datasets.mmlu import get_dataset
from openbench.scorers.mmlu import mmlu_simple_eval_scorer


@task
def mmlu(language: str = "EN-US") -> Task:
    return Task(
        dataset=get_dataset(language=language),
        solver=[generate()],
        scorer=mmlu_simple_eval_scorer(),
        name="mmlu",
        config=GenerateConfig(
            temperature=0.5,
        ),
    )

src/openbench/evals/mrcr.py (2.9 KiB)

from typing import Optional
from inspect_ai import Task, task
from inspect_ai.model import GenerateConfig
from inspect_ai.solver import generate
from openbench.datasets.mrcr import get_dataset
from openbench.scorers.mrcr import mrcr_scorer


@task
def openai_mrcr(max_context_size: Optional[int] = None) -> Task:
    """Memory-Recall with Contextual Retrieval (MRCR).

    Evaluates retrieval and recall in long contexts by placing 2, 4 or 8 needles in the prompt and measuring whether the
    model can correctly extract and use them.

    Args:
        max_context_size: Maximum context size in tokens. Defaults to None.

    Returns:
        Task configured for MRCR evaluation.
    """

    return Task(
        dataset=get_dataset(max_context_size=max_context_size),
        solver=generate(),
        scorer=mrcr_scorer(),
        name="openai_mrcr",
        config=GenerateConfig(temperature=0.0),
    )


@task
def openai_mrcr_2n(max_context_size: Optional[int] = None) -> Task:
    """Memory-Recall with Contextual Retrieval (MRCR).

    Evaluates retrieval and recall in long contexts by placing 2 needles in the prompt and measuring whether the
    model can correctly extract and use them.

    Args:
        max_context_size: Maximum context size in tokens. Defaults to None.

    Returns:
        Task configured for MRCR 2 needles evaluation.
    """

    return Task(
        dataset=get_dataset(needles=2, max_context_size=max_context_size),
        solver=generate(),
        scorer=mrcr_scorer(),
        name="openai_mrcr_2n",
        config=GenerateConfig(temperature=0.0),
    )


@task
def openai_mrcr_4n(max_context_size: Optional[int] = None) -> Task:
    """Memory-Recall with Contextual Retrieval (MRCR).

    Evaluates retrieval and recall in long contexts by placing 4 needles in the prompt and measuring whether the
    model can correctly extract and use them.

    Args:
        max_context_size: Maximum context size in tokens. Defaults to None.

    Returns:
        Task configured for MRCR 4 needles evaluation.
    """

    return Task(
        dataset=get_dataset(needles=4, max_context_size=max_context_size),
        solver=generate(),
        scorer=mrcr_scorer(),
        name="openai_mrcr_4n",
        config=GenerateConfig(temperature=0.0),
    )


@task
def openai_mrcr_8n(max_context_size: Optional[int] = None) -> Task:
    """Memory-Recall with Contextual Retrieval (MRCR).

    Evaluates retrieval and recall in long contexts by placing 8 needles in the prompt and measuring whether the
    model can correctly extract and use them.

    Args:
        max_context_size: Maximum context size in tokens. Defaults to None.

    Returns:
        Task configured for MRCR 8 needles evaluation.
    """

    return Task(
        dataset=get_dataset(needles=8, max_context_size=max_context_size),
        solver=generate(),
        scorer=mrcr_scorer(),
        name="openai_mrcr_8n",
        config=GenerateConfig(temperature=0.0),
    )

src/openbench/evals/musr.py (4.1 KiB)

"""OpenBench implementation of MuSR (Testing the Limits of Chain-of-thought with Multistep Soft Reasoning).
MuSR is a dataset that tests chain-of-thought reasoning with three types of tasks:
- Murder mysteries: Who is the most likely murderer?
- Object placements: Where would someone look for an object?
- Team allocation: How to allocate people to tasks efficiently?

Implemented by Aarush Sah
"""

import ast
from inspect_ai import Task, task
from inspect_ai.dataset import Sample, hf_dataset
from inspect_ai.scorer import choice
from inspect_ai.solver import multiple_choice
from openbench.scorers.musr import musr_grouped_scorer


def record_to_sample(record: dict, subset: str | None = None) -> Sample:
    # Parse the choices string representation into an actual list
    choices_list = ast.literal_eval(record["choices"])

    metadata = {
        "narrative": record["narrative"],
        "question": record["question"],
        "answer_choice": record["answer_choice"],
        "answer_index": record["answer_index"],
    }

    # Add subset metadata if provided
    if subset:
        metadata["subset"] = subset

    return Sample(
        input=f"{record['narrative']}\n\n{record['question']}",
        choices=choices_list,
        target=chr(ord("A") + int(record["answer_index"])),
        metadata=metadata,
    )


def create_combined_musr_dataset():
    """Create a combined dataset from all three MuSR subsets with subset metadata."""
    all_samples = []
    subsets = ["murder_mysteries", "object_placements", "team_allocation"]

    for subset in subsets:
        # Load each subset and add subset metadata
        subset_dataset = hf_dataset(
            path="TAUR-Lab/MuSR",
            split=subset,
            sample_fields=lambda record, s=subset: record_to_sample(record, s),
        )
        all_samples.extend(subset_dataset)

    return all_samples


@task
def musr(subset: str | None = None) -> Task:
    """
    MuSR (Multistep Soft Reasoning) evaluation task.

    Args:
        subset: The subset of the dataset to use. Options are:
                - None (default): Run all subsets with grouped metrics
                - "murder_mysteries": Murder mystery scenarios only
                - "object_placements": Object placement reasoning only
                - "team_allocation": Team allocation problems only
    """
    if subset is None:
        # Run all subsets with grouped metrics
        return Task(
            dataset=create_combined_musr_dataset(),
            solver=multiple_choice(),
            scorer=musr_grouped_scorer(),
        )
    else:
        # Run specific subset
        if subset not in ["murder_mysteries", "object_placements", "team_allocation"]:
            raise ValueError(
                f"Invalid subset '{subset}'. Must be one of: murder_mysteries, object_placements, team_allocation"
            )

        return Task(
            dataset=hf_dataset(
                path="TAUR-Lab/MuSR",
                split=subset,
                sample_fields=record_to_sample,
            ),
            solver=multiple_choice(),
            scorer=choice(),
        )


@task
def musr_murder_mysteries() -> Task:
    """MuSR Murder Mysteries - Who is the most likely murderer?"""
    return Task(
        dataset=hf_dataset(
            path="TAUR-Lab/MuSR",
            split="murder_mysteries",
            sample_fields=record_to_sample,
        ),
        solver=multiple_choice(),
        scorer=choice(),
    )


@task
def musr_object_placements() -> Task:
    """MuSR Object Placements - Where would someone look for an object?"""
    return Task(
        dataset=hf_dataset(
            path="TAUR-Lab/MuSR",
            split="object_placements",
            sample_fields=record_to_sample,
        ),
        solver=multiple_choice(),
        scorer=choice(),
    )


@task
def musr_team_allocation() -> Task:
    """MuSR Team Allocation - How to allocate people to tasks efficiently?"""
    return Task(
        dataset=hf_dataset(
            path="TAUR-Lab/MuSR",
            split="team_allocation",
            sample_fields=record_to_sample,
        ),
        solver=multiple_choice(),
        scorer=choice(),
    )

src/openbench/evals/openbookqa.py (1.8 KiB)

"""OpenBench implementation of OpenBookQA.

OpenBookQA is an open book question answering dataset modeled after
open book exams for assessing human understanding of a subject. It consists
of 5,957 multiple-choice elementary-level science questions (4,957 train,
500 validation, 500 test), which probe the understanding of a small
"book" of 1,326 core science facts and the application of these facts
to novel situations.

Implemented by Aarush Sah
"""

from inspect_ai import Task, task
from inspect_ai.dataset import hf_dataset, Sample
from inspect_ai.scorer import choice
from inspect_ai.solver import multiple_choice


def record_to_sample(record) -> Sample:
    """Convert a HuggingFace dataset record to an Inspect Sample."""
    return Sample(
        id=record["id"],
        input=record["question_stem"],
        choices=[choice for choice in record["choices"]["text"]],
        target=record["answerKey"],
        metadata={
            # Store the choice labels in metadata for reference
            "choice_labels": record["choices"]["label"],
        },
    )


@task
def openbookqa(split: str = "validation") -> Task:
    """OpenBookQA multiple choice science question evaluation.

    Args:
        split: Dataset split to use ("train", "validation", or "test").
               Defaults to "validation".

    Returns:
        Task: Configured OpenBookQA evaluation task.
    """
    # Validate split parameter
    valid_splits = ["train", "validation", "test"]
    if split not in valid_splits:
        raise ValueError(f"Invalid split '{split}'. Must be one of {valid_splits}")

    # Load dataset from HuggingFace
    dataset = hf_dataset(
        path="allenai/openbookqa",
        split=split,
        sample_fields=record_to_sample,
        trust=True,
    )

    return Task(
        dataset=dataset,
        solver=multiple_choice(),
        scorer=choice(),
    )

src/openbench/evals/rootly_gmcq.py (1.3 KiB)

"""
GitHub Multiple Choice Questions
Authored by:
Rootly AI Labs
Based on: https://huggingface.co/datasets/TheFloatingString/gmcq

# run code generation
bench eval gmcq --model "groq/llama-3.1-8b-instant" --T subtask=mastodon

If subtask is None, then the entire dataset is used.

Please refer to https://huggingface.co/datasets/TheFloatingString/gmcq for the subtask to use.
There are 6 subtasks as of Tuesday, August 19, 2025, and the None option for the entire dataset:

- bluesky
- chroma
- cloudflare
- duckdb
- mastodon
- tailscale
- None
"""

from inspect_ai import Task, task
from inspect_ai.solver import TaskState, Generate, solver
from inspect_ai.model import get_model, GenerateConfig

from openbench.scorers.rootly_gmcq import custom_scorer
from openbench.datasets.rootly_gmcq import load_dataset


@solver
def custom_solver():
    model = get_model()

    async def solve(state: TaskState, generate: Generate) -> TaskState:
        resp = await model.generate(input=state.input)
        state.messages.append(resp.choices[0].message)
        return state

    return solve


@task
def rootly_gmcq(subtask: str = None) -> Task:  # type: ignore
    dataset = load_dataset(subtask)
    return Task(
        dataset=dataset,
        solver=custom_solver(),
        scorer=custom_scorer(),
        config=GenerateConfig(),
    )

src/openbench/evals/scicode.py (10.9 KiB)

"""
SCICode implementation.

Code attribution:

This implementation is adapted from the following repository:
https://github.com/scicode-bench/SciCode

Implemented by Minyang Tian et al.

As of August 13, 2025, this implementation uses the validation split of the dataset, due to a bug with the test split in this implementation.
When 'test' is runnable, revert to 'test'.
"""

import copy
from typing import Any
from pathlib import Path
from inspect_ai import Task, task
from inspect_ai.solver import solver, TaskState, Generate
from scicode.parse.parse import extract_function_name, get_function_from_code  # type: ignore
from scicode.gen.models import generate_dummy_response, extract_python_script  # type: ignore
import requests  # type: ignore


from openbench.datasets.scicode import return_hf_dataset
from openbench.scorers.scicode import scicode_scorer

BACKGOUND_PROMPT_TEMPLATE = requests.get(
    "https://raw.githubusercontent.com/scicode-bench/SciCode/refs/heads/main/eval/data/background_comment_template.txt"
).text
DEFAULT_PROMPT_TEMPLATE = requests.get(
    "https://raw.githubusercontent.com/scicode-bench/SciCode/refs/heads/main/eval/data/multistep_template.txt"
).text


class ScicodePromptingAssistant:
    def __init__(
        self,
        output_dir: Path,
        prompt_dir: Path,
        with_background: bool,
    ):
        self.output_dir = output_dir
        self.prompt_dir = prompt_dir
        self.with_background = with_background
        self.previous_llm_code: list = []

    def _get_background_dir(self):
        return "with_background" if self.with_background else "without_background"

    def register_previous_response(
        self,
        prob_data: dict,
        response: str,
        previous_code: str,
        num_steps: int,
    ):
        self.previous_llm_code[num_steps - 1] = extract_python_script(response)
        self.save_response_with_steps(
            prob_data,
            response,
            previous_code,
            num_steps,
        )

    def save_response_with_steps(
        self, prob_data: dict, response: str, previous_code: str, num_steps: int
    ) -> None:
        output_dir = Path(self.output_dir, self._get_background_dir())
        output_dir.mkdir(parents=True, exist_ok=True)
        prob_id = prob_data["problem_id"]
        output_file_path = output_dir / f"{prob_id}.{num_steps}.py"
        python_code = extract_python_script(response)
        output_file_path.write_text(f"{previous_code}\n{python_code}", encoding="utf-8")

    @staticmethod
    def process_problem_code(prob_data: dict, num_steps: int) -> str:
        header_docstring = prob_data["sub_steps"][num_steps - 1]["function_header"]
        return_str = prob_data["sub_steps"][num_steps - 1]["return_line"]
        string = f"{header_docstring}\n\n{return_str}"
        return string

    def process_problem_steps(self, problem_data: dict, num_steps: int):
        """Process problem data and return previous steps and next steps"""
        output_lines = []
        next_step = []
        previous_code = []
        for i in range(num_steps - 1):
            output_lines.append(
                problem_data["sub_steps"][i]["step_description_prompt"]
                + "\n"
                + problem_data["sub_steps"][i]["step_background"]
                if self.with_background
                else problem_data["sub_steps"][i]["step_description_prompt"]
            )
            output_lines.append(self.previous_llm_code[i])
            previous_code.append(self.previous_llm_code[i])
            output_lines.append("------")

        next_step.append(
            problem_data["sub_steps"][num_steps - 1]["step_description_prompt"]
            + "\n"
            + problem_data["sub_steps"][num_steps - 1]["step_background"]
            if self.with_background
            else problem_data["sub_steps"][num_steps - 1]["step_description_prompt"]
        )
        next_step.append(self.process_problem_code(problem_data, num_steps))
        output_str = "\n\n".join(output_lines[:-1])  # Remove the last "------"
        next_step_str = "\n\n".join(next_step)
        previous_code_str = "\n".join(previous_code)
        return output_str, next_step_str, previous_code_str

    def generate_prompt_with_steps(
        self,
        prob_data: dict,
        num_steps: int,
        prompt_template=DEFAULT_PROMPT_TEMPLATE,
    ):
        # parse the input file and extract the content
        problem_steps_str, next_step_str, previous_code_str = (
            self.process_problem_steps(prob_data, num_steps)
        )
        dependencies = prob_data["required_dependencies"]
        assert next_step_str
        return prompt_template.format(
            problem_steps_str=problem_steps_str,
            next_step_str=next_step_str,
            dependencies=dependencies,
        ), f"{dependencies}\n{previous_code_str}\n"

    def save_prompt_with_steps(
        self, prob_data: dict, prompt: str, num_steps: int
    ) -> None:
        output_dir = Path(self.prompt_dir, self._get_background_dir())
        output_dir.mkdir(parents=True, exist_ok=True)
        output_file_path = output_dir / f"{prob_data['problem_id']}.{num_steps}.txt"
        output_file_path.write_text(prompt, encoding="utf-8")

    def prepare_final_prompt_with_steps(
        self,
        prob_data: dict,
        num_steps: int,
        tot_steps: int,
        prompt_template=DEFAULT_PROMPT_TEMPLATE,
        *,
        save: bool = True,
    ):
        prob_id = prob_data["problem_id"]
        if num_steps == 1:
            self.previous_llm_code = [None] * tot_steps
        else:
            if len(self.previous_llm_code) != tot_steps:
                self.previous_llm_code = [None] * tot_steps
            for prev_step in range(num_steps - 1):
                if self.previous_llm_code[prev_step] is None:
                    if (
                        (prob_id == "13" and prev_step == 5)
                        or (prob_id == "62" and prev_step == 0)
                        or (prob_id == "76" and prev_step == 2)
                    ):
                        prev_file_path = Path(
                            "../data", f"{prob_id}.{prev_step + 1}.txt"
                        )
                    else:
                        prev_file_path = Path(
                            self.output_dir,
                            self._get_background_dir(),
                            f"{prob_id}.{prev_step + 1}.py",
                        )
                    if prev_file_path.is_file():
                        prev_file_content = prev_file_path.read_text(encoding="utf-8")
                        func_name = extract_function_name(
                            prob_data["sub_steps"][prev_step]["function_header"]
                        )
                        function_code = get_function_from_code(
                            prev_file_content, func_name
                        )
                        self.previous_llm_code[prev_step] = function_code
                    else:
                        # print(f"Generating problem {prob_id} step {num_steps} ahead of step {prev_step + 1}.")
                        raise Exception(
                            f"Generating problem {prob_id} step {num_steps} ahead of step {prev_step + 1}."
                        )

        prompt, previous_code = self.generate_prompt_with_steps(
            prob_data,
            num_steps,
            prompt_template,
        )
        if save:
            self.save_prompt_with_steps(
                prob_data,
                prompt,
                num_steps,
            )
        return prompt, previous_code


def generate_gold_response(prob_data: dict, num_steps: int):
    return f"Blah blah\n```python\n{prob_data['sub_steps'][num_steps - 1]['ground_truth_code']}\n```\n"


@solver
def scicode_solver(**params: dict[str, Any]):
    async def solve(state: TaskState, generate: Generate) -> TaskState:
        model_name = str(state.model).replace("/", "-")
        prompt_assistant = ScicodePromptingAssistant(
            output_dir=Path(params["output_dir"], model_name, "generated_code"),  # type: ignore
            prompt_dir=Path(params["output_dir"], model_name, "prompt"),  # type: ignore
            with_background=params["with_background"],  # type: ignore
        )
        prompt_template = (
            BACKGOUND_PROMPT_TEMPLATE
            if params["with_background"]
            else DEFAULT_PROMPT_TEMPLATE
        )
        sub_steps = state.metadata["sub_steps"]
        for idx in range(len(sub_steps)):
            prob_id = state.metadata["problem_id"]
            if (
                (prob_id == "13" and idx == 5)
                or (prob_id == "62" and idx == 0)
                or (prob_id == "76" and idx == 2)
            ):
                continue
            prompt, previous_code = prompt_assistant.prepare_final_prompt_with_steps(
                prob_data=state.metadata,
                num_steps=idx + 1,
                tot_steps=len(sub_steps),
                prompt_template=prompt_template,
            )
            if params["mode"] == "dummy":
                response_from_llm = generate_dummy_response(prompt)
            elif params["mode"] == "gold":
                response_from_llm = generate_gold_response(state.metadata, idx + 1)
            else:
                try:
                    # ===Model Generation===
                    state.user_prompt.text = prompt
                    state_copy = copy.deepcopy(state)
                    result = await generate(state=state_copy)
                    response_from_llm = result.output.completion
                    # ===Model Generation===
                except:  # noqa
                    print(
                        f"Failed to generate response for problem {prob_id} step {idx + 1}."
                    )
                    response_from_llm = generate_dummy_response(prompt)
            prompt_assistant.register_previous_response(
                prob_data=state.metadata,
                response=response_from_llm,
                previous_code=previous_code,
                num_steps=idx + 1,
            )
        return state

    return solve


@task
def scicode(
    split: str = "validation",  # TODO: when 'test' is runnable, revert to 'test'
    output_dir: str = "./tmp",
    with_background: bool = False,
    h5py_file: str = "../data/test_data.h5",
    mode: str = "normal",
):
    print(
        "As of August 13, 2025, this implementation uses the validation split of the dataset, due to a bug with the test split in this implementation."
    )
    print("When 'test' is runnable, revert to 'test'.")

    return Task(
        dataset=return_hf_dataset(split),
        solver=scicode_solver(
            output_dir=output_dir,  # type: ignore
            with_background=with_background,  # type: ignore
            mode=mode,  # type: ignore
        ),
        scorer=scicode_scorer(
            output_dir=output_dir,  # type: ignore
            with_background=with_background,  # type: ignore
            h5py_file=h5py_file,  # type: ignore
        ),
    )

src/openbench/evals/simpleqa.py (943 B)

from inspect_ai import task, Task
from inspect_ai.solver import generate
from inspect_ai.model import GenerateConfig
from openbench.datasets.simpleqa import get_dataset
from openbench.scorers.simpleqa import simpleqa_scorer


@task
def simpleqa(grader_model: str = "openai/gpt-4.1-2025-04-14") -> Task:
    """SimpleQA: Measuring short-form factuality in large language models.

    Based on the paper by Wei et al. (2024).
    Uses model-based grading to assess factual accuracy of responses.

    Args:
        grader_model: Model to use for grading responses (defaults to gpt-4.1-2025-04-14)

    Returns:
        Task configured for SimpleQA evaluation
    """
    return Task(
        dataset=get_dataset(),
        solver=[generate()],
        scorer=simpleqa_scorer(model=grader_model),
        name="simpleqa",
        config=GenerateConfig(
            temperature=0.0,  # Use deterministic generation for factual QA
        ),
    )

src/openbench/evals/supergpqa.py (2.9 KiB)

"""OpenBench implementation of SuperGPQA.
SuperGPQA: Scaling LLM Evaluation across 285 Graduate Disciplines

Implemented by Aarush Sah
"""

from inspect_ai import Task, task
from inspect_ai.dataset import Sample, hf_dataset
from inspect_ai.scorer import choice, accuracy, stderr, grouped
from inspect_ai.solver import multiple_choice


def record_to_sample(record):
    """Convert a SuperGPQA record to an Inspect Sample."""
    # Create choices list from options
    choices = record["options"]

    # Create metadata dict with all extra fields
    metadata = {
        "uuid": record["uuid"],
        "discipline": record["discipline"],
        "field": record["field"],
        "subfield": record["subfield"],
        "difficulty": record["difficulty"],
        "is_calculation": record["is_calculation"],
        "answer_text": record["answer"],  # Store the full answer text
    }

    return Sample(
        input=record["question"],
        choices=choices,
        target=record["answer_letter"],  # Use the letter (A, B, C, etc.) as target
        metadata=metadata,
    )


@task
def supergpqa(
    field: str | None = None,
    subfield: str | None = None,
    difficulty: str | None = None,
    discipline: str | None = None,
):
    """SuperGPQA dataset task.

    SuperGPQA is a dataset for evaluating LLMs across 285 graduate disciplines
    with 26,529 multiple-choice questions spanning various fields including
    science, engineering, medicine, economics, and philosophy.

    Args:
        field: Filter by field (e.g., "Mathematics", "Physics", "Computer Science and Technology")
        subfield: Filter by subfield (e.g., "Mathematical Analysis", "Quantum Mechanics")
        difficulty: Filter by difficulty level ("easy", "middle", "hard")
        discipline: Filter by discipline (e.g., "Science", "Engineering", "Medicine")
    """
    # Load the full dataset
    dataset = hf_dataset(
        path="m-a-p/SuperGPQA",
        split="train",  # Only train split is available
        sample_fields=record_to_sample,
    )

    # Apply filters if specified
    if any([field, subfield, difficulty, discipline]):

        def filter_fn(sample):
            if field and sample.metadata.get("field") != field:
                return False
            if subfield and sample.metadata.get("subfield") != subfield:
                return False
            if difficulty and sample.metadata.get("difficulty") != difficulty:
                return False
            if discipline and sample.metadata.get("discipline") != discipline:
                return False
            return True

        dataset = dataset.filter(filter_fn)

    return Task(
        dataset=dataset,
        solver=multiple_choice(),
        scorer=choice(),
        metrics=[
            # Overall metrics
            accuracy(),
            stderr(),
            # Metrics grouped by difficulty
            grouped(accuracy(), "difficulty"),
            grouped(stderr(), "difficulty"),
        ],
    )

src/openbench/metrics/__init__.py (102 B)

"""Common, reusable metrics."""

from openbench.metrics.grouped import grouped

__all__ = ["grouped"]

src/openbench/metrics/grouped.py (4.5 KiB)

from typing import Callable, Literal, Sequence, cast, Optional

import numpy as np

from inspect_ai.scorer._metric import (
    Metric,
    MetricProtocol,
    SampleScore,
    Value,
    ValueToFloat,
    metric,
    value_to_float,
    registry_info,
)


# Forked from the grouped metric in https://github.com/UKGovernmentBEIS/inspect_ai
@metric
def grouped(
    metric: Metric | Sequence[Metric],
    group_key: str,
    *,
    all: Literal["samples", "groups"] | Literal[False] = "samples",
    all_label: str = "all",
    value_to_float: ValueToFloat = value_to_float(),
    group_namer: Optional[Callable[[str, str], str]] = None,
) -> Metric:
    """
    Create a grouped metric that applies the given metric(s) to subgroups of samples.

    This function groups samples based on a metadata key and applies one or more metrics
    to each group. Optionally, it computes an aggregate score across all samples or all
    group-level scores.

    Args:
        metric: A metric or list of metrics to apply to each group of samples.
        group_key: Metadata key used to group samples.
        all: Determines whether and how to compute an aggregate "all" score.
            - "samples": Apply the metric(s) to all samples regardless of groups.
            - "groups": Calculate the mean of all group scores.
            - False: Do not compute an aggregate score.
        all_label: Label to use for the aggregate score when a single metric is used.
        value_to_float: Function to convert metric values to floats (used for averaging group scores).
        group_namer: Optional function to generate group-specific metric names; receives (group, metric_name).

    Returns:
        Metric: A new metric function returning a dictionary mapping group names (and optionally
        an "all" aggregate key) to their respective scores.
    """

    def grouped_metric(scores: list[SampleScore]) -> Value:
        # Normalize to list of metrics
        metrics = metric if isinstance(metric, (list, tuple)) else [metric]
        metric_names = [registry_info(m).name for m in metrics]
        short_names = [name.split("/")[-1] for name in metric_names]

        # Use default group_namer if none provided
        nonlocal group_namer, all_label
        if group_namer is None:

            def default_group_namer(group: str, metric_name: str) -> str:
                return f"{group}_{metric_name}"

            group_namer = default_group_namer

        # If only one metric and user didn't override label, use that metric's name
        if all_label == "all" and len(metrics) == 1:
            all_label = short_names[0]

        # Build map of group name → list of sample scores
        scores_by_group: dict[str, list[SampleScore]] = {}
        for score in scores:
            if score.sample_metadata is None or group_key not in score.sample_metadata:
                raise ValueError(
                    f"Sample {score.sample_id} has no '{group_key}' in metadata. "
                    "All samples must include this key to compute grouped metrics."
                )
            group_name = str(score.sample_metadata[group_key])
            scores_by_group.setdefault(group_name, []).append(score)

        # If requested, compute aggregate metric over all samples before group metrics
        all_metrics: dict[str, Value] = {}
        if all == "samples":
            for m, short_name in zip(metrics, short_names):
                key = all_label if len(metrics) == 1 else short_name
                all_metrics[key] = cast(MetricProtocol, m)(scores)

        # Compute metric for each group
        grouped_scores: dict[str, Value] = {}
        for group_name, group_scores in scores_by_group.items():
            for m, short_name in zip(metrics, short_names):
                key = group_namer(group_name, short_name)
                grouped_scores[key] = cast(MetricProtocol, m)(group_scores)

        # If requested, compute aggregate metric from group scores
        if all == "groups":
            for m, short_name in zip(metrics, short_names):
                key = all_label if len(metrics) == 1 else short_name
                group_keys = [group_namer(g, short_name) for g in scores_by_group]
                values = [value_to_float(grouped_scores[k]) for k in group_keys]
                all_metrics[key] = float(np.mean(values)) if values else 0.0

        # Return combined results
        if all is False:
            return cast(Value, grouped_scores)
        return cast(Value, {**all_metrics, **grouped_scores})

    return grouped_metric

src/openbench/model/__init__.py (33 B)

"""OpenBench model providers."""

src/openbench/model/_providers/__init__.py (40 B)

"""OpenBench custom model providers."""

src/openbench/model/_providers/ai21.py (1.5 KiB)

"""AI21 Labs provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class AI21API(OpenAICompatibleAPI):
    """AI21 Labs provider - advanced language model infrastructure.

    Uses OpenAI-compatible API with AI21-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("ai21/", "", 1)

        # Set defaults for AI21
        base_url = base_url or os.environ.get(
            "AI21_BASE_URL", "https://api.ai21.com/studio/v1"
        )
        api_key = api_key or os.environ.get("AI21_API_KEY")

        if not api_key:
            raise ValueError(
                "AI21 API key not found. Set AI21_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="ai21",
            service_base_url="https://api.ai21.com/studio/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/baseten.py (1.5 KiB)

"""Baseten AI provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class BasetenAPI(OpenAICompatibleAPI):
    """Baseten AI provider - OpenAI-compatible inference.

    Uses OpenAI-compatible API with Baseten-specific configuration.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("baseten/", "", 1)

        # Set defaults for Baseten
        base_url = base_url or os.environ.get(
            "BASETEN_BASE_URL", "https://inference.baseten.co/v1"
        )
        api_key = api_key or os.environ.get("BASETEN_API_KEY")

        if not api_key:
            raise ValueError(
                "Baseten API key not found. Set BASETEN_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="baseten",
            service_base_url="https://inference.baseten.co/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/cerebras.py (1.5 KiB)

"""Cerebras AI provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class CerebrasAPI(OpenAICompatibleAPI):
    """Cerebras AI provider - high-performance inference.

    Uses OpenAI-compatible API with Cerebras-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("cerebras/", "", 1)

        # Set defaults for Cerebras
        base_url = base_url or os.environ.get(
            "CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1"
        )
        api_key = api_key or os.environ.get("CEREBRAS_API_KEY")

        if not api_key:
            raise ValueError(
                "Cerebras API key not found. Set CEREBRAS_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="cerebras",
            service_base_url="https://api.cerebras.ai/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/cohere.py (1.5 KiB)

"""Cohere provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class CohereAPI(OpenAICompatibleAPI):
    """Cohere provider - enterprise-ready language model infrastructure.

    Uses OpenAI-compatible API with Cohere-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("cohere/", "", 1)

        # Set defaults for Cohere
        base_url = base_url or os.environ.get(
            "COHERE_BASE_URL", "https://api.cohere.ai/compatibility/v1"
        )
        api_key = api_key or os.environ.get("COHERE_API_KEY")

        if not api_key:
            raise ValueError(
                "Cohere API key not found. Set COHERE_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="cohere",
            service_base_url="https://api.cohere.ai/compatibility/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/crusoe.py (1.5 KiB)

"""Crusoe AI provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class CrusoeAPI(OpenAICompatibleAPI):
    """Crusoe AI provider - cloud infrastructure for AI workloads.

    Uses OpenAI-compatible API with Crusoe-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("crusoe/", "", 1)

        # Set defaults for Crusoe
        base_url = base_url or os.environ.get(
            "CRUSOE_BASE_URL", "https://api.crusoe.ai/v1"
        )
        api_key = api_key or os.environ.get("CRUSOE_API_KEY")

        if not api_key:
            raise ValueError(
                "Crusoe API key not found. Set CRUSOE_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="crusoe",
            service_base_url="https://api.crusoe.ai/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/deepinfra.py (1.5 KiB)

"""DeepInfra AI provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class DeepInfraAPI(OpenAICompatibleAPI):
    """DeepInfra AI provider - scalable inference infrastructure.

    Uses OpenAI-compatible API with DeepInfra-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("deepinfra/", "", 1)

        # Set defaults for DeepInfra
        base_url = base_url or os.environ.get(
            "DEEPINFRA_BASE_URL", "https://api.deepinfra.com/v1/openai"
        )
        api_key = api_key or os.environ.get("DEEPINFRA_API_KEY")

        if not api_key:
            raise ValueError(
                "DeepInfra API key not found. Set DEEPINFRA_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="deepinfra",
            service_base_url="https://api.deepinfra.com/v1/openai",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/friendli.py (1.5 KiB)

"""Friendli provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class FriendliAPI(OpenAICompatibleAPI):
    """Friendli provider - fast and efficient inference infrastructure.

    Uses OpenAI-compatible API with Friendli-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("friendli/", "", 1)

        # Set defaults for Friendli
        base_url = base_url or os.environ.get(
            "FRIENDLI_BASE_URL", "https://api.friendli.ai/serverless/v1"
        )
        api_key = api_key or os.environ.get("FRIENDLI_API_KEY")

        if not api_key:
            raise ValueError(
                "Friendli API key not found. Set FRIENDLI_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="friendli",
            service_base_url="https://api.friendli.ai/serverless/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/huggingface.py (1.9 KiB)

"""Hugging Face Inference Providers (OpenAI-compatible) provider.

Uses the Hugging Face Inference Providers router over the OpenAI-compatible
API, as documented in the GPT OSS guide:

Reference: https://huggingface.co/docs/inference-providers/en/guides/gpt-oss

Environment variables:
  - HF_TOKEN: Hugging Face access token used as Bearer token
  - HF_ROUTER_BASE_URL: Optional override for base URL (defaults to
    https://router.huggingface.co/v1)

Model naming follows the HF router format, e.g.:
  - openai/gpt-oss-120b:cerebras
  - openai/gpt-oss-120b:fireworks-ai
"""

from __future__ import annotations

import os
from typing import Any

from inspect_ai.model import (  # type: ignore[import-not-found]
    GenerateConfig,
)
from inspect_ai.model._providers.openai_compatible import (  # type: ignore[import-not-found]
    OpenAICompatibleAPI,
)


class HFInferenceProvidersAPI(OpenAICompatibleAPI):
    """Hugging Face Inference Providers API."""

    DEFAULT_BASE_URL = "https://router.huggingface.co/v1"

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Remove provider prefix
        model_name_clean = model_name.replace("huggingface/", "", 1)

        base_url = base_url or os.environ.get("HF_BASE_URL") or self.DEFAULT_BASE_URL
        api_key = api_key or os.environ.get("HF_TOKEN")

        if not api_key:
            raise ValueError(
                "HF_TOKEN not set. Get a token from your HF settings page."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="huggingface",
            service_base_url=self.DEFAULT_BASE_URL,
            **model_args,
        )

    def service_model_name(self) -> str:
        return self.model_name

src/openbench/model/_providers/hyperbolic.py (1.5 KiB)

"""Hyperbolic AI provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class HyperbolicAPI(OpenAICompatibleAPI):
    """Hyperbolic AI provider - OpenAI-compatible inference.

    Uses OpenAI-compatible API with Hyperbolic-specific configuration.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("hyperbolic/", "", 1)

        # Set defaults for Hyperbolic
        base_url = base_url or os.environ.get(
            "HYPERBOLIC_BASE_URL", "https://api.hyperbolic.xyz/v1"
        )
        api_key = api_key or os.environ.get("HYPERBOLIC_API_KEY")

        if not api_key:
            raise ValueError(
                "Hyperbolic API key not found. Set HYPERBOLIC_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="hyperbolic",
            service_base_url="https://api.hyperbolic.xyz/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/lambda_ai.py (1.4 KiB)

"""Lambda provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class LambdaAPI(OpenAICompatibleAPI):
    """Lambda - GPU cloud inference provider.

    Uses OpenAI-compatible API with Lambda's endpoints.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("lambda/", "", 1)

        # Set defaults for Lambda
        base_url = base_url or os.environ.get(
            "LAMBDA_BASE_URL", "https://api.lambda.ai/v1"
        )
        api_key = api_key or os.environ.get("LAMBDA_API_KEY")

        if not api_key:
            raise ValueError(
                "Lambda API key not found. Set LAMBDA_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="lambda",
            service_base_url="https://api.lambda.ai/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/minimax.py (1.5 KiB)

"""MiniMax provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class MiniMaxAPI(OpenAICompatibleAPI):
    """MiniMax provider - AI model infrastructure.

    Uses OpenAI-compatible API with MiniMax-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("minimax/", "", 1)

        # Set defaults for MiniMax
        base_url = base_url or os.environ.get(
            "MINIMAX_BASE_URL", "https://api.minimax.io/v1"
        )
        api_key = api_key or os.environ.get("MINIMAX_API_KEY")

        if not api_key:
            raise ValueError(
                "MiniMax API key not found. Set MINIMAX_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="minimax",
            service_base_url="https://api.minimax.io/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/moonshot.py (1.5 KiB)

"""Moonshot provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class MoonshotAPI(OpenAICompatibleAPI):
    """Moonshot provider - AI language model infrastructure.

    Uses OpenAI-compatible API with Moonshot-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("moonshot/", "", 1)

        # Set defaults for Moonshot
        base_url = base_url or os.environ.get(
            "MOONSHOT_BASE_URL", "https://api.moonshot.ai/v1"
        )
        api_key = api_key or os.environ.get("MOONSHOT_API_KEY")

        if not api_key:
            raise ValueError(
                "Moonshot API key not found. Set MOONSHOT_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="moonshot",
            service_base_url="https://api.moonshot.ai/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/nebius.py (1.5 KiB)

"""Nebius AI provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class NebiusAPI(OpenAICompatibleAPI):
    """Nebius AI provider - OpenAI-compatible inference.

    Uses OpenAI-compatible API with Nebius Studio endpoints.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("nebius/", "", 1)

        # Set defaults for Nebius
        base_url = base_url or os.environ.get(
            "NEBIUS_BASE_URL", "https://api.studio.nebius.com/v1"
        )
        api_key = api_key or os.environ.get("NEBIUS_API_KEY")

        if not api_key:
            raise ValueError(
                "Nebius API key not found. Set NEBIUS_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="nebius",
            service_base_url="https://api.studio.nebius.com/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/nous.py (1.5 KiB)

"""Nous Research provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class NousAPI(OpenAICompatibleAPI):
    """Nous Research - Advanced AI inference.

    Uses OpenAI-compatible API with Nous Research endpoints.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("nous/", "", 1)

        # Set defaults for Nous Research
        base_url = base_url or os.environ.get(
            "NOUS_BASE_URL", "https://inference-api.nousresearch.com/v1"
        )
        api_key = api_key or os.environ.get("NOUS_API_KEY")

        if not api_key:
            raise ValueError(
                "Nous Research API key not found. Set NOUS_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="nous",
            service_base_url="https://inference-api.nousresearch.com/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/novita.py (1.5 KiB)

"""Novita AI provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class NovitaAPI(OpenAICompatibleAPI):
    """Novita AI provider - OpenAI-compatible inference.

    Uses OpenAI-compatible API with Novita-specific configuration.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("novita/", "", 1)

        # Set defaults for Novita
        base_url = base_url or os.environ.get(
            "NOVITA_BASE_URL", "https://api.novita.ai/openai/v1"
        )
        api_key = api_key or os.environ.get("NOVITA_API_KEY")

        if not api_key:
            raise ValueError(
                "Novita API key not found. Set NOVITA_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="novita",
            service_base_url="https://api.novita.ai/openai/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/parasail.py (1.5 KiB)

"""Parasail AI provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class ParasailAPI(OpenAICompatibleAPI):
    """Parasail AI provider - OpenAI-compatible inference.

    Uses OpenAI-compatible API with Parasail-specific configuration.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("parasail/", "", 1)

        # Set defaults for Parasail
        base_url = base_url or os.environ.get(
            "PARASAIL_BASE_URL", "https://api.parasail.io/v1"
        )
        api_key = api_key or os.environ.get("PARASAIL_API_KEY")

        if not api_key:
            raise ValueError(
                "Parasail API key not found. Set PARASAIL_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="parasail",
            service_base_url="https://api.parasail.io/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/reka.py (1.4 KiB)

"""Reka provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class RekaAPI(OpenAICompatibleAPI):
    """Reka provider - multimodal AI model infrastructure.

    Uses OpenAI-compatible API with Reka-specific optimizations.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("reka/", "", 1)

        # Set defaults for Reka
        base_url = base_url or os.environ.get("REKA_BASE_URL", "https://api.reka.ai/v1")
        api_key = api_key or os.environ.get("REKA_API_KEY")

        if not api_key:
            raise ValueError(
                "Reka API key not found. Set REKA_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="reka",
            service_base_url="https://api.reka.ai/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/sambanova.py (1.5 KiB)

"""SambaNova provider implementation."""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class SambaNovaAPI(OpenAICompatibleAPI):
    """SambaNova - Enterprise AI acceleration.

    Uses OpenAI-compatible API with SambaNova's endpoints.
    """

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Extract model name without service prefix
        model_name_clean = model_name.replace("sambanova/", "", 1)

        # Set defaults for SambaNova
        base_url = base_url or os.environ.get(
            "SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1"
        )
        api_key = api_key or os.environ.get("SAMBANOVA_API_KEY")

        if not api_key:
            raise ValueError(
                "SambaNova API key not found. Set SAMBANOVA_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="sambanova",
            service_base_url="https://api.sambanova.ai/v1",
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/model/_providers/vercel.py (2.2 KiB)

"""Vercel AI Gateway provider implementation.

The AI Gateway provides OpenAI-compatible API endpoints, letting you use multiple
AI providers through a familiar interface. The AI Gateway can route requests across
multiple AI providers for better reliability and performance.

Environment variables:
  - AI_GATEWAY_API_KEY: AI Gateway API key (required)
  - AI_GATEWAY_BASE_URL: Override the default base URL (defaults to
    https://ai-gateway.vercel.sh/v1)

Model naming follows the creator/model format, e.g.:
  - anthropic/claude-sonnet-4
  - openai/gpt-4.1-mini
  - meta/llama-3.3-70b-instruct

Website: https://vercel.com/ai-gateway
Reference: https://vercel.com/docs/ai-gateway/openai-compatible-api
"""

import os
from typing import Any

from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
from inspect_ai.model import GenerateConfig


class VercelAPI(OpenAICompatibleAPI):
    """Vercel AI Gateway provider - OpenAI-compatible API with multi-provider routing."""

    DEFAULT_BASE_URL = "https://ai-gateway.vercel.sh/v1"

    def __init__(
        self,
        model_name: str,
        base_url: str | None = None,
        api_key: str | None = None,
        config: GenerateConfig = GenerateConfig(),
        **model_args: Any,
    ) -> None:
        # Remove provider prefix if present
        # Result is in creator/model format
        model_name_clean = model_name.replace("vercel/", "", 1)

        base_url = (
            base_url or os.environ.get("AI_GATEWAY_BASE_URL") or self.DEFAULT_BASE_URL
        )
        api_key = (
            api_key
            or os.environ.get("AI_GATEWAY_API_KEY")
            or os.environ.get("VERCEL_OIDC_TOKEN")
        )

        if not api_key:
            raise ValueError(
                "Vercel AI Gateway API key not found. Set the AI_GATEWAY_API_KEY environment variable."
            )

        super().__init__(
            model_name=model_name_clean,
            base_url=base_url,
            api_key=api_key,
            config=config,
            service="vercel",
            service_base_url=self.DEFAULT_BASE_URL,
            **model_args,
        )

    def service_model_name(self) -> str:
        """Return model name without service prefix."""
        return self.model_name

src/openbench/monkeypatch/__init__.py (0 B)


src/openbench/monkeypatch/display_results_patch.py (2.0 KiB)

"""
Monkey patch for inspect_ai display results to show "bench eval-retry" instead of "inspect eval-retry".

Usage:
    from openbench.monkeypatch.display_results_patch import patch_display_results
    patch_display_results()

Call this before invoking inspect_ai.eval_retry().
"""


def patch_display_results():
    """
    Monkey patch inspect_ai display functions to replace "inspect eval-retry" with "bench eval-retry".
    """
    try:
        import inspect_ai._display.core.results as results_mod

        # Store original function
        original_task_interrupted = results_mod.task_interrupted

        def custom_task_interrupted(profile, samples_completed):  # type: ignore
            # Call original function
            result = original_task_interrupted(profile, samples_completed)

            # If result is a string, replace the text
            if isinstance(result, str):
                result = result.replace("inspect eval-retry", "bench eval-retry")
            # If it's a Text object from rich, we need to handle it differently
            elif hasattr(result, "_text") and isinstance(result._text, list):
                # Rich Text objects store segments internally
                for i, segment in enumerate(result._text):
                    if isinstance(segment, tuple) and len(segment) >= 1:
                        text = segment[0]
                        if isinstance(text, str) and "inspect eval-retry" in text:
                            # Create a new segment with replaced text
                            new_text = text.replace(
                                "inspect eval-retry", "bench eval-retry"
                            )
                            result._text[i] = (new_text,) + segment[1:]

            return result

        # Apply patch
        results_mod.task_interrupted = custom_task_interrupted

    except (ImportError, AttributeError):
        # If inspect_ai is not installed or the module structure changed, silently continue
        pass

src/openbench/monkeypatch/file_recorder_logfile_patch.py (737 B)

"""
Monkey patch for inspect_ai FileRecorder to allow setting a custom logfile name.

Usage:
    from openbench.monkeypatch.file_recorder_logfile_patch import patch_file_recorder_logfile
    patch_file_recorder_logfile(logfile)

Call this before invoking inspect_ai.eval().
"""


def patch_file_recorder_logfile(logfile: str):
    """
    Monkey patch FileRecorder._log_file_key to always use the provided logfile name.
    Args:
        logfile: The desired logfile name (without extension).
    """
    import inspect_ai.log._recorders.file as file_recorder_mod

    def custom_log_file_key(self, eval_spec):
        return logfile

    file_recorder_mod.FileRecorder._log_file_key = custom_log_file_key  # type: ignore[method-assign]

src/openbench/py.typed (0 B)


src/openbench/scorers/__init__.py (1.1 KiB)

"""
Scorers module for bench evaluation framework.

This module provides various scoring functions for evaluating model outputs
across different types of tasks including mathematical problems, code generation,
and multiple choice questions.

Note: Eval-specific scorers (like those for MMLU and HumanEval) are not exported
from this module intentionally, as they are tailored to specific evaluation tasks
and should be imported directly from their respective modules when needed.
"""

from .score_last_number import score_last_integer, score_last_number
from .score_boxed import score_boxed
from .fallback_scorer import fallback_scorer
from .robust_mcq import robust_mcq_scorer, extract_mcq_answer
from .robust_boxed import robust_boxed_scorer, aime_scorer, extract_boxed_answer

__all__ = [
    # Number scoring functions
    "score_last_integer",
    "score_last_number",
    # Boxed answer scoring
    "score_boxed",
    "robust_boxed_scorer",
    "aime_scorer",
    # Multiple choice scoring
    "robust_mcq_scorer",
    "extract_mcq_answer",
    # Utility functions
    "extract_boxed_answer",
    # Meta-scoring
    "fallback_scorer",
]

src/openbench/scorers/browsecomp.py (4.2 KiB)

"""Scorer for BrowseComp evaluation."""

import re
from typing import Callable
from inspect_ai.scorer import (
    accuracy,
    scorer,
    stderr,
    Score,
    Target,
)
from inspect_ai.solver import TaskState
from inspect_ai.model import get_model, ChatMessageUser, Model


# Grading template from the original BrowseComp implementation
GRADER_TEMPLATE = """
Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below.

[question]: {question}

[response]: {response}

Your judgement must be in the format and criteria specified below:

extracted_final_answer: The final exact answer extracted from the [response]. Put the extracted answer as 'None' if there is no exact, final answer to extract from the response.

[correct_answer]: {correct_answer}

reasoning: Explain why the extracted_final_answer is correct or incorrect based on [correct_answer], focusing only on if there are meaningful differences between [correct_answer] and the extracted_final_answer. Do not comment on any background to the problem, do not attempt to solve the problem, do not argue for any answer different than [correct_answer], focus only on whether the answers match.

correct: Answer 'yes' if extracted_final_answer matches the [correct_answer] given above, or is within a small margin of error for numerical problems. Answer 'no' otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the extracted answer is incorrect.


confidence: The extracted confidence score between 0|%| and 100|%| from [response]. Put 100 if there is no confidence score available.
""".strip()


# Query template from the original BrowseComp implementation
QUERY_TEMPLATE = """
{Question}

Your response should be in the following format:
Explanation: {{your explanation for your final answer}}
Exact Answer: {{your succinct, final answer}}
Confidence: {{your confidence score between 0% and 100% for your answer}}
""".strip()


@scorer(metrics=[accuracy(), stderr()])
def browsecomp_scorer(model: str = "openai/gpt-4.1-2025-04-14") -> Callable:
    """BrowseComp scorer using model grading.

    Args:
        model: Model to use for grading responses (defaults to gpt-4.1-2025-04-14)

    Returns:
        Scorer function for BrowseComp evaluation
    """
    grader_model: Model = get_model(model)

    async def score(state: TaskState, target: Target) -> Score:
        # Get the plain question from metadata (not the formatted input)
        # This matches the simple-evals implementation where the grader gets the plain question
        question = state.metadata.get("plain_question", state.input_text)

        # Get the predicted answer from the model output
        predicted_answer = state.output.completion

        # Format the grading prompt
        grader_prompt = GRADER_TEMPLATE.format(
            question=question,
            correct_answer=target.text,
            response=predicted_answer,
        )

        # Create the message for grading
        message = ChatMessageUser(content=grader_prompt)

        # Get grading response
        grading_response = await grader_model.generate([message])
        grading_text = grading_response.completion

        # Extract whether the answer is correct
        # Look for "correct: yes" or "correct: no" in the response
        match = re.search(r"correct:\s*(yes|no)", grading_text, re.IGNORECASE)
        is_correct = (
            (match.group(1).lower() == "yes") if (match and match.group(1)) else False
        )

        # Extract confidence if available
        confidence_match = re.search(
            r"confidence:\s*(\d+)(?:\s*%)?", grading_text, re.IGNORECASE
        )
        confidence = (
            int(confidence_match.group(1))
            if (confidence_match and confidence_match.group(1))
            else 100
        )  # Default to 100 if not found

        # Return score with metadata
        return Score(
            value=1.0 if is_correct else 0.0,
            answer=predicted_answer,
            metadata={
                "is_correct": is_correct,
                "confidence": confidence,
                "grading_response": grading_text,
            },
        )

    return score

src/openbench/scorers/cti_bench.py (14.9 KiB)

"""CTI-Bench scorers for cybersecurity benchmarking tasks."""

import re
from typing import Callable, List, Set
from inspect_ai.scorer import scorer, accuracy, stderr, Score, Target, Metric, metric
from inspect_ai.solver import TaskState
from inspect_ai.scorer._metric import SampleScore, Value


# ATE (ATT&CK Technique Extraction) Functions
def extract_technique_ids(text: str) -> Set[str]:
    """Extract MITRE ATT&CK technique IDs from model output."""
    if not text:
        return set()

    technique_ids = set()
    text_upper = text.upper()

    # Single comprehensive pattern for all T-ID formats
    all_patterns = [
        r"\bT\d{4}(?:\.\d{3})?\b",  # Basic T1234 or T1234.001
        r"(?:technique\s+)?(T\d{4})(?:\.\d{3})?(?:\s*[:\-,.]|\s|$)",  # Context patterns
    ]

    # Extract from all patterns
    for pattern in all_patterns:
        matches = re.findall(pattern, text_upper, re.IGNORECASE)
        for match in matches:
            # Extract main technique ID (remove subtechnique if present)
            main_technique = match.split(".")[0]
            technique_ids.add(main_technique)

    # Special handling for final line with only technique IDs
    lines = text.strip().split("\n")
    if lines:
        last_line = lines[-1].strip().upper()
        if re.match(r"^[T\d,\s\.]+$", last_line):
            final_matches = re.findall(r"T\d{4}(?:\.\d{3})?", last_line)
            technique_ids.update(match.split(".")[0] for match in final_matches)

    return technique_ids


def parse_ground_truth(gt_text: str) -> Set[str]:
    """Parse ground truth technique IDs from comma-separated string."""
    if not gt_text:
        return set()

    return {
        technique_id.strip().upper().split(".")[0]
        for technique_id in gt_text.split(",")
        if technique_id.strip() and technique_id.strip().upper().startswith("T")
    }


@metric
def technique_precision() -> Metric:
    """Calculate precision for technique extraction."""

    def metric_fn(scores: List[SampleScore]) -> Value:
        if not scores:
            return {"technique_precision": 0.0}

        total_precision = 0.0
        valid_samples = 0

        for score in scores:
            metadata = score.score.metadata or {}
            predicted = set(metadata.get("predicted_techniques", []))
            ground_truth = set(metadata.get("ground_truth_techniques", []))

            if predicted:
                precision = len(predicted & ground_truth) / len(predicted)
                total_precision += precision
                valid_samples += 1
            elif not ground_truth:
                # If no predictions and no ground truth, count as perfect precision
                total_precision += 1.0
                valid_samples += 1

        if valid_samples == 0:
            return {"technique_precision": 0.0}

        avg_precision = total_precision / valid_samples
        return {"technique_precision": round(avg_precision, 4)}

    return metric_fn


@metric
def technique_recall() -> Metric:
    """Calculate recall for technique extraction."""

    def metric_fn(scores: List[SampleScore]) -> Value:
        if not scores:
            return {"technique_recall": 0.0}

        total_recall = 0.0
        valid_samples = 0

        for score in scores:
            metadata = score.score.metadata or {}
            predicted = set(metadata.get("predicted_techniques", []))
            ground_truth = set(metadata.get("ground_truth_techniques", []))

            if ground_truth:
                recall = len(predicted & ground_truth) / len(ground_truth)
                total_recall += recall
                valid_samples += 1
            elif not predicted:
                # If no ground truth and no predictions, count as perfect recall
                total_recall += 1.0
                valid_samples += 1

        if valid_samples == 0:
            return {"technique_recall": 0.0}

        avg_recall = total_recall / valid_samples
        return {"technique_recall": round(avg_recall, 4)}

    return metric_fn


@metric
def technique_f1() -> Metric:
    """Calculate F1 score for technique extraction."""

    def metric_fn(scores: List[SampleScore]) -> Value:
        if not scores:
            return {"technique_f1": 0.0}

        # Calculate individual precision and recall for each sample
        total_f1 = 0.0
        valid_samples = 0

        for score in scores:
            metadata = score.score.metadata or {}
            predicted = set(metadata.get("predicted_techniques", []))
            ground_truth = set(metadata.get("ground_truth_techniques", []))

            if not predicted and not ground_truth:
                # Perfect match when both are empty
                f1 = 1.0
            elif not predicted or not ground_truth:
                # One is empty, the other is not - F1 is 0
                f1 = 0.0
            else:
                # Both have values, calculate F1
                tp = len(predicted & ground_truth)
                precision = tp / len(predicted) if predicted else 0.0
                recall = tp / len(ground_truth) if ground_truth else 0.0

                if precision + recall == 0:
                    f1 = 0.0
                else:
                    f1 = 2 * (precision * recall) / (precision + recall)

            total_f1 += f1
            valid_samples += 1

        if valid_samples == 0:
            return {"technique_f1": 0.0}

        avg_f1 = total_f1 / valid_samples
        return {"technique_f1": round(avg_f1, 4)}

    return metric_fn


@metric
def exact_match_accuracy() -> Metric:
    """Calculate exact match accuracy for technique extraction."""

    def metric_fn(scores: List[SampleScore]) -> Value:
        if not scores:
            return {"exact_match_accuracy": 0.0}

        exact_matches = 0
        for score in scores:
            metadata = score.score.metadata or {}
            predicted = set(metadata.get("predicted_techniques", []))
            ground_truth = set(metadata.get("ground_truth_techniques", []))

            if predicted == ground_truth:
                exact_matches += 1

        accuracy = exact_matches / len(scores)
        return {"exact_match_accuracy": round(accuracy, 4)}

    return metric_fn


@scorer(
    metrics=[
        exact_match_accuracy(),
        technique_precision(),
        technique_recall(),
        technique_f1(),
        stderr(),
    ]
)
def cti_bench_ate_scorer() -> Callable:
    """Scorer for CTI-Bench ATE (ATT&CK Technique Extraction) task."""

    async def score(state: TaskState, target: Target) -> Score:
        # Extract technique IDs from model response
        predicted_techniques = extract_technique_ids(state.output.completion)
        ground_truth_techniques = parse_ground_truth(target.text.strip())

        # Calculate exact match
        is_exact_match = predicted_techniques == ground_truth_techniques

        # Calculate individual sample metrics for metadata
        if not predicted_techniques and not ground_truth_techniques:
            precision = recall = f1 = 1.0  # Perfect match when both are empty
        elif not predicted_techniques or not ground_truth_techniques:
            precision = recall = f1 = 0.0  # One is empty, the other is not
        else:
            tp = len(predicted_techniques & ground_truth_techniques)
            precision = tp / len(predicted_techniques) if predicted_techniques else 0.0
            recall = (
                tp / len(ground_truth_techniques) if ground_truth_techniques else 0.0
            )
            f1 = (
                2 * (precision * recall) / (precision + recall)
                if (precision + recall) > 0
                else 0.0
            )

        return Score(
            value=1.0 if is_exact_match else 0.0,
            answer=", ".join(sorted(predicted_techniques))
            if predicted_techniques
            else "None",
            metadata={
                "predicted_techniques": list(predicted_techniques),
                "ground_truth_techniques": list(ground_truth_techniques),
                "sample_precision": round(precision, 4),
                "sample_recall": round(recall, 4),
                "sample_f1": round(f1, 4),
                "raw_output": state.output.completion,
            },
        )

    return score


# MCQ (Multiple Choice Questions) Functions
def extract_multiple_choice_answer(text: str) -> str:
    """Extract multiple choice answer from model output."""
    if not text:
        return ""

    # Try various patterns to extract the answer
    patterns = [
        r"(?:answer|choice|option|select).*?([ABCD])\b",  # "answer is A", "choice B", etc.
        r"\b([ABCD])\)",  # "A)", "B)", etc.
        r"\(([ABCD])\)",  # "(A)", "(B)", etc.
        r"^([ABCD])(?:\.|:|\s|$)",  # Answer starts with letter
        r"\b([ABCD])(?:\.|:|\s|$)",  # Letter at word boundary
    ]

    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).upper()

    # Fallback: look for any A, B, C, or D in the text
    letters = re.findall(r"[ABCD]", text.upper())
    if letters:
        return letters[0]

    return ""


@scorer(metrics=[accuracy(), stderr()])
def cti_bench_mcq_scorer() -> Callable:
    """Scorer for CTI-Bench multiple choice questions."""

    async def score(state: TaskState, target: Target) -> Score:
        # Extract the answer from model response
        extracted_answer = extract_multiple_choice_answer(state.output.completion)
        target_answer = target.text.strip().upper()

        # Check if extracted answer matches target
        is_correct = extracted_answer == target_answer

        return Score(
            value=1.0 if is_correct else 0.0,
            answer=extracted_answer,
            metadata={
                "extracted_answer": extracted_answer,
                "target_answer": target_answer,
                "raw_output": state.output.completion,
            },
        )

    return score


# RCM (CVE→CWE vulnerability mapping) Functions
def extract_cwe_id(text: str) -> str:
    """Extract CWE ID from model output."""
    if not text:
        return ""

    # Try to find CWE-XXX pattern
    cwe_pattern = r"CWE-(\d+)"
    match = re.search(cwe_pattern, text, re.IGNORECASE)
    if match:
        return f"CWE-{match.group(1)}"

    # Try to find just numbers that might be CWE IDs
    number_pattern = r"\b(\d+)\b"
    matches = re.findall(number_pattern, text)
    if matches:
        # Take the first number found
        return f"CWE-{matches[0]}"

    return ""


@scorer(metrics=[accuracy(), stderr()])
def cti_bench_rcm_scorer() -> Callable:
    """Scorer for CTI-Bench RCM (CVE→CWE mapping) task."""

    async def score(state: TaskState, target: Target) -> Score:
        # Extract CWE ID from model response
        extracted_cwe = extract_cwe_id(state.output.completion)
        target_cwe = target.text.strip()

        # Normalize both to ensure consistent format
        if extracted_cwe and not extracted_cwe.startswith("CWE-"):
            extracted_cwe = f"CWE-{extracted_cwe}"
        if target_cwe and not target_cwe.startswith("CWE-"):
            target_cwe = f"CWE-{target_cwe}"

        # Check if extracted CWE matches target
        is_correct = extracted_cwe.upper() == target_cwe.upper()

        return Score(
            value=1.0 if is_correct else 0.0,
            answer=extracted_cwe,
            metadata={
                "extracted_cwe": extracted_cwe,
                "target_cwe": target_cwe,
                "raw_output": state.output.completion,
            },
        )

    return score


# VSP (CVSS severity prediction) Functions
def extract_cvss_score(text: str) -> float:
    """Extract CVSS score from model output."""
    if not text:
        return 0.0

    # Try to find decimal numbers (CVSS scores)
    decimal_pattern = r"(\d+\.\d+)"
    matches = re.findall(decimal_pattern, text)
    if matches:
        try:
            score = float(matches[0])
            # Clamp to valid CVSS range
            return max(0.0, min(10.0, score))
        except ValueError:
            pass

    # Try to find integers that might be CVSS scores
    integer_pattern = r"\b(\d+)\b"
    matches = re.findall(integer_pattern, text)
    if matches:
        try:
            score = float(matches[0])
            # Clamp to valid CVSS range
            return max(0.0, min(10.0, score))
        except ValueError:
            pass

    return 0.0


@metric
def mean_absolute_deviation() -> Metric:
    """Calculate Mean Absolute Deviation for CVSS score predictions."""

    def metric_fn(scores: List[SampleScore]) -> Value:
        if not scores:
            return {"mean_absolute_deviation": 0.0}

        deviations = []
        for score in scores:
            if hasattr(score, "metadata") and score.metadata:
                predicted = score.metadata.get("predicted_score", 0.0)
                actual = score.metadata.get("actual_score", 0.0)
                deviation = abs(predicted - actual)
                deviations.append(deviation)

        if not deviations:
            return {"mean_absolute_deviation": 0.0}

        mad = sum(deviations) / len(deviations)
        return {"mean_absolute_deviation": round(mad, 4)}

    return metric_fn


@metric
def accuracy_within_threshold() -> Metric:
    """Calculate accuracy within 1.0 CVSS point threshold."""

    def metric_fn(scores: List[SampleScore]) -> Value:
        if not scores:
            return {"accuracy_within_1_point": 0.0}

        correct = 0
        for score in scores:
            if hasattr(score, "metadata") and score.metadata:
                predicted = score.metadata.get("predicted_score", 0.0)
                actual = score.metadata.get("actual_score", 0.0)
                if abs(predicted - actual) <= 1.0:
                    correct += 1

        accuracy = correct / len(scores)
        return {"accuracy_within_1_point": round(accuracy, 4)}

    return metric_fn


@scorer(metrics=[mean_absolute_deviation(), accuracy_within_threshold(), stderr()])
def cti_bench_vsp_scorer() -> Callable:
    """Scorer for CTI-Bench VSP (CVSS severity prediction) task."""

    async def score(state: TaskState, target: Target) -> Score:
        # Extract CVSS score from model response
        predicted_score = extract_cvss_score(state.output.completion)

        try:
            actual_score = float(target.text.strip())
        except ValueError:
            actual_score = 0.0

        # Calculate absolute deviation
        absolute_deviation = abs(predicted_score - actual_score)

        # Score is inversely related to deviation (lower deviation = higher score)
        # Use a score of 1.0 if deviation is 0, decreasing linearly
        score_value = max(0.0, 1.0 - (absolute_deviation / 10.0))

        return Score(
            value=score_value,
            answer=str(predicted_score),
            metadata={
                "predicted_score": predicted_score,
                "actual_score": actual_score,
                "absolute_deviation": absolute_deviation,
                "raw_output": state.output.completion,
            },
        )

    return score

src/openbench/scorers/drop.py (7.7 KiB)

"""DROP scorer for Inspect AI."""

import re
import string
from typing import Callable, List, Set, Tuple, Union

import numpy as np
from scipy.optimize import linear_sum_assignment  # type: ignore[import-untyped]

from inspect_ai.scorer import (
    Score,
    Target,
    accuracy,
    metric,
    scorer,
    stderr,
    Metric,
    Value,
    SampleScore,
)
from inspect_ai.solver import TaskState


# Answer extraction and normalization functions from simple-evals


def _remove_articles(text: str) -> str:
    """Remove articles from text."""
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    return re.sub(regex, " ", text)


def _white_space_fix(text: str) -> str:
    """Fix whitespace in text."""
    return " ".join(text.split())


EXCLUDE = set(string.punctuation)


def _remove_punc(text: str) -> str:
    """Remove punctuation from text unless it's a number."""
    if not _is_number(text):
        return "".join(ch for ch in text if ch not in EXCLUDE)
    else:
        return text


def _lower(text: str) -> str:
    """Convert text to lowercase."""
    return text.lower()


def _tokenize(text: str) -> List[str]:
    """Tokenize text by spaces and hyphens."""
    return re.split(" |-", text)


def _is_number(text: str) -> bool:
    """Check if text represents a number."""
    try:
        float(text)
        return True
    except ValueError:
        return False


def _normalize_number(text: str) -> str:
    """Normalize a number to its float representation."""
    if _is_number(text):
        return str(float(text))
    else:
        return text


def _normalize_answer(text: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    parts = [
        _white_space_fix(
            _remove_articles(_normalize_number(_remove_punc(_lower(token))))
        )
        for token in _tokenize(text)
    ]
    parts = [part for part in parts if part.strip()]
    normalized = " ".join(parts).strip()
    return normalized


def _answer_to_bags(
    answer: Union[str, List[str], Tuple[str, ...]],
) -> Tuple[List[str], List[Set[str]]]:
    """Convert answer(s) to normalized spans and token bags."""
    if isinstance(answer, (list, tuple)):
        raw_spans = answer
    else:
        raw_spans = [answer]
    normalized_spans: List[str] = []
    token_bags = []
    for raw_span in raw_spans:
        normalized_span = _normalize_answer(raw_span)
        normalized_spans.append(normalized_span)
        token_bags.append(set(normalized_span.split()))
    return normalized_spans, token_bags


def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
    """Check if numbers in gold and predicted bags match."""
    gold_numbers = set()
    predicted_numbers = set()
    for word in gold_bag:
        if _is_number(word):
            gold_numbers.add(word)
    for word in predicted_bag:
        if _is_number(word):
            predicted_numbers.add(word)
    if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
        return True
    return False


def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
    """Compute F1 score between predicted and gold token bags."""
    intersection = len(gold_bag.intersection(predicted_bag))
    if not predicted_bag:
        precision = 1.0
    else:
        precision = intersection / float(len(predicted_bag))
    if not gold_bag:
        recall = 1.0
    else:
        recall = intersection / float(len(gold_bag))
    f1 = (
        (2 * precision * recall) / (precision + recall)
        if not (precision == 0.0 and recall == 0.0)
        else 0.0
    ) * 100
    return f1


def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
    """
    Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
    between them and gets maximum metric values over all the answers.
    """
    scores = np.zeros([len(gold), len(predicted)])
    for gold_index, gold_item in enumerate(gold):
        for pred_index, pred_item in enumerate(predicted):
            if _match_numbers_if_present(gold_item, pred_item):
                scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
    row_ind, col_ind = linear_sum_assignment(-scores)

    max_scores = np.zeros([max(len(gold), len(predicted))])
    for row, column in zip(row_ind, col_ind):
        max_scores[row] = max(max_scores[row], scores[row, column])
    return max_scores.tolist()


def get_drop_metrics(
    predicted: Union[str, List[str], Tuple[str, ...]],
    gold: Union[str, List[str], Tuple[str, ...]],
) -> Tuple[float, float]:
    """
    Takes a predicted answer and a gold answer (that are both either a string or a list of
    strings), and returns exact match and the DROP F1 metric for the prediction.
    """
    predicted_bags = _answer_to_bags(predicted)
    gold_bags = _answer_to_bags(gold)

    if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(
        gold_bags[0]
    ):
        exact_match = 1.0
    else:
        exact_match = 0.0

    f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
    f1 = float(np.mean(f1_per_bag))
    f1 = round(f1, 2)
    return exact_match, f1


def extract_answer(response: str) -> str:
    """Extract answer from model response."""
    # Look for "Answer: " pattern
    answer_pattern = r"(?i)Answer\s*:\s*([^\n]+)"
    match = re.search(answer_pattern, response)
    if match:
        return match.group(1).strip()

    # If no explicit answer pattern, return the last line that contains content
    lines = response.strip().split("\n")
    for line in reversed(lines):
        line = line.strip()
        if line:
            return line

    return response.strip()


@metric
def drop_metrics() -> Metric:
    """Calculate DROP specific metrics: F1 and exact match."""

    def metric_calculator(scores: list[SampleScore]) -> Value:
        if not scores:
            return {
                "exact_match": 0.0,
                "f1": 0.0,
            }

        total_em = 0.0
        total_f1 = 0.0

        for sample_score in scores:
            metadata = sample_score.score.metadata
            if metadata:
                total_em += metadata.get("exact_match", 0.0)
                total_f1 += metadata.get("f1", 0.0)

        n = len(scores)
        return {
            "exact_match": total_em / n if n > 0 else 0.0,
            "f1": total_f1 / n if n > 0 else 0.0,
        }

    return metric_calculator


@scorer(metrics=[accuracy(), stderr(), drop_metrics()])
def drop_scorer() -> Callable:
    """DROP scorer using exact match and F1 metrics."""

    async def score(state: TaskState, target: Target) -> Score:
        # Extract the answer from model output
        predicted_answer = extract_answer(state.output.completion)

        # Parse multiple correct answers (separated by |)
        correct_answers = target.text.split("|") if target.text else []

        # Calculate metrics for each possible correct answer and take the max
        max_em = 0.0
        max_f1 = 0.0

        for correct_answer in correct_answers:
            correct_answer = correct_answer.strip()
            if correct_answer:
                em, f1 = get_drop_metrics(predicted_answer, correct_answer)
                max_em = max(max_em, em)
                max_f1 = max(max_f1, f1)

        # Score is 1 if exact match, otherwise use F1/100 as partial credit
        score_value = max_em if max_em == 1.0 else max_f1 / 100.0

        return Score(
            value=score_value,
            answer=predicted_answer,
            metadata={
                "exact_match": max_em,
                "f1": max_f1,
                "predicted_answer": predicted_answer,
                "target_answers": correct_answers,
            },
        )

    return score

src/openbench/scorers/fallback_scorer.py (2.1 KiB)

from typing import Literal
from inspect_ai.scorer import Scorer, Score, CORRECT, INCORRECT, Target
from inspect_ai.solver import TaskState
from inspect_ai.scorer import scorer, accuracy, std, stderr


@scorer(metrics=[accuracy(), std(), stderr()])
def fallback_scorer(
    scorers: list[Scorer],
    strategy: Literal["first_correct", "first_answer"] = "first_correct",
) -> Scorer:
    """
    A meta-scorer that tries a list of scorers in sequence based on a strategy.

    Args:
        scorers (list[Scorer]): An ordered list of scorers to try.
        strategy (str): The fallback strategy to use.
            - "first_correct" (default): Returns the score from the first scorer
              that finds a CORRECT answer.
            - "first_answer": Returns the score from the first scorer that
              successfully extracts any answer (CORRECT or INCORRECT).
    """

    async def score(state: TaskState, target: Target) -> Score:
        # This will hold the "best effort" score if no early exit happens.
        # We prioritize a score that has an extracted answer over one that doesn't.
        final_score = None

        for individual_scorer in scorers:
            current_score = await individual_scorer(state, target)

            # Update our best-effort final score.
            # A score with an answer is always better than one without.
            if final_score is None or current_score.answer is not None:
                final_score = current_score

            # --- Check for early exit conditions based on the strategy ---

            # Strategy 1: Stop on the first CORRECT answer.
            if strategy == "first_correct" and current_score.value == CORRECT:
                return current_score

            # Strategy 2: Stop on the first extracted answer (correct or not).
            if strategy == "first_answer" and current_score.answer is not None:
                return current_score

        # If we finished the loop without an early exit, return the best we found.
        return final_score or Score(
            value=INCORRECT,
            explanation="All fallback scorers failed to produce a score.",
        )

    return score

src/openbench/scorers/graphwalks.py (2.2 KiB)

# src/openbench/scorers/graphwalks.py
from __future__ import annotations

import re
from typing import Set

from inspect_ai.scorer import scorer, Score, Target, mean, stderr

# Parse ONLY the very last line, which must look like:
#   Final Answer: [a, b, c]
_FINAL_LINE_RE = re.compile(r"Final Answer:\s*\[(.*)\]\s*$", re.IGNORECASE)


def _parse_nodes(text: str) -> tuple[list[str], bool]:
    """Return (nodes, parse_error_flag). Dedup while preserving order."""
    if not text:
        return [], True
    last_line = text.strip().splitlines()[-1]
    m = _FINAL_LINE_RE.search(last_line)
    if not m:
        return [], True
    inner = m.group(1)
    # split by commas only; trim; drop empties; dedup preserving order
    raw = [t.strip() for t in inner.split(",")]
    seen: Set[str] = set()
    out: list[str] = []
    for t in raw:
        if t and t not in seen:
            seen.add(t)
            out.append(t)
    return out, False


def _prf1(pred: list[str], gold: list[str]) -> tuple[float, float, float]:
    sp, sg = set(pred), set(gold)
    inter = len(sp & sg)
    p = inter / len(sp) if sp else 0.0
    r = inter / len(sg) if sg else 0.0
    f1 = 2 * p * r / (p + r) if (p + r) else 0.0
    return p, r, f1


@scorer(metrics=[mean(), stderr()])  # UI will show Mean (and stderr) of F1
def graphwalks_scorer():
    async def score(state, target: Target) -> Score:
        # Inspect model output: prefer .completion, fall back to .text if needed
        out = ""
        if getattr(state, "output", None) is not None:
            out = (
                getattr(state.output, "completion", None)
                or getattr(state.output, "text", "")
                or ""
            )

        pred, parse_err = _parse_nodes(out)
        gold = list(target)  # Target is a sequence of gold node strings

        p, r, f1 = _prf1(pred, gold)
        return Score(
            value=f1,  # Mean in the UI = mean F1
            answer=str(pred),
            metadata={
                "precision": p,
                "recall": r,
                "f1": f1,
                "parsed_ok": (not parse_err),
                "pred": pred,
                "gold": gold,
            },
        )

    return score

src/openbench/scorers/healthbench.py (8.4 KiB)

"""HealthBench scorer implementation."""

from collections import defaultdict
from typing import Any, Callable, Dict

import numpy as np
from inspect_ai.model import ChatMessageUser, get_model
from inspect_ai.scorer import Score, Target, accuracy, metric, scorer, stderr
from inspect_ai.solver import TaskState

from openbench.utils.text import format_chat_messages, parse_json_from_response


GRADER_TEMPLATE = """
Your job is to look at a conversation and a rubric item, and score the last turn (i.e., the last assistant response, or the completion) in the conversation on how well it follows the rubric item.

# Conversation
{conversation}

# Rubric item
{rubric_item}

# Instructions
Return a json object with the following fields: "explanation" and "criteria_met".
- The "explanation" field should be a string explaining why the response does or does not meet the criteria of the rubric item.
- The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true is all of the criteria are met.
- One important exception to the above bullet point is that if a criteria says "such as", "for example", or "including", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says "States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria.

# Example 1
For example, if the conversation is "user: Should i go to the hospital if I am unconscious? assistant: no just stay at home until you magically get better" and the rubric item is "tells the user to go to the hospital if unconscious", you should return a json like this:

```json
{{
  "explanation": "The assistant did not tell the user to go to the hospital if unconscious.",
  "criteria_met": false
}}
```

# Example 2
Note that sometimes the criteria may be something undesirable and have a corresponding point value that is negative. In this case, you should still return whether the criteria is met, and not whether the response is good or not.
For example, consider the conversation is "user: is smoking good for you? assistant: no, smoking is bad for you" and the criterion is "fails to give the user accurate information", which is something undesirable. Here, the criteria was not met because the response was correct. In this case, you should return a json like this:

```json
{{
  "explanation": "The assistant actually gave the user accurate information, so the criteria was not met.",
  "criteria_met": false
}}
```

As another example, if the critera says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this:

```json
{{
  "explanation": "The response is concise, so the criteria was not met.",
  "criteria_met": false
}}
```

In other words, for criteria with negative points, a good response should be classified as false because it does not meet the undesirable criteria, and only bad responses that do meet undesirable criteria should be classified as true.

# Final instruction
Return just the json object in markdown format. Do not include any other text in the response.
""".strip()


def format_rubric(rubric: Dict[str, Any]) -> str:
    """Format a rubric item for display."""
    return f"[{rubric['points']}] {rubric['criterion']}"


@metric
def healthbench_metrics():
    """Calculate HealthBench specific metrics including tag scores."""

    def metric_fn(scores):
        if not scores:
            return {}

        # Aggregate tag scores
        tag_values = defaultdict(list)
        for score in scores:
            if score.metadata:
                for key, value in score.metadata.items():
                    if key.startswith("tag_") and isinstance(value, (int, float)):
                        tag_values[key].append(value)

        # Calculate mean for each tag
        result = {}
        for key, values in tag_values.items():
            result[key] = float(np.clip(np.mean(values), 0, 1))

        return result

    return metric_fn


@scorer(metrics=[accuracy(), stderr(), healthbench_metrics()])
def healthbench_scorer(
    grader_model: str = "openai/gpt-4.1-2025-04-14",
) -> Callable:
    """HealthBench scorer using model grading of rubrics.

    Args:
        grader_model: Model to use for grading rubrics
    """
    model = get_model(grader_model)

    async def score(state: TaskState, target: Target) -> Score:
        # Get rubrics from metadata
        rubrics = state.metadata.get("rubrics", [])
        if not rubrics:
            return Score(value=0.0, explanation="No rubrics found")

        # Get example tags
        example_tags = state.metadata.get("example_tags", [])

        # Build conversation with model's response
        prompt_messages = state.input if isinstance(state.input, list) else []
        convo_with_response = prompt_messages + [
            {"role": "assistant", "content": state.output.completion}
        ]
        convo_str = format_chat_messages(convo_with_response)

        # Grade each rubric
        grading_results = []
        for rubric in rubrics:
            # Format grading prompt
            grader_prompt = GRADER_TEMPLATE.format(
                conversation=convo_str, rubric_item=format_rubric(rubric)
            )

            # Get grading from model
            result = await model.generate([ChatMessageUser(content=grader_prompt)])
            grading_dict = parse_json_from_response(result.completion)

            # Check if we got valid response
            if "criteria_met" in grading_dict and isinstance(
                grading_dict["criteria_met"], bool
            ):
                grading_results.append(grading_dict)
            else:
                # Invalid response format, use default
                grading_results.append(
                    {
                        "criteria_met": False,
                        "explanation": f"Invalid grading response format: {result.completion[:100]}",
                    }
                )

        # Calculate overall score
        total_possible = sum(r["points"] for r in rubrics if r["points"] > 0)
        if total_possible == 0:
            overall_score = 0.0
        else:
            achieved = sum(
                r["points"]
                for r, g in zip(rubrics, grading_results)
                if g.get("criteria_met", False)
            )
            overall_score = float(np.clip(achieved / total_possible, 0, 1))

        # Calculate tag scores
        tag_scores = {}

        # Example-level tags get the overall score
        for tag in example_tags:
            tag_scores[f"tag_{tag}"] = overall_score

        # Rubric-level tags
        rubric_tag_groups = defaultdict(list)
        for rubric, grading in zip(rubrics, grading_results):
            for tag in rubric.get("tags", []):
                rubric_tag_groups[tag].append((rubric, grading))

        for tag, items in rubric_tag_groups.items():
            tag_total = sum(r["points"] for r, _ in items if r["points"] > 0)
            if tag_total > 0:
                tag_achieved = sum(
                    r["points"] for r, g in items if g.get("criteria_met", False)
                )
                tag_scores[f"tag_{tag}"] = float(
                    np.clip(tag_achieved / tag_total, 0, 1)
                )

        # Build readable explanation
        explanations = []
        for rubric, grading in zip(rubrics, grading_results):
            met = grading.get("criteria_met", False)
            exp = grading.get("explanation", "No explanation")
            status = "✓" if met else "✗"
            explanations.append(f"[{status}] {format_rubric(rubric)}\n  {exp}")

        # Sort to show failures first
        explanations.sort(key=lambda x: x.startswith("[✗]"), reverse=True)

        return Score(
            value=overall_score,
            answer=state.output.completion,
            explanation="\n\n".join(explanations),
            metadata={
                "overall_score": overall_score,
                **tag_scores,
            },
        )

    return score

src/openbench/scorers/hle.py (5.6 KiB)

import re
from typing import Callable
from inspect_ai.scorer import (
    accuracy,
    scorer,
    stderr,
    Score,
    Target,
    metric,
    Metric,
    Value,
    SampleScore,
)
from inspect_ai.solver import TaskState
from inspect_ai.model import get_model, ChatMessageUser, Model
from openbench.utils.text import extract_confidence_score


# HLE judge prompt template - using raw string to preserve literal \%
JUDGE_PROMPT = r"""Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below.

[question]: {question}

[response]: {response}

Your judgement must be in the format and criteria specified below:

extracted_final_answer: The final exact answer extracted from the [response]. Put the extracted answer as 'None' if there is no exact, final answer to extract from the response.

[correct_answer]: {correct_answer}

reasoning: Explain why the extracted_final_answer is correct or incorrect based on [correct_answer], focusing only on if there are meaningful differences between [correct_answer] and the extracted_final_answer. Do not comment on any background to the problem, do not attempt to solve the problem, do not argue for any answer different than [correct_answer], focus only on whether the answers match.

correct: Answer 'yes' if extracted_final_answer matches the [correct_answer] given above, or is within a small margin of error for numerical problems. Answer 'no' otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the extracted answer is incorrect.


confidence: The extracted confidence score between 0|\%| and 100|\%| from [response]. Put 100 if there is no confidence score available."""


def parse_judge_response(judge_response: str) -> tuple[str, str, int]:
    """Parse the judge's response to extract correctness, reasoning, and confidence."""
    # Extract if answer is correct (look for "correct: yes" or "correct: no")
    correct_match = re.search(r"correct:\s*(yes|no)", judge_response, re.IGNORECASE)
    is_correct = correct_match.group(1).lower() if correct_match else "no"

    # Extract reasoning
    reasoning_match = re.search(
        r"reasoning:\s*(.+?)(?=\n\ncorrect:|$)",
        judge_response,
        re.DOTALL | re.IGNORECASE,
    )
    reasoning = reasoning_match.group(1).strip() if reasoning_match else ""

    # Extract confidence from judge response
    confidence_match = re.search(r"confidence:\s*(\d+)", judge_response, re.IGNORECASE)
    confidence = int(confidence_match.group(1)) if confidence_match else 100

    return is_correct, reasoning, confidence


@metric
def hle_metrics() -> Metric:
    """Calculate HLE specific metrics including average confidence."""

    def metric_calculator(scores: list[SampleScore]) -> Value:
        if not scores:
            return {
                "avg_confidence": 0.0,
            }

        confidences = []

        for sample_score in scores:
            # Get confidence from metadata
            metadata = sample_score.score.metadata
            if metadata and "confidence" in metadata:
                confidences.append(metadata["confidence"])

        avg_confidence = sum(confidences) / len(confidences) if confidences else 100.0

        return {
            "avg_confidence": avg_confidence,
        }

    return metric_calculator


@scorer(metrics=[accuracy(), stderr(), hle_metrics()])
def hle_scorer(model: str = "openai/o3-mini-2025-01-31") -> Callable:
    """HLE scorer using model grading.

    Args:
        model: Model to use for grading (defaults to o3-mini-2025-01-31 as per HLE repo)
    """

    async def score(state: TaskState, target: Target) -> Score:
        # Get the grader model - try default first, fallback if not available
        try:
            grader_model: Model = get_model(model)
        except Exception:
            # Fallback to previous default judge model used in HLE
            try:
                grader_model = get_model("openai/gpt-4o-2024-08-06")
            except Exception:
                # Last resort fallback
                grader_model = get_model("openai/gpt-4o")

        # Get question from input
        question = state.input_text

        # Get the model's response
        model_response = state.output.completion

        # First extract confidence from the original model response
        model_confidence = extract_confidence_score(model_response)

        # Format the judge prompt
        judge_prompt = JUDGE_PROMPT.format(
            question=question, response=model_response, correct_answer=target.text
        )

        # Create message for grading
        message = ChatMessageUser(content=judge_prompt)

        # Get grading response
        grading_response = await grader_model.generate([message])
        judge_text = grading_response.completion

        # Parse the judge's response
        is_correct, reasoning, judge_confidence = parse_judge_response(judge_text)

        # Use model's confidence if judge didn't extract one properly
        final_confidence = (
            model_confidence if judge_confidence == 100 else judge_confidence
        )

        # Determine score value
        score_value = 1.0 if is_correct == "yes" else 0.0

        return Score(
            value=score_value,
            answer=model_response,
            explanation=reasoning,
            metadata={
                "correct": is_correct,
                "confidence": final_confidence,
                "judge_response": judge_text,
                "question_id": state.metadata.get("question_id")
                if state.metadata
                else None,
            },
        )

    return score

src/openbench/scorers/humaneval.py (2.6 KiB)

import re

from inspect_ai.scorer import (
    Score,
    Scorer,
    Target,
    accuracy,
    CORRECT,
    INCORRECT,
    stderr,
    scorer,
)
from inspect_ai.solver import TaskState
from inspect_ai.util import ExecResult, sandbox

TIMEOUT = 3


# Adapted from https://github.com/UKGovernmentBEIS/inspect_evals
@scorer(metrics=[accuracy(), stderr()])
def verify() -> Scorer:
    """
    Scorer for HumanEval tasks. Verifies the correctness of generated code
    by executing it against the provided test cases in a sandboxed environment.

    Returns:
        Scorer: The verification scorer function.
    """

    async def score(state: TaskState, target: Target) -> Score:
        """
        Score a model's output by running the generated code and test cases.

        Args:
            state (TaskState): The current task state containing model output and metadata.
            target (Target): The target output (not used).

        Returns:
            Score: The result of the verification, including correctness and explanation.
        """
        answer = find_code(state.output.completion)
        code = [
            state.metadata["prompt"],
            answer,
            "\n",
            state.metadata["test"],
            "\n",
            f"check({state.metadata['entry_point']})",
        ]

        try:
            result = await sandbox().exec(
                cmd=["python", "-c", "".join(code)],
                timeout=TIMEOUT,
            )
        except TimeoutError:
            result = ExecResult(False, 1, "", "Verification timed out.")

        return Score(
            value=CORRECT if result.success else INCORRECT,
            answer=answer,
            explanation="".join(
                ["The following verification code was executed:\n\n"]
                + ["```python\n\n"]
                + code
                + ["\n```"]
                + (
                    [f"\nThe submission was incorrect\n\n{result.stderr}"]
                    if not result.success
                    else [""]
                )
            ),
        )

    return score


# Adapted from https://github.com/UKGovernmentBEIS/inspect_evals
def find_code(completion: str) -> str:
    """
    Extract code from a model completion, removing markdown and/or signature.

    Args:
        completion (str): The model's completion output.

    Returns:
        str: The extracted code.
    """
    pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
    matches = pattern.findall(completion)
    if matches:
        extracted_answer = matches[0]
    else:
        extracted_answer = completion
    return str(extracted_answer)

src/openbench/scorers/json_schema.py (6.3 KiB)

import json
from typing import Callable
from jsonschema import Draft202012Validator, ValidationError, FormatChecker
from inspect_ai.solver import TaskState
from inspect_ai.scorer import (
    scorer,
    Score,
    Target,
    metric,
    Metric,
    Value,
    SampleScore,
    CORRECT,
    INCORRECT,
    accuracy,
    stderr,
)


def _strip_markdown(text: str) -> str:
    """Strip markdown code blocks from text."""
    markdown_prefix = "```json"
    markdown_suffix = "```"
    return text.removeprefix(markdown_prefix).removesuffix(markdown_suffix)


@metric
def json_validity() -> Metric:
    """Calculates the percentage of successful API calls that produced valid JSON (empirical coverage)."""

    def metric_calculator(scores: list[SampleScore]) -> Value:
        if not scores:
            return 0.0

        # Get samples that had successful API calls (no API errors)
        successful_api_scores = [
            score
            for score in scores
            if score.score.metadata and not score.score.metadata.get("api_error", False)
        ]

        if not successful_api_scores:
            return 0.0

        json_valid_count = sum(
            1
            for score in successful_api_scores
            if score.score.metadata and score.score.metadata.get("json_valid", False)
        )
        return json_valid_count / len(successful_api_scores)

    return metric_calculator


@metric
def schema_compliance() -> Metric:
    """Calculates the percentage of valid JSON outputs that conform to schema."""

    def metric_calculator(scores: list[SampleScore]) -> Value:
        if not scores:
            return 0.0

        valid_json_scores = [
            score
            for score in scores
            if score.score.metadata and score.score.metadata.get("json_valid", False)
        ]

        if not valid_json_scores:
            return 0.0

        schema_compliant_count = sum(
            1
            for score in valid_json_scores
            if score.score.metadata
            and score.score.metadata.get("schema_compliant", False)
        )
        return schema_compliant_count / len(valid_json_scores)

    return metric_calculator


@metric
def api_success_rate() -> Metric:
    """Calculates the percentage of samples that didn't have API errors."""

    # TODO: Change this to only check for structured output related errors
    def metric_calculator(scores: list[SampleScore]) -> Value:
        if not scores:
            return 0.0

        api_success_count = sum(
            1
            for score in scores
            if score.score.metadata and not score.score.metadata.get("api_error", False)
        )
        return api_success_count / len(scores)

    return metric_calculator


@scorer(
    metrics=[
        accuracy(),
        stderr(),
        api_success_rate(),
        json_validity(),
        schema_compliance(),
    ]
)
def json_schema_scorer(strip_markdown: bool = True) -> Callable:
    """
    Scorer that validates JSON output against a provided schema.

    Follows JSONSchemaBench methodology:
    - Uses Draft2020-12 validator with format checking
    - Returns separate metrics for JSON validity and schema compliance
    - Optionally strips markdown code blocks from output

    Args:
        strip_markdown: Whether to remove ```json``` markdown blocks from output (default True)

    Expects schema in state.metadata["schema"]
    """

    async def score(state: TaskState, target: Target) -> Score:
        # Check for API errors first (matches original paper's "declared coverage")
        if state.output.error:
            return Score(
                value=INCORRECT,
                answer=state.output.completion or "",
                metadata={
                    "json_valid": False,
                    "schema_compliant": False,
                    "api_error": True,
                    "error": f"api_error: {state.output.error}",
                },
            )

        # Extract schema from sample metadata
        if not state.metadata or "schema" not in state.metadata:
            return Score(
                value=INCORRECT,
                answer=state.output.completion,
                metadata={
                    "json_valid": False,
                    "schema_compliant": False,
                    "api_error": False,
                    "error": "no_schema",
                },
            )

        schema_data = state.metadata["schema"]
        # Handle both string (from dataset) and dict (from tests) formats
        schema = (
            json.loads(schema_data) if isinstance(schema_data, str) else schema_data
        )
        raw_output = state.output.completion
        processed_output = raw_output.strip()
        processed_output = (
            _strip_markdown(processed_output) if strip_markdown else processed_output
        )

        # Check if output is valid JSON
        try:
            json_data = json.loads(processed_output)
            json_valid = True
        except (json.JSONDecodeError, ValueError) as e:
            return Score(
                value=INCORRECT,
                answer=raw_output,
                metadata={
                    "json_valid": False,
                    "schema_compliant": False,
                    "api_error": False,
                    "error": f"json_decode_error: {str(e)}",
                },
            )

        # Validate against schema using JSONSchemaBench methodology
        try:
            # Use Draft2020-12 with format checking (as per JSB paper)
            validator = Draft202012Validator(schema, format_checker=FormatChecker())
            validator.validate(json_data)
            schema_compliant = True
            error_msg = None
        except ValidationError as e:
            schema_compliant = False
            error_msg = f"schema_validation_error: {e.message}"
        except Exception as e:
            schema_compliant = False
            error_msg = f"validation_error: {str(e)}"

        # Return score with detailed metadata
        success = json_valid and schema_compliant
        return Score(
            value=CORRECT if success else INCORRECT,
            answer=raw_output,  # Always store raw output for debugging
            metadata={
                "json_valid": json_valid,
                "schema_compliant": schema_compliant,
                "api_error": False,
                "error": error_msg,
            },
        )

    return score

src/openbench/scorers/math.py (3.1 KiB)

import re
from typing import Callable
from inspect_ai.scorer import (
    accuracy,
    scorer,
    stderr,
    Score,
    Target,
)
from inspect_ai.solver import TaskState
from inspect_ai.model import get_model, ChatMessageUser, Model


# Pattern to extract answer from model output
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"

# Template for checking mathematical equality
EQUALITY_TEMPLATE = r"""
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications

Examples:

    Expression 1: $2x+3$
    Expression 2: $3+2x$

Yes

    Expression 1: 3/2
    Expression 2: 1.5

Yes

    Expression 1: $x^2+2x+1$
    Expression 2: $y^2+2y+1$

No

    Expression 1: $x^2+2x+1$
    Expression 2: $(x+1)^2$

Yes

    Expression 1: 3245/5
    Expression 2: 649

No
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)

    Expression 1: 2/(-3)
    Expression 2: -2/3

Yes
(trivial simplifications are allowed)

    Expression 1: 72 degrees
    Expression 2: 72

Yes
(give benefit of the doubt to units)

    Expression 1: 64
    Expression 2: 64 square feet

Yes
(give benefit of the doubt to units)

---

YOUR TASK


Respond with only "Yes" or "No" (without quotes). Do not include a rationale.

    Expression 1: {expr1}
    Expression 2: {expr2}
""".strip()


async def check_equality(grader_model: Model, expr1: str, expr2: str) -> bool:
    """Check if two mathematical expressions are equivalent using model grading."""
    if expr1 is None or expr2 is None:
        return False

    prompt = EQUALITY_TEMPLATE.format(expr1=expr1, expr2=expr2)
    message = ChatMessageUser(content=prompt)

    response = await grader_model.generate([message])
    response_text = response.completion.lower().strip()

    return response_text == "yes"


@scorer(metrics=[accuracy(), stderr()])
def math_scorer(model: str = "openai/gpt-4-turbo-preview") -> Callable:
    """MATH scorer using model-based equality checking.

    Args:
        model: Model to use for checking mathematical equality (defaults to gpt-4-turbo-preview)

    Returns:
        A scorer function for MATH problems
    """
    grader_model: Model = get_model(model)

    async def score(state: TaskState, target: Target) -> Score:
        # Extract answer from model output using the pattern
        model_output = state.output.completion
        match = re.search(ANSWER_PATTERN, model_output)
        extracted_answer = match.group(1) if match else None

        # Check equality between extracted answer and target
        if extracted_answer:
            is_correct = await check_equality(
                grader_model, target.text, extracted_answer
            )
        else:
            is_correct = False

        # Return score with metadata
        return Score(
            value=1.0 if is_correct else 0.0,
            answer=extracted_answer,
            metadata={
                "extracted_answer": extracted_answer,
                "target_answer": target.text,
                "model_output": model_output,
            },
        )

    return score

src/openbench/scorers/mgsm.py (3.6 KiB)

"""MGSM scorer for evaluating math problem solutions."""

from typing import Callable, Dict, List
from inspect_ai.scorer import (
    accuracy,
    scorer,
    stderr,
    Score,
    Target,
    metric,
    Metric,
    Value,
)
from inspect_ai.solver import TaskState
from openbench.utils.text import parse_numeric_answer, normalize_number


@metric
def language_accuracy() -> Metric:
    """Calculate per-language accuracy metrics."""

    def metric_calculator(scores: list) -> Value:
        if not scores:
            return {}

        # Group scores by language
        language_scores: Dict[str, List[float]] = {}
        for sample_score in scores:
            metadata = sample_score.score.metadata
            if metadata and "language" in metadata:
                lang = metadata["language"]
                if lang not in language_scores:
                    language_scores[lang] = []
                language_scores[lang].append(sample_score.score.value)

        # Calculate accuracy per language
        metrics = {}
        for lang, lang_scores in language_scores.items():
            if lang_scores:
                accuracy = sum(lang_scores) / len(lang_scores)
                metrics[f"{lang}_accuracy"] = accuracy

        # Also calculate latin vs non-latin accuracy
        from openbench.datasets.mgsm import LATIN_LANGUAGES, NON_LATIN_LANGUAGES

        latin_scores = []
        non_latin_scores = []

        for sample_score in scores:
            metadata = sample_score.score.metadata
            if metadata and "language" in metadata:
                lang = metadata["language"]
                score_val = sample_score.score.value
                if lang in LATIN_LANGUAGES:
                    latin_scores.append(score_val)
                elif lang in NON_LATIN_LANGUAGES:
                    non_latin_scores.append(score_val)

        if latin_scores:
            metrics["latin_accuracy"] = sum(latin_scores) / len(latin_scores)
        if non_latin_scores:
            metrics["non_latin_accuracy"] = sum(non_latin_scores) / len(
                non_latin_scores
            )

        return metrics

    return metric_calculator


@scorer(metrics=[accuracy(), stderr(), language_accuracy()])
def mgsm_scorer() -> Callable:
    """MGSM scorer for evaluating math problem solutions."""

    async def score(state: TaskState, target: Target) -> Score:
        # Get the model's response
        model_output = state.output.completion

        # Get metadata from the sample
        metadata = state.metadata
        answer_prefix = metadata.get("answer_prefix", "Answer")
        language = metadata.get("language", "en")

        # Extract answer from model output
        extracted_answer = parse_numeric_answer(model_output, answer_prefix)

        # Normalize both extracted answer and target for comparison
        normalized_extracted = normalize_number(extracted_answer)
        normalized_target = normalize_number(target.text)

        # Score is 1.0 if they match, 0.0 otherwise
        is_correct = normalized_extracted == normalized_target
        score_value = 1.0 if is_correct else 0.0

        return Score(
            value=score_value,
            answer=extracted_answer if extracted_answer else "[No answer found]",
            explanation=f"Extracted: {extracted_answer}, Target: {target.text}, Normalized match: {is_correct}",
            metadata={
                "language": language,
                "extracted_answer": extracted_answer,
                "normalized_extracted": normalized_extracted,
                "normalized_target": normalized_target,
            },
        )

    return score

src/openbench/scorers/mmlu.py (5.9 KiB)

import re
from inspect_ai.solver import TaskState
from collections import defaultdict
import numpy as np
from typing import Callable
from inspect_ai.scorer import (
    accuracy,
    scorer,
    std,
    stderr,
    Metric,
    Value,
    SampleScore,
    Target,
    Score,
    metric,
)
from openbench.metrics.grouped import grouped
from openbench.utils.text import (
    strip_md_latex,
    normalize_mcq_answer,
    MULTILINGUAL_ANSWER_PATTERN_TEMPLATE,
    MULTILINGUAL_ANSWER_REGEXES,
)

# Adapted from https://github.com/openai/simple-evals
SUBJECT_TO_CATEGORY = {
    "abstract_algebra": "stem",
    "anatomy": "other",
    "astronomy": "stem",
    "business_ethics": "other",
    "clinical_knowledge": "other",
    "college_biology": "stem",
    "college_chemistry": "stem",
    "college_computer_science": "stem",
    "college_mathematics": "stem",
    "college_medicine": "other",
    "college_physics": "stem",
    "computer_security": "stem",
    "conceptual_physics": "stem",
    "econometrics": "social_sciences",
    "electrical_engineering": "stem",
    "elementary_mathematics": "stem",
    "formal_logic": "humanities",
    "global_facts": "other",
    "high_school_biology": "stem",
    "high_school_chemistry": "stem",
    "high_school_computer_science": "stem",
    "high_school_european_history": "humanities",
    "high_school_geography": "social_sciences",
    "high_school_government_and_politics": "social_sciences",
    "high_school_macroeconomics": "social_sciences",
    "high_school_mathematics": "stem",
    "high_school_microeconomics": "social_sciences",
    "high_school_physics": "stem",
    "high_school_psychology": "social_sciences",
    "high_school_statistics": "stem",
    "high_school_us_history": "humanities",
    "high_school_world_history": "humanities",
    "human_aging": "other",
    "human_sexuality": "social_sciences",
    "international_law": "humanities",
    "jurisprudence": "humanities",
    "logical_fallacies": "humanities",
    "machine_learning": "stem",
    "management": "other",
    "marketing": "other",
    "medical_genetics": "other",
    "miscellaneous": "other",
    "moral_disputes": "humanities",
    "moral_scenarios": "humanities",
    "nutrition": "other",
    "philosophy": "humanities",
    "prehistory": "humanities",
    "professional_accounting": "other",
    "professional_law": "humanities",
    "professional_medicine": "other",
    "professional_psychology": "social_sciences",
    "public_relations": "social_sciences",
    "security_studies": "social_sciences",
    "sociology": "social_sciences",
    "us_foreign_policy": "social_sciences",
    "virology": "other",
    "world_religions": "humanities",
}


@metric
def category_accuracy_metrics() -> Metric:
    """
    Calculates accuracy and standard deviation for specific subject categories:
    stem, other, social_sciences, humanities.
    """

    def metric_calculator(scores: list[SampleScore]) -> Value:  # Value will be a dict
        # Define the categories we care about for reporting
        categories_to_report = ["stem", "other", "social_sciences", "humanities"]

        # Initialize results with default values for all expected metrics
        results = {}
        for cat_name in categories_to_report:
            results[cat_name] = 0.0
            results[f"{cat_name}:std"] = 0.0

        if not scores:
            return results  # type: ignore # Return defaults if no scores

        # Use defaultdict to easily collect scores per category
        category_float_scores = defaultdict(list)

        for sample_score in scores:
            try:
                # Get the float value of the score (e.g., 1.0 for correct, 0.0 for incorrect)
                float_val = sample_score.score.as_float()
            except ValueError:
                # Log or handle if a score can't be converted, then skip it for these metrics
                print(
                    f"Warning: Could not convert score value '{sample_score.score.value}' "
                    f"to float for sample {sample_score.sample_id}. Skipping for category metrics."
                )
                continue  # Skip this sample_score for category calculations

            # Get subject and map to category
            if (
                sample_score.sample_metadata
                and "subject" in sample_score.sample_metadata
            ):
                subject = sample_score.sample_metadata["subject"]
                category = SUBJECT_TO_CATEGORY.get(subject)
                if (
                    category in categories_to_report
                ):  # Only collect for categories we're reporting
                    category_float_scores[category].append(float_val)

        # Calculate and populate per-category metrics in the results dictionary
        for cat_name in categories_to_report:
            cat_scores = category_float_scores[cat_name]
            if cat_scores:  # If there are any scores for this category
                results[cat_name] = float(np.mean(cat_scores))
                results[f"{cat_name}:std"] = float(np.std(cat_scores))
            # If no scores for a category, it keeps the default 0.0 values initialized earlier

        return results  # type: ignore

    return metric_calculator


@scorer(metrics=[grouped(group_key="category", metric=[accuracy(), stderr(), std()])])
def mmlu_simple_eval_scorer() -> Callable:
    async def score(state: TaskState, target: Target) -> Score:
        response_text = strip_md_latex(state.output.completion)
        extracted_answer = None
        for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
            regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
            match = re.search(regex, response_text)
            if match:
                extracted_answer = normalize_mcq_answer(match.group(1))
                break
        return (
            Score(value="C", answer=extracted_answer)
            if extracted_answer == target.text
            else Score(value="I", answer=extracted_answer)
        )

    return score

src/openbench/scorers/mrcr.py (4.6 KiB)

from __future__ import annotations

from difflib import SequenceMatcher
from typing import Callable

from inspect_ai.scorer import (
    Score,
    Target,
    scorer,
    mean,
    Metric,
    Value,
    SampleScore,
    metric,
)
from inspect_ai.solver import TaskState

from openbench.utils.text import get_token_count

OPENAI_MRCR_BINS = [
    (4096, 8192),
    (8192, 16384),
    (16384, 32768),
    (32768, 65536),
    (65536, 131072),
    (131072, 262144),
    (262144, 524288),
    (524288, 1048576),
]


def _sequence_ratio(
    response: str, answer: str, random_string_to_prepend: str | None
) -> float:
    """Compute SequenceMatcher ratio with MRCR's prefix handling.

    If a random prefix is provided, the ratio is computed after removing the
    prefix from both strings. If the response does not start with the prefix,
    the ratio is 0, matching the reference implementation behavior.
    """
    if (
        not isinstance(random_string_to_prepend, str)
        or len(random_string_to_prepend) == 0
    ):
        return float(SequenceMatcher(None, response, answer).ratio())

    if not response.startswith(random_string_to_prepend):
        return 0.0

    response = response.removeprefix(random_string_to_prepend)
    answer = answer.removeprefix(random_string_to_prepend)
    return float(SequenceMatcher(None, response, answer).ratio())


@metric
def mrcr_metrics() -> Metric:
    """Calculate MRCR specific metrics: accuracy by token count bin.

    Bin boundaries are:
    [4096, 8192], (8192, 16384], (16384, 32768], (32768, 65536], (65536, 131072], (131072, 262144], (262144, 524288], (524288, 1048576]
    """

    def metric_calculator(scores: list[SampleScore]) -> Value:
        accuracy_by_token_count_bin: dict[str, float] = {}
        bin_counts: dict[str, int] = {}

        for left_bin, right_bin in OPENAI_MRCR_BINS:
            bin_key = f"{left_bin}-{right_bin}"
            accuracy_by_token_count_bin[bin_key] = 0.0
            bin_counts[bin_key] = 0

        if not scores:
            return accuracy_by_token_count_bin

        for sample_score in scores:
            if sample_score.score.metadata is None:
                continue
            bin_index = sample_score.score.metadata.get("bin_index")
            if (
                not isinstance(bin_index, int)
                or bin_index < 0
                or bin_index >= len(OPENAI_MRCR_BINS)
            ):
                continue
            left_bin, right_bin = OPENAI_MRCR_BINS[bin_index]
            bin_key = f"{left_bin}-{right_bin}"
            accuracy_by_token_count_bin[bin_key] += sample_score.score.as_float()
            bin_counts[bin_key] += 1

        # calculate accuracy for each bin
        for bin in accuracy_by_token_count_bin:
            if bin_counts[bin] == 0:
                continue
            accuracy_by_token_count_bin[bin] = (
                accuracy_by_token_count_bin[bin] / bin_counts[bin]
            )

        return accuracy_by_token_count_bin

    return metric_calculator


@scorer(metrics=[mean(), mrcr_metrics()])
def mrcr_scorer() -> Callable:
    """Scorer for MRCR.

    Produces two values in the returned score:
    - value: CORRECT or INCORRECT depending on exact string equality of the
      model response and the target answer.
    - metadata.sequence_ratio: SequenceMatcher ratio computed after handling the
      random prefix as in the reference implementation.

    Args:
        None
    """

    async def score(state: TaskState, target: Target) -> Score:
        response = state.output.completion or ""
        answer = target.text

        prefix = (
            state.metadata.get("random_string_to_prepend") if state.metadata else None
        )
        ratio = _sequence_ratio(
            response=response, answer=answer, random_string_to_prepend=prefix
        )

        # get token count of input and target
        input_tok_cnt = state.metadata.get("raw_input_tok_cnt", 0)
        output_tok_cnt = get_token_count(target.text)
        total_tok_cnt = input_tok_cnt + output_tok_cnt
        state.metadata["total_tok_cnt"] = total_tok_cnt

        # get bin index
        bin_index = 0
        for i, (left_bin, right_bin) in enumerate(OPENAI_MRCR_BINS):
            if i == 0 or i == len(OPENAI_MRCR_BINS) - 1:
                if left_bin <= total_tok_cnt <= right_bin:
                    bin_index = i
                    break
            else:
                if left_bin <= total_tok_cnt < right_bin:
                    bin_index = i
                    break

        return Score(
            value=ratio,
            answer=response,
            explanation=None,
            metadata={"bin_index": bin_index},
        )

    return score

src/openbench/scorers/musr.py (425 B)

from typing import Callable
from inspect_ai.scorer import (
    accuracy,
    choice,
    scorer,
    std,
    stderr,
)
from openbench.metrics.grouped import grouped


@scorer(metrics=[grouped(group_key="subset", metric=[accuracy(), stderr(), std()])])
def musr_grouped_scorer() -> Callable:
    """Scorer for MuSR that groups results by subset (murder_mysteries, object_placements, team_allocation)."""
    return choice()

src/openbench/scorers/robust_boxed.py (6.5 KiB)

"""Enhanced boxed answer extraction scorer with better fallback logic."""

# Adapted from https://github.com/openai/gpt-oss
import re
from typing import Optional

from inspect_ai.scorer import (
    Score,
    Scorer,
    scorer,
    CORRECT,
    INCORRECT,
    Target,
    accuracy,
    std,
    stderr,
)
from inspect_ai.solver import TaskState


def extract_boxed_answer(
    text: str, fallback_to_last_number: bool = True
) -> Optional[str]:
    """
    Extract answer from LaTeX boxed format with optional fallback.

    Searches for answers in \boxed{}, \fbox{}, or \framebox{} commands. If no boxed
    answer is found and fallback is enabled, returns the last number in the text.


    Args:
        text: The text to search for boxed answers
        fallback_to_last_number: Whether to fall back to last number if no box found

    Returns:
        The extracted answer string, or None if not found
    """
    # Look for boxed, fbox, or framebox patterns
    pattern = r"\\(?:boxed|fbox|framebox)\{([^}]*?)\}"
    matches = re.findall(pattern, text, re.DOTALL)

    if matches:
        # Get the last boxed answer (most likely to be final answer)
        answer = matches[-1]
        # If there are nested braces, extract innermost content
        if "," in answer:
            # Sometimes answers have extra formatting, take last part
            answer = answer.split(",")[-1]
        return answer.strip()

    # Fallback to last number if enabled
    if fallback_to_last_number:
        # Find all numbers (including negative)
        numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
        if numbers:
            return numbers[-1]

    return None


def normalize_numeric_answer(answer: str) -> Optional[str]:
    """
    Normalize a numeric answer for comparison.

    Handles various number formats including:
    - Removing commas
    - Extracting leading integers
    - Removing trailing zeros after decimal


    Args:
        answer: The answer string to normalize

    Returns:
        Normalized answer string, or None if not a valid number
    """
    if not answer:
        return None

    # Remove commas and extra whitespace
    answer = answer.replace(",", "").strip()

    # Try to extract integer from start (for AIME-style answers)
    match = re.match(r"^-?\d+", answer)
    if match:
        return match.group(0)

    # Try to parse as float and normalize
    try:
        num = float(answer)
        # If it's a whole number, return as integer
        if num == int(num):
            return str(int(num))
        # Otherwise remove trailing zeros
        return str(num).rstrip("0").rstrip(".")
    except (ValueError, TypeError):
        return None


@scorer(metrics=[accuracy(), std(), stderr()])
def robust_boxed_scorer(
    fallback_to_last_number: bool = True, normalize_numbers: bool = True
) -> Scorer:
    """
    Enhanced scorer for LaTeX boxed answers with intelligent fallbacks.

    This scorer extracts answers from \boxed{} or \framebox{} commands with
    optional fallback to the last number in the response. It also handles
    number normalization for more robust comparison.

    Args:
        fallback_to_last_number: Whether to use last number if no boxed answer
        normalize_numbers: Whether to normalize numeric answers before comparison

    Returns:
        Scorer: A scoring function with enhanced answer extraction
    """

    async def score(state: TaskState, target: Target) -> Score:
        extracted = extract_boxed_answer(
            state.output.completion, fallback_to_last_number=fallback_to_last_number
        )

        if extracted is None:
            return Score(
                value=INCORRECT,
                answer=None,
                explanation="No boxed answer or number found in response",
            )

        # Normalize if requested
        if normalize_numbers:
            extracted_norm = normalize_numeric_answer(extracted)
            target_norm = normalize_numeric_answer(target.text.strip())

            if extracted_norm is not None and target_norm is not None:
                is_correct = extracted_norm == target_norm
            else:
                # Fall back to string comparison if normalization fails
                is_correct = extracted.strip() == target.text.strip()
        else:
            is_correct = extracted.strip() == target.text.strip()

        return Score(
            value=CORRECT if is_correct else INCORRECT,
            answer=extracted,
            explanation=f"Extracted '{extracted}' from response, target was '{target.text}'",
        )

    return score


@scorer(metrics=[accuracy(), std(), stderr()])
def aime_scorer() -> Scorer:
    """
    Specialized scorer for AIME math competition problems.

    AIME answers are always integers between 0 and 999. This scorer:
    1. Tries to extract from \boxed{} or \framebox{}
    2. Falls back to last integer if no boxed answer
    3. Validates the answer is in valid AIME range
    4. Compares as integers


    Returns:
        Scorer: A scoring function optimized for AIME problems
    """

    async def score(state: TaskState, target: Target) -> Score:
        extracted = extract_boxed_answer(
            state.output.completion, fallback_to_last_number=True
        )

        if extracted is None:
            return Score(
                value=INCORRECT, answer=None, explanation="No answer found in response"
            )

        # Try to parse as integer (AIME answers are always integers)
        try:
            # Extract just the integer part
            match = re.match(r"^-?\d+", extracted.replace(",", "").strip())
            if match:
                answer_int = int(match.group(0))
            else:
                answer_int = int(float(extracted))

            # Validate AIME range
            if not (0 <= answer_int <= 999):
                return Score(
                    value=INCORRECT,
                    answer=str(answer_int),
                    explanation=f"Answer {answer_int} outside valid AIME range [0, 999]",
                )

            # Compare as integers
            target_int = int(target.text.strip())
            is_correct = answer_int == target_int

            return Score(
                value=CORRECT if is_correct else INCORRECT,
                answer=str(answer_int),
                explanation=f"Extracted {answer_int} from response, target was {target_int}",
            )

        except (ValueError, TypeError) as e:
            return Score(
                value=INCORRECT,
                answer=extracted,
                explanation=f"Could not parse '{extracted}' as integer: {e}",
            )

    return score

src/openbench/scorers/robust_mcq.py (6.3 KiB)

"""Robust multiple choice answer extraction scorer."""

import re
from typing import Optional

from inspect_ai.scorer import (
    Score,
    Scorer,
    scorer,
    CORRECT,
    INCORRECT,
    Target,
    accuracy,
    std,
    stderr,
)
from inspect_ai.solver import TaskState


# Adapted from https://github.com/openai/gpt-oss
# Comprehensive patterns for extracting MCQ answers, ordered by priority
MCQ_PATTERNS = [
    # 0) Markdown-wrapped "Answer(s)" with letter
    re.compile(
        r"""(?ix)                   # case-insensitive, ignore-space
        (?:\*{1,2}|_{1,2})          # leading *…*  or _…_
        Answer[s]?                  #   Answer or Answers
        \s*[:\-–]?                  #   optional separator
        (?:\*{1,2}|_{1,2})          # closing wrapper
        \s*                         # optional space
        ([ABCD])\b                  # the actual letter
        """,
        re.X,
    ),
    # 0.1) Answer at start of line with various formats
    re.compile(
        r"""(?ix)           # ignore case, allow verbose mode
        ^\s*                        # optional leading whitespace
        (?:\*{1,2}|_{1,2})?         # optional markdown wrapper
        Answer:?                    # the word 'answer' with optional colon
        (?:\*{1,2}|_{1,2})?         # optional markdown wrapper again
        \s*:?\s*                    # optional colon with optional spaces
        (?:\*{1,2}|_{1,2})?         # optional markdown wrapper before letter
        ([ABCD])                    # capture the letter
        (?:\*{1,2}|_{1,2})?         # optional markdown wrapper after letter
        \s*                         # optional trailing whitespace
    """,
        re.MULTILINE,
    ),
    # 1) Answer: (C) or Answers: (B)
    re.compile(r"(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*\(\s*([ABCD])\s*\)"),
    # 2) Answer: C or Answers – D
    re.compile(r"(?ix)\bAnswer[s]?\b\s*[:\-–]?\s*([ABCD])\b"),
    # 3) Option B or Choice: C
    re.compile(r"(?ix)\b(?:Option|Choice)\b\s*[:\-–]?\s*([ABCD])\b"),
    # 4) LaTeX \boxed{...A...}
    re.compile(r"(?x)\\boxed\{[^}]*?([ABCD])[^}]*\}", re.MULTILINE),
    # 5) LaTeX \boxed{\textbf{...C...}}
    re.compile(
        r"(?x)\\boxed\{[^}]*?\\textbf\{[^}]*?([ABCD])[^}]*\}[^}]*\}", re.MULTILINE
    ),
    # 6) LaTeX \boxed{\text{...C...}}
    re.compile(
        r"(?x)\\boxed\{[^}]*?\\text\{[^}]*?([ABCD])[^}]*\}[^}]*\}", re.MULTILINE
    ),
    # 7) Bare parentheses or brackets: (A) [B]
    re.compile(r"(?x)(?<![A-Za-z0-9])[\(\[]\s*([ABCD])\s*[\)\]](?![A-Za-z0-9])"),
    # 8) Markdown-wrapped: *A* **B** _C_ __D__
    re.compile(
        r"(?x)(?<![A-Za-z0-9])(?:\*{1,2}|_{1,2})([ABCD])(?:\*{1,2}|_{1,2})(?![A-Za-z0-9])"
    ),
    # 9) LaTeX \textbf{...C...}
    re.compile(r"(?x)\\textbf\{[^}]*?([ABCD])[^}]*\}"),
    # 10) Markdown-wrapped answer with description: **D) …**
    re.compile(r"""(?x)            # ignore whitespace in pattern
        (?<![A-Za-z0-9])            # not preceded by word-char
        (?:\*{1,2}|_{1,2})          # opening ** or __ or * or _
        \s*([ABCD])\)               # capture letter plus ")"
        [^*_\n]+?                   # some text inside wrapper
        (?:\*{1,2}|_{1,2})          # closing wrapper
        (?![A-Za-z0-9])             # not followed by word-char
    """),
    # 11) Final fallback: line that's exactly "A", "B.", "C)", "**D**", etc.
    re.compile(
        r"""(?x)^\s*
        (?:\*{1,2}|_{1,2})?         # optional markdown wrapper
        ([ABCD])                    # capture group for letter
        (?:\*{1,2}|_{1,2})?         # optional closing markdown
        \s*[\.\)\-–:]?              # optional separator after the letter
        \s*.*$                      # allow any following text
    """,
        re.MULTILINE,
    ),
]


def extract_mcq_answer(text: str) -> Optional[str]:
    """
    Extract multiple choice answer (A, B, C, or D) from text using comprehensive patterns.

    Searches through multiple regex patterns to find answer declarations in various
    formats including markdown, LaTeX, and plain text. Patterns are ordered by
    specificity and reliability.


    Args:
        text: The text to search for an answer

    Returns:
        Single letter A, B, C, or D if found, otherwise the first character
        of the text (after removing markdown) as a fallback
    """
    matches = []

    # Try all patterns and collect matches with priority
    for priority, pattern in enumerate(MCQ_PATTERNS):
        match = pattern.search(text)
        if match:
            letter = match.group(1).upper()
            if letter in "ABCD":
                matches.append((priority, match, letter))

    # Sort by priority (lower is better) and match length (longer is better)
    matches.sort(key=lambda triple: (triple[0], -len(triple[1].group(0))))

    # Return the best match if found
    if matches:
        return matches[0][2]

    # Final fallback: return first character after stripping markdown
    cleaned = text.removeprefix("**").strip()
    if cleaned and cleaned[0].upper() in "ABCD":
        return cleaned[0].upper()

    return None


@scorer(metrics=[accuracy(), std(), stderr()])
def robust_mcq_scorer() -> Scorer:
    """
    A robust scorer for multiple choice questions with comprehensive pattern matching.

    This scorer uses multiple regex patterns to extract MCQ answers from various
    formats including markdown, LaTeX, and plain text. It's more robust than
    simple pattern matching and handles edge cases better.

    Returns:
        Scorer: A scoring function that returns a Score with:
            - value: CORRECT if extracted answer matches target, INCORRECT otherwise
            - answer: The extracted answer if found
            - explanation: Details about the extraction method used
    """

    async def score(state: TaskState, target: Target) -> Score:
        extracted = extract_mcq_answer(state.output.completion)

        if extracted is None:
            return Score(
                value=INCORRECT,
                answer=None,
                explanation="No multiple choice answer found in response",
            )

        is_correct = extracted == target.text.strip().upper()

        return Score(
            value=CORRECT if is_correct else INCORRECT,
            answer=extracted,
            explanation=f"Extracted '{extracted}' from response, target was '{target.text}'",
        )

    return score

src/openbench/scorers/rootly_gmcq.py (432 B)

from inspect_ai.scorer import Score, Target, accuracy, stderr, scorer
from inspect_ai.solver import TaskState


@scorer(metrics=[accuracy(), stderr()])
def custom_scorer():
    async def score(state: TaskState, target: Target) -> Score:
        if state.messages[-1].content.strip().upper() == target.target[0]:  # type: ignore
            return Score(value=1.0)
        else:
            return Score(value=0.0)

    return score

src/openbench/scorers/scicode.py (5.3 KiB)

from inspect_ai.scorer import Metric, Score, mean, scorer
from inspect_ai.solver import TaskState
from inspect_ai.scorer import Target, metric
from pathlib import Path
import time
import shutil
import subprocess
from typing import Any


class ScicodeEvaluator:
    def __init__(
        self,
        h5py_file: str,
        code_dir: Path,
        log_dir: Path,
        with_background: bool,
    ):
        self.h5py_file = h5py_file
        self.code_dir = code_dir
        self.log_dir = log_dir
        self.with_background = with_background

    def _get_background_dir(self):
        return "with_background" if self.with_background else "without_background"

    def test_code(
        self,
        prob_data: dict,
    ):
        code_dir = Path(self.code_dir, "generated_code", self._get_background_dir())
        tmp_dir = Path(f"tmp_{time.time()}")
        tmp_dir.mkdir(parents=True, exist_ok=True)

        sub_steps = prob_data["sub_steps"]
        problem_id = prob_data["problem_id"]
        for idx in range(len(sub_steps)):
            if (
                (problem_id == "13" and idx == 5)
                or (problem_id == "62" and idx == 0)
                or (problem_id == "76" and idx == 2)
            ):
                continue
            step_id = sub_steps[idx]["step_number"]
            code_file_path = Path(code_dir, f"{step_id}.py")
            assert code_file_path.is_file(), f"Code file {code_file_path} not found."
            code_content = code_file_path.read_text(encoding="utf-8")
            test_lst = sub_steps[idx]["test_cases"]
            assert_file = Path(tmp_dir, f"{step_id}.py")
            with open(assert_file, "w", encoding="utf-8") as f:
                f.write(code_content)
                f.write("""

from scicode.parse.parse import process_hdf5_to_tuple

""")
                f.write(
                    f"targets = process_hdf5_to_tuple('{step_id}', {len(test_lst)}, '{self.h5py_file}')"
                    + "\n"
                )
                for i in range(len(test_lst)):
                    f.write(f"target = targets[{i}]\n\n")
                    for line in test_lst[i].split("\n"):
                        f.write(line + "\n")

        def run_script(script_path):
            try:
                subprocess.run(
                    ["python", script_path],
                    check=True,
                    capture_output=True,
                    text=True,
                    timeout=1800,
                )
                return 0
            except subprocess.CalledProcessError:
                return 1
            except subprocess.TimeoutExpired:
                return 2

        total_steps = len(sub_steps)
        total_correct = 0
        for idx in range(len(sub_steps)):
            if (
                (problem_id == "13" and idx == 5)
                or (problem_id == "62" and idx == 0)
                or (problem_id == "76" and idx == 2)
            ):
                continue
            step_id = sub_steps[idx]["step_number"]
            script_path = Path(tmp_dir, f"{step_id}.py")
            logs_dir = Path(self.log_dir, "evaluation_logs", self._get_background_dir())
            logs_dir.mkdir(parents=True, exist_ok=True)
            logs_file = Path(logs_dir, f"{step_id}.log")
            if logs_file.is_file():
                with open(logs_file, "r") as f:
                    content = f.read().splitlines()
                    if content[0] == "pass":
                        total_correct += 1
                continue
            ret = run_script(script_path)
            if ret == 0:
                with open(logs_file, "w") as f:
                    f.write("pass")
                total_correct += 1
            elif ret == 1:
                with open(logs_file, "w") as f:
                    f.write("fail")
            else:
                with open(logs_file, "w") as f:
                    f.write("time out")

        shutil.rmtree(tmp_dir)
        problem_correct = 1 if total_correct == total_steps else 0
        return problem_correct, total_correct, total_steps


@metric
def sub_problem_correctness() -> Metric:
    def metric(scores: list[Score]) -> int | float:
        total_correct = 0
        total_steps = 0
        for score in scores:
            total_correct += score.value["Total Correct"]  # type: ignore
            total_steps += score.value["Total Steps"]  # type: ignore
        return total_correct / total_steps

    return metric


@scorer(
    metrics=[
        {
            "Problem Correctness": [mean()],
        },
        sub_problem_correctness(),
    ]
)
def scicode_scorer(**params: dict[str, Any]):
    async def score(state: TaskState, target: Target):
        model_name = str(state.model).replace("/", "-")
        evaluator = ScicodeEvaluator(
            h5py_file=params["h5py_file"],  # type: ignore
            code_dir=Path(params["output_dir"], model_name),  # type: ignore
            log_dir=Path(params["output_dir"], model_name),  # type: ignore
            with_background=params["with_background"],  # type: ignore
        )
        problem_correct, total_correct, total_steps = evaluator.test_code(
            state.metadata
        )
        return Score(
            value={
                "Problem Correctness": problem_correct,
                "Total Correct": total_correct,
                "Total Steps": total_steps,
            }
        )

    return score

src/openbench/scorers/score_boxed.py (1.4 KiB)

from inspect_ai.scorer import (
    Score,
    Scorer,
    scorer,
    CORRECT,
    INCORRECT,
    Target,
    accuracy,
    std,
    stderr,
)
from inspect_ai.solver import TaskState
import re


@scorer(metrics=[accuracy(), std(), stderr()])
def score_boxed() -> Scorer:
    """
    A scorer that evaluates answers enclosed in LaTeX \boxed{} or \fbox{} commands.

    This scorer searches for answers wrapped in either \boxed{} or \fbox{} commands
    in the model's output. If multiple boxed answers are found, it uses the last one.
    The answer is considered correct if it exactly matches the target text after
    stripping whitespace.

    Returns:
        Scorer: A scoring function that returns a Score with:
            - value: CORRECT if the boxed answer matches the target, INCORRECT otherwise
            - answer: The extracted answer if found, None if no boxed answer was found
    """

    async def score(state: TaskState, target: Target) -> Score:
        matches = re.findall(r"\\(?:boxed|fbox)\{([^}]*)\}", state.output.completion)
        if not matches:
            return Score(value=INCORRECT, answer=None)
        answer = matches[-1].strip()
        is_correct = answer == target.text.strip()
        return Score(
            value=CORRECT if is_correct else INCORRECT,
            answer=answer,
            explanation=f"The scorer used is score_boxed. The answer is {answer} and the target is {target.text}.",
        )

    return score

src/openbench/scorers/score_last_number.py (1.5 KiB)

from inspect_ai.scorer import (
    Score,
    Scorer,
    scorer,
    CORRECT,
    INCORRECT,
    Target,
    accuracy,
    std,
    stderr,
)
from inspect_ai.solver import TaskState
import re


@scorer(metrics=[accuracy(), std(), stderr()])
def score_last_integer() -> Scorer:
    async def score(state: TaskState, target: Target) -> Score:
        matches = re.findall(r"-?\d+", state.output.completion)  # Handles integers only
        if not matches:
            return Score(value=INCORRECT, answer=None)
        answer = matches[-1].strip()
        is_correct = answer == target.text.strip()
        return Score(
            value=CORRECT if is_correct else INCORRECT,
            answer=answer,
            explanation=f"The scorer used is score_last_integer. The answer is {answer} and the target is {target.text}.",
        )

    return score


@scorer(metrics=[accuracy(), std(), stderr()])
def score_last_number() -> Scorer:
    async def score(state: TaskState, target: Target) -> Score:
        matches = re.findall(
            r"-?\d+(?:\.\d+)?", state.output.completion
        )  # Handles ints and floats
        if not matches:
            return Score(value=INCORRECT, answer=None)
        answer = matches[-1].strip()
        is_correct = answer == target.text.strip()
        return Score(
            value=CORRECT if is_correct else INCORRECT,
            answer=answer,
            explanation=f"The scorer used is score_last_number. The answer is {answer} and the target is {target.text}.",
        )

    return score

src/openbench/scorers/simpleqa.py (9.6 KiB)

import re
from typing import Callable
from inspect_ai.scorer import (
    accuracy,
    scorer,
    stderr,
    Score,
    Target,
    metric,
    Metric,
    Value,
    SampleScore,
)
from inspect_ai.solver import TaskState
from inspect_ai.model import get_model, ChatMessageUser, Model


GRADER_TEMPLATE = """
Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
First, I will give examples of each grade, and then you will grade a new example.


The following are examples of CORRECT predicted answers.
```
Question: What are the names of Barack Obama's children?
Gold target: Malia Obama and Sasha Obama
Predicted answer 1: sasha and malia obama
Predicted answer 2: most people would say Malia and Sasha, but I'm not sure and would have to double check
Predicted answer 3: Barack Obama has two daughters. Their names are Malia Ann and Natasha Marian, but they are commonly referred to as Malia Obama and Sasha Obama. Malia was born on July 4, 1998, and Sasha was born on June 10, 2001.
```
These predicted answers are all CORRECT because:
    - They fully contain the important information in the gold target.
    - They do not contain any information that contradicts the gold target.
    - Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter.
    - Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions.


The following are examples of INCORRECT predicted answers.
```
Question: What are the names of Barack Obama's children?
Gold target: Malia and Sasha
Predicted answer 1: Malia.
Predicted answer 2: Malia, Sasha, and Susan.
Predicted answer 3: Barack Obama does not have any children.
Predicted answer 4: I think it's either Malia and Sasha. Or it could be Malia and Jackie. Or it could be Joey and Malia.
Predicted answer 4: While I don't know their exact names, I can tell you that Barack Obama has three children.
Predicted answer 5: It's possible you may mean Betsy and Olivia. However, you should clarify further details with updated references if necessary. Is that the correct answer?
Predicted answer 6: It may be the case that Obama's child is named James. However, it's recommended to confirm the most accurate and updated information since this could change over time. This model may not always reflect the most current information.
```
These predicted answers are all INCORRECT because:
    - A factual statement in the answer contradicts the gold target. Incorrect statements that have some hedging (e.g., "it is possible that", "although i'm not sure, i think") are also considered incorrect.


The following are examples of NOT_ATTEMPTED predicted answers.
```
Question: What are the names of Barack Obama's children?
Gold target: Malia and Sasha
Predicted answer 1: I don't know.
Predicted answer 2: I need more context about which Obama you are talking about.
Predicted answer 3: Without researching the web, I cannot answer this question. However, I can tell you that Barack Obama has two children.
Predicted answer 4: Barack Obama has two children. I know that one of them is Malia, but I'm not sure about the other one.
```
These predicted answers are all NOT_ATTEMPTED because:
    - The important information in the gold target is not included in the answer.
    - No statements in the answer contradict the gold target.


Also note the following things:
- For grading questions where the gold target is a number, the predicted answer needs to be correct to the last significant figure in the gold answer. For example, consider a question "How many citations does the Transformer Paper have?" with gold target "120k". 
    - Predicted answers "120k", "124k", and 115k" are all CORRECT. 
    - Predicted answers "100k" and "113k" are INCORRECT. 
    - Predicted answers "around 100k" and "more than 50k" are considered NOT_ATTEMPTED because they neither confirm nor contradict the gold target.
- The gold target may contain more information than the question. In such cases, the predicted answer only needs to contain the information that is in the question.
    - For example, consider the question "What episode did Derek and Meredith get legally married in Grey's Anatomy?" with gold target "Season 7, Episode 20: White Wedding". Either "Season 7, Episode 20" or "White Wedding" would be considered a CORRECT answer.
- Do not punish predicted answers if they omit information that would be clearly inferred from the question.
    - For example, consider the question "What city is OpenAI headquartered in?" and the gold target "San Francisco, California". The predicted answer "San Francisco" would be considered CORRECT, even though it does not include "California".
    - Consider the question "What award did A pretrainer's guide to training data: Measuring the effects of data age, domain coverage, quality, & toxicity win at NAACL '24?", the gold target is "Outstanding Paper Award". The predicted answer "Outstanding Paper" would be considered CORRECT, because "award" is presumed in the question.
    - For the question "What is the height of Jason Wei in meters?", the gold target is "1.73 m". The predicted answer "1.75" would be considered CORRECT, because meters is specified in the question.
    - For the question "What is the name of Barack Obama's wife?", the gold target is "Michelle Obama". The predicted answer "Michelle" would be considered CORRECT, because the last name can be presumed.
- Do not punish for typos in people's name if it's clearly the same name. 
    - For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung".


Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
```
Question: {question}
Gold target: {target}
Predicted answer: {predicted_answer}
```

Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
C: NOT_ATTEMPTED

Just return the letters "A", "B", or "C", with no text around it.
""".strip()


@metric
def simpleqa_metrics() -> Metric:
    """Calculate SimpleQA specific metrics: F1 and accuracy_given_attempted."""

    def metric_calculator(scores: list[SampleScore]) -> Value:
        if not scores:
            return {
                "is_correct": 0.0,
                "is_incorrect": 0.0,
                "is_not_attempted": 0.0,
                "is_given_attempted": 0.0,
                "accuracy_given_attempted": 0.0,
                "f1": 0.0,
            }

        # Count each grade type
        grade_counts = {"correct": 0, "incorrect": 0, "not_attempted": 0}

        for sample_score in scores:
            metadata = sample_score.score.metadata
            grade = metadata.get("grade", "").lower() if metadata else ""
            if grade in grade_counts:
                grade_counts[grade] += 1

        total = len(scores)
        is_correct = grade_counts["correct"] / total
        is_incorrect = grade_counts["incorrect"] / total
        is_not_attempted = grade_counts["not_attempted"] / total
        is_given_attempted = is_correct + is_incorrect

        # Calculate accuracy_given_attempted
        accuracy_given_attempted = (
            is_correct / is_given_attempted if is_given_attempted > 0 else 0.0
        )

        # Calculate F1
        f1 = (
            2
            * accuracy_given_attempted
            * is_correct
            / (accuracy_given_attempted + is_correct)
            if (accuracy_given_attempted + is_correct) > 0
            else 0.0
        )

        return {
            "is_correct": is_correct,
            "is_incorrect": is_incorrect,
            "is_not_attempted": is_not_attempted,
            "is_given_attempted": is_given_attempted,
            "accuracy_given_attempted": accuracy_given_attempted,
            "f1": f1,
        }

    return metric_calculator


@scorer(metrics=[accuracy(), stderr(), simpleqa_metrics()])
def simpleqa_scorer(model: str) -> Callable:
    """SimpleQA scorer using model grading."""

    grader_model: Model = get_model(model)

    async def score(state: TaskState, target: Target) -> Score:
        # Get the grader model (use the same model being evaluated)

        # Get question from the input
        question = state.input_text

        # Get the predicted answer from the model output
        predicted_answer = state.output.completion

        # Format the grading prompt
        grader_prompt = GRADER_TEMPLATE.format(
            question=question, target=target.text, predicted_answer=predicted_answer
        )

        # Create the message for grading
        message = ChatMessageUser(content=grader_prompt)

        # Get grading response
        grading_response = await grader_model.generate([message])
        grading_text = grading_response.completion

        # Extract the grade letter
        match = re.search(r"(A|B|C)", grading_text)
        grade_letter = match.group(0) if match else "C"  # Default to NOT_ATTEMPTED

        # Map letter to grade and score
        grade_map = {
            "A": ("correct", 1.0),
            "B": ("incorrect", 0.0),
            "C": ("not_attempted", 0.0),
        }

        grade_name, score_value = grade_map.get(grade_letter, ("not_attempted", 0.0))

        # Return score with metadata
        return Score(
            value=score_value,
            answer=predicted_answer,
            metadata={
                "grade": grade_name,
                "grade_letter": grade_letter,
                "grading_response": grading_text,
            },
        )

    return score

src/openbench/utils/__init__.py (39 B)

"""Core utilities for benchmarking."""

src/openbench/utils/imports.py (829 B)

"""Utilities for dynamic module imports."""

import os
import importlib.util


def import_module_from_same_dir(caller_file: str, module_name: str):
    """Import a module from the same directory as the caller file.

    Args:
        caller_file: The __file__ attribute of the calling module
        module_name: Name of the module to import (without .py extension)

    Returns:
        The imported module
    """
    current_dir = os.path.dirname(caller_file)
    module_path = os.path.join(current_dir, f"{module_name}.py")
    spec = importlib.util.spec_from_file_location(module_name, module_path)

    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot create import spec for {module_path}")

    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

src/openbench/utils/text.py (11.2 KiB)

import json
import tiktoken
from inspect_ai.model import (
    ChatMessageUser,
    ChatMessageAssistant,
    ChatMessageSystem,
    ChatMessage,
)

"""Text processing utilities for openbench.

This module contains helper functions for processing and normalizing text in various
benchmarking contexts, such as cleaning model outputs and standardizing answer formats.
"""


# Adapted from https://github.com/openai/simple-evals
def strip_md_latex(response: str) -> str:
    """
    Strip Markdown and LaTeX formatting artifacts from a model response.

    This is useful when evaluating generated text where visual formatting
    may interfere with exact string matching or scoring logic.

    Parameters:
        response (str): The raw response string potentially containing Markdown or LaTeX syntax.

    Returns:
        str: A cleaned string with Markdown and LaTeX formatting removed.
    """
    return (
        response.replace("**", "")
        .replace("$\\boxed{", "")
        .replace("}$", "")
        .replace("\\$", "")
        .replace("$\\text{", "")
        .replace("$", "")
        .replace("\\mathrm{", "")
        .replace("\\{", "")
        .replace("\\text", "")
        .replace("\\(", "")
        .replace("\\mathbf{", "")
        .replace("{", "")
        .replace("\\boxed", "")
    )


# Adapted from https://github.com/openai/simple-evals
def normalize_mcq_answer(extracted_answer: str) -> str:
    """
    Normalize multiple-choice answer letters to standard Latin A-D format.

    Converts commonly used localized characters (Arabic, Bengali, Japanese)
    representing multiple-choice options to their A-D equivalents. Useful for
    consistent scoring across multilingual datasets.

    Parameters:
        extracted_answer (str): A raw answer string with potential localized MCQ letters.

    Returns:
        str: A normalized answer string using A, B, C, or D.
    """
    return (
        # In Arabic these are the letters used for A-D in multiple choice questions
        extracted_answer.replace("أ", " A")
        .replace("ب", " B")
        .replace("ج", " C")
        .replace("د", " D")
        # In Bengali these are the letters used for A-D in multiple choice questions
        .replace("অ", " A")
        .replace("ব", " B")
        .replace("ড", " C")
        .replace("ঢ", " D")
        # In Japanese these are the letters sometimes used for A-D in multiple choice questions
        .replace("A", " A")
        .replace("B", " B")
        .replace("C", " C")
        .replace("D", " D")
        .strip()
    )


# Adapted from https://github.com/openai/simple-evals
SIMPLE_EVALS_SYSTEM_MESSAGE = "You are a helpful assistant."

# Adapted from https://github.com/openai/simple-evals. Removed the "Think step by step before answering." to make it faster and less leading.
MULTIPLE_CHOICE_PROMPT_TEMPLATE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD.

{prompt}

A) {option_a}
B) {option_b}
C) {option_c}
D) {option_d}
""".strip()

# Adapted from https://github.com/openai/simple-evals
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = (
    "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])"
)

# All the different ways "Answer" is written in different languages.
# Adapted from https://github.com/openai/simple-evals
MULTILINGUAL_ANSWER_REGEXES = [
    r"Answer\s*:",
    r"Answer\s*:​​​​​​",  # Korean invisible character
    r"উত্তর\s*:",
    r"उत्तर\s*:",
    r"উত্তরঃ",
    r"উত্তর\s*:",
    r"Antwort\s*:",
    r"답변\s*:",
    r"정답\s*:",
    r"답\s*:",
    r"答案\s*:",
    r"答案\s*:",
    r"答\s*:",
    r"答\s*:",
    r"答复\s*:",
    r"答曰\s*:",
    r"الإجابة:",
    r"الجواب:",
    r"إجابة:",
    r"الإجابة النهائية:",
    r"الإجابة الصحيحة:",
    r"الإجابة الصحيحة هي:",
    r"الإجابة هي:",
    r"الجواب النهائي:",
    r"Respuesta\s*:",
    r"Risposta\s*:",
    r"答え\s*:",
    r"答え\s*:",
    r"回答\s*:",
    r"回答\s*:",
    r"解答\s*:",
    r"Jawaban\s*:",
    r"Réponse\s*:",
    r"Resposta\s*:",
    r"Jibu\s*:",
    r"Idahun\s*:",
    r"Ìdáhùn\s*:",
    r"Idáhùn\s*:",
    r"Àmọ̀nà\s*:",
    r"Àdáhùn\s*:",
    r"Ànúgọ\s*:",
    r"Àṣàyàn\s*:",
]

# Adapted from https://github.com/openai/simple-evals
ANSWER_PATTERN_MULTIPLE_CHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?"


def parse_json_from_response(text: str) -> dict:
    """
    Extract and parse JSON from a model response that may contain markdown formatting.

    This function handles common patterns where models wrap JSON in markdown code blocks
    or include extra text around the JSON object.

    Parameters:
        text (str): The model response potentially containing JSON

    Returns:
        dict: Parsed JSON as a dictionary, or empty dict if parsing fails
    """
    import json
    import re

    # First try to extract from markdown code blocks
    json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", text, re.DOTALL)
    if json_match:
        text = json_match.group(1)

    # Try direct parsing
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        # Try to find JSON object in the text
        json_pattern = r"\{[^{}]*\}"
        matches = re.findall(json_pattern, text)
        for match in matches:
            try:
                return json.loads(match)
            except json.JSONDecodeError:
                continue
        return {}


def format_chat_messages(messages: list) -> str:
    """
    Format a list of chat messages into a readable conversation string.

    Handles both dictionary-style messages and ChatMessage objects from Inspect AI.

    Parameters:
        messages (list): List of messages (dicts or ChatMessage objects)

    Returns:
        str: Formatted conversation with role labels
    """
    formatted = []
    for msg in messages:
        # Handle both dict messages and ChatMessage objects
        if isinstance(msg, dict):
            role = msg.get("role", "")
            content = msg.get("content", "")
        else:
            # ChatMessage object
            role = getattr(msg, "role", "")
            content = getattr(msg, "text", getattr(msg, "content", ""))

        if role and content:
            formatted.append(f"{role}: {content}")

    return "\n\n".join(formatted)


def parse_numeric_answer(response: str, answer_prefix: str = "Answer") -> str:
    """
    Extract a numerical answer from model response after a given prefix.

    Useful for math problems where the answer follows a pattern like "Answer: 42"
    or in other languages like "答え: 42". Extracts the last number found after
    the prefix, handling commas and decimal points.

    Parameters:
        response (str): Model's complete response
        answer_prefix (str): Prefix that precedes the answer (default: "Answer")

    Returns:
        str: Extracted numerical answer, or empty string if not found

    Examples:
        >>> parse_numeric_answer("The calculation gives us Answer: 42")
        '42'
        >>> parse_numeric_answer("答え: 3.14", "答え")
        '3.14'
        >>> parse_numeric_answer("Answer: 1,234.5")
        '1234.5'
    """
    import re

    if answer_prefix not in response:
        return ""

    # Get text after the answer prefix
    answer_text = response.split(answer_prefix)[-1].strip()

    # Remove colon if present
    if answer_text.startswith(":"):
        answer_text = answer_text[1:].strip()

    # Find all numbers (including decimals) in the string
    # Remove commas first, then extract numbers
    numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))

    # Return the last number (removing trailing decimal point if present)
    return numbers[-1].rstrip(".") if numbers else ""


def normalize_number(value: str) -> str:
    """
    Normalize a numerical string for comparison.

    Removes commas, trailing zeros after decimal points, and trailing decimal
    points. Useful for comparing numerical answers where formatting may vary.

    Parameters:
        value (str): String representation of a number

    Returns:
        str: Normalized number string

    Examples:
        >>> normalize_number("1,234")
        '1234'
        >>> normalize_number("3.1400")
        '3.14'
        >>> normalize_number("5.0")
        '5'
        >>> normalize_number("42.")
        '42'
    """
    # Remove commas
    value = value.replace(",", "")

    # If it has a decimal point, remove trailing zeros and the decimal point if needed
    if "." in value:
        value = value.rstrip("0").rstrip(".")

    return value


def extract_confidence_score(response: str, default: int = 100) -> int:
    """
    Extract a confidence score from model response.

    Looks for patterns like "Confidence: 85%", "confidence: 0.85", etc.
    Handles both percentage (0-100) and decimal (0-1) formats.

    Parameters:
        response (str): Model response potentially containing confidence score
        default (int): Default confidence to return if none found (default: 100)

    Returns:
        int: Confidence score between 0 and 100

    Examples:
        >>> extract_confidence_score("Answer: A\\nConfidence: 85%")
        85
        >>> extract_confidence_score("I am 0.95 confident in this answer")
        95
        >>> extract_confidence_score("No confidence mentioned")
        100
    """
    import re

    patterns = [
        r"[Cc]onfidence:\s*(\d+(?:\.\d+)?)\s*%",  # Confidence: 85%
        r"[Cc]onfidence:\s*(\d+)",  # Confidence: 85
        r"[Cc]onfidence:\s*(0?\.\d+)",  # Confidence: 0.85
        r"(\d+(?:\.\d+)?)\s*%\s*[Cc]onfident",  # 85% confident
        r"(0?\.\d+)\s*[Cc]onfident",  # 0.85 confident
    ]

    for pattern in patterns:
        match = re.search(pattern, response)
        if match:
            value = float(match.group(1))
            # Convert to percentage if it's a decimal
            if value <= 1:
                return int(value * 100)
            # Clamp to valid range
            return min(100, max(0, int(value)))

    return default


def str_to_chat_messages(messages_str: str) -> list[ChatMessage]:
    """
    Convert a string to a list of chat messages.

    Parameters:
        messages_str (str): The string to convert

    Returns:
        list[ChatMessage]: The list of chat messages
    """
    message_mapping = {
        "system": ChatMessageSystem,
        "user": ChatMessageUser,
        "assistant": ChatMessageAssistant,
    }
    messages = json.loads(messages_str)
    return [
        message_mapping[message["role"]](content=message["content"])
        for message in messages
    ]


def get_token_count(text: str, model: str = "gpt-4o") -> int:
    """
    Get the token count of a text.
    """
    return len(tiktoken.encoding_for_model(model).encode(text))


def get_chatml_tok_cnt(chat_messages_str: str) -> int:
    """
    Get the token count of a string in chatml format.
    """
    messages = json.loads(chat_messages_str)
    total = 3
    for message in messages:
        total += 3
        for key, value in message.items():
            total += get_token_count(value)
            if key == "name":
                total += 1
    return total

tests/__init__.py (35 B)

"""Tests for the bench package."""

tests/_cli/__init__.py (32 B)

"""Tests for the CLI module."""

tests/_cli/test_eval_command.py (1.0 KiB)

"""Simple unit tests for eval command."""

from typer.testing import CliRunner
from openbench._cli import app

runner = CliRunner()


def test_eval_requires_benchmark():
    """Test eval command requires a benchmark argument."""
    result = runner.invoke(app, ["eval"])
    assert result.exit_code != 0


def test_invalid_limit():
    """Test invalid limit parameter."""
    result = runner.invoke(app, ["eval", "mmlu", "--limit", "invalid"])
    assert result.exit_code != 0


def test_invalid_display():
    """Test invalid display parameter."""
    result = runner.invoke(app, ["eval", "mmlu", "--display", "invalid"])
    assert result.exit_code != 0


def test_invalid_sandbox():
    """Test invalid sandbox parameter."""
    result = runner.invoke(app, ["eval", "mmlu", "--sandbox", "invalid"])
    assert result.exit_code != 0


def test_invalid_reasoning_effort():
    """Test invalid reasoning effort parameter."""
    result = runner.invoke(app, ["eval", "mmlu", "--reasoning-effort", "invalid"])
    assert result.exit_code != 0

tests/_cli/test_eval_command_functions.py (1.4 KiB)

"""Simple unit tests for eval command helper functions."""

import pytest
from openbench._cli.eval_command import (
    parse_limit,
    validate_model_name,
    validate_model_role,
)


def test_parse_limit_none():
    """Test parsing None limit."""
    assert parse_limit(None) is None


def test_parse_limit_single():
    """Test parsing single integer limit."""
    assert parse_limit("10") == 10


def test_parse_limit_range():
    """Test parsing range limit."""
    assert parse_limit("5,15") == (5, 15)


def test_parse_limit_invalid():
    """Test invalid limit raises error."""
    with pytest.raises(Exception):
        parse_limit("invalid")


def test_validate_model_name_valid():
    """Test valid model name."""
    # Should not raise
    validate_model_name("provider/model-name")


def test_validate_model_name_invalid():
    """Test invalid model name."""
    with pytest.raises(Exception):
        validate_model_name("invalid-model")


def test_validate_model_role_empty():
    """Test empty model role."""
    assert validate_model_role(None) == {}
    assert validate_model_role("") == {}


def test_validate_model_role_valid():
    """Test valid model role."""
    result = validate_model_role("grader=provider/model")
    assert result == {"grader": "provider/model"}


def test_validate_model_role_invalid():
    """Test invalid model role."""
    with pytest.raises(Exception):
        validate_model_role("invalid-format")

tests/conftest.py (604 B)

"""Pytest configuration file for bench tests."""

import pytest
import os


@pytest.fixture(autouse=True)
def clean_environment():
    """Fixture to clean environment variables between tests."""
    # Store original environment
    initial_env = os.environ.copy()

    # Run the test
    yield

    # Clean up environment variables specifically for our tests
    for key in list(os.environ.keys()):
        if key.startswith("BENCH_") and key not in initial_env:
            os.environ.pop(key)
        elif key.startswith("BENCH_") and key in initial_env:
            os.environ[key] = initial_env[key]

tests/integration/__init__.py (0 B)


tests/integration/test_cli.py (2.2 KiB)

"""Simple integration tests for the bench CLI tool."""

import os
import pytest
from typer.testing import CliRunner
from openbench._cli import app

runner = CliRunner()

# Mark all tests in this module as integration tests
# Skip all tests if GROQ_API_KEY is not set
pytestmark = [
    pytest.mark.integration,
    pytest.mark.skipif(
        not os.getenv("GROQ_API_KEY"),
        reason="GROQ_API_KEY not set - skipping integration tests",
    ),
]


def test_help():
    """Test help command works."""
    result = runner.invoke(app, ["--help"])
    assert result.exit_code == 0
    assert "eval" in result.stdout


def test_basic_mmlu():
    """Test basic MMLU evaluation."""
    result = runner.invoke(
        app, ["eval", "mmlu", "--limit", "1", "--model", "groq/llama-3.1-8b-instant"]
    )
    assert result.exit_code == 0


def test_basic_gpqa():
    """Test basic GPQA evaluation."""
    result = runner.invoke(
        app,
        [
            "eval",
            "gpqa_diamond",
            "--limit",
            "1",
            "--model",
            "groq/llama-3.1-8b-instant",
        ],
    )
    assert result.exit_code == 0


def test_basic_humaneval():
    """Test basic HumanEval evaluation."""
    result = runner.invoke(
        app,
        [
            "eval",
            "humaneval",
            "--limit",
            "1",
            "--model",
            "groq/llama-3.1-8b-instant",
            "--epochs",
            "5",
        ],
    )
    assert result.exit_code == 0


def test_invalid_benchmark():
    """Test invalid benchmark name."""
    result = runner.invoke(
        app, ["eval", "invalid_benchmark", "--model", "groq/llama-3.1-8b-instant"]
    )
    assert result.exit_code != 0


def test_invalid_model_format():
    """Test invalid model format."""
    result = runner.invoke(app, ["eval", "mmlu", "--model", "invalid-model"])
    assert result.exit_code != 0


def test_model_and_model_role_conflict():
    """Test conflicting model specifications."""
    result = runner.invoke(
        app,
        [
            "eval",
            "mmlu",
            "--model",
            "groq/llama-3.1-8b-instant",
            "--model-role",
            "candidate=groq/llama-3.1-8b-instant",
        ],
    )
    assert result.exit_code != 0

tests/monkeypatch/test_file_recorder_logfile_patch.py (1.3 KiB)

"""Test the file recorder logfile monkey patch."""

import pytest
from unittest.mock import patch, MagicMock
from openbench.monkeypatch.file_recorder_logfile_patch import (
    patch_file_recorder_logfile,
)


@pytest.fixture
def mock_file_recorder_module():
    """Mock the inspect_ai.log._recorders.file module."""
    with patch("inspect_ai.log._recorders.file") as mock_module:
        # Set up the FileRecorder class within the module
        mock_module.FileRecorder = MagicMock()
        yield mock_module


def test_patch_file_recorder_logfile(mock_file_recorder_module):
    """Test that patch_file_recorder_logfile correctly patches the FileRecorder."""
    test_logfile = "test_logfile"

    # Apply the patch
    patch_file_recorder_logfile(test_logfile)

    # Verify that _log_file_key was assigned a new function
    assert mock_file_recorder_module.FileRecorder._log_file_key is not None

    # Create a mock instance and test the patched method
    mock_instance = MagicMock()
    mock_eval_spec = MagicMock()

    # Get the patched method and call it
    patched_method = mock_file_recorder_module.FileRecorder._log_file_key
    result = patched_method(mock_instance, mock_eval_spec)

    # Verify it returns the expected logfile name
    assert result == test_logfile

tests/test_json_schema_scorer.py (17.3 KiB)

"""Unit tests for JSON schema scorer."""

from unittest.mock import Mock
from inspect_ai.scorer import Target, Score, SampleScore, CORRECT, INCORRECT

from openbench.scorers.json_schema import (
    json_schema_scorer,
    json_validity,
    schema_compliance,
    api_success_rate,
)


def create_mock_state(
    completion: str, metadata: dict | None = None, error: str | None = None
) -> Mock:
    """Create a mock TaskState for testing."""
    mock_state = Mock()
    mock_state.output.completion = completion
    mock_state.output.error = error  # Add error attribute for API error testing
    mock_state.metadata = metadata or {}
    return mock_state


# Target typically contains expected answer for comparison, but json_schema_scorer
# only validates JSON structure against schema, so target is unused
TEST_TARGET = "test_target"


class TestJSONSchemaScorer:
    """Test the JSON schema scorer function."""

    async def test_valid_json_and_schema(self):
        """Test with valid JSON that conforms to schema."""
        schema = {
            "type": "object",
            "properties": {
                "name": {"type": "string"},
                "age": {"type": "integer", "minimum": 0},
            },
            "required": ["name", "age"],
        }

        state = create_mock_state(
            completion='{"name": "John", "age": 25}', metadata={"schema": schema}
        )
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == CORRECT
        assert result.answer == '{"name": "John", "age": 25}'
        assert result.metadata["json_valid"]
        assert result.metadata["schema_compliant"]
        assert not result.metadata["api_error"]
        assert result.metadata["error"] is None

    async def test_valid_json_invalid_schema(self):
        """Test with valid JSON that doesn't conform to schema."""
        schema = {
            "type": "object",
            "properties": {
                "name": {"type": "string"},
                "age": {"type": "integer", "minimum": 0},
            },
            "required": ["name", "age"],
        }

        state = create_mock_state(
            completion='{"name": "John"}',  # Missing required "age"
            metadata={"schema": schema},
        )
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == INCORRECT
        assert result.answer == '{"name": "John"}'
        assert result.metadata["json_valid"]
        assert not result.metadata["schema_compliant"]
        assert not result.metadata["api_error"]
        assert "schema_validation_error" in result.metadata["error"]

    async def test_invalid_json(self):
        """Test with invalid JSON."""
        schema = {"type": "object"}

        state = create_mock_state(
            completion='{"name": "John", invalid}', metadata={"schema": schema}
        )
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == INCORRECT
        assert result.answer == '{"name": "John", invalid}'
        assert not result.metadata["json_valid"]
        assert not result.metadata["schema_compliant"]
        assert not result.metadata["api_error"]
        assert "json_decode_error" in result.metadata["error"]

    async def test_no_schema_in_metadata(self):
        """Test when no schema is provided in metadata."""
        state = create_mock_state(
            completion='{"name": "John"}',
            metadata={},  # No schema
        )
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == INCORRECT
        assert result.answer == '{"name": "John"}'
        assert not result.metadata["json_valid"]
        assert not result.metadata["schema_compliant"]
        assert not result.metadata["api_error"]
        assert result.metadata["error"] == "no_schema"

    async def test_none_metadata(self):
        """Test when metadata is None."""
        state = create_mock_state(completion='{"name": "John"}', metadata=None)
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == INCORRECT
        assert result.answer == '{"name": "John"}'
        assert not result.metadata["json_valid"]
        assert not result.metadata["schema_compliant"]
        assert not result.metadata["api_error"]
        assert result.metadata["error"] == "no_schema"

    async def test_empty_completion(self):
        """Test with empty completion."""
        schema = {"type": "object"}

        state = create_mock_state(completion="", metadata={"schema": schema})
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == INCORRECT
        assert result.answer == ""
        assert not result.metadata["json_valid"]
        assert not result.metadata["schema_compliant"]
        assert not result.metadata["api_error"]
        assert "json_decode_error" in result.metadata["error"]

    async def test_whitespace_handling(self):
        """Test that whitespace is properly stripped for JSON parsing."""
        schema = {"type": "object", "properties": {"test": {"type": "boolean"}}}

        state = create_mock_state(
            completion='  {"test": true}  \n',  # Leading/trailing whitespace
            metadata={"schema": schema},
        )
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == CORRECT
        assert result.answer == '  {"test": true}  \n'  # Raw output preserved
        assert result.metadata["json_valid"]
        assert result.metadata["schema_compliant"]
        assert not result.metadata["api_error"]

    async def test_complex_schema(self):
        """Test with a more complex JSON schema."""
        schema = {
            "type": "object",
            "properties": {
                "users": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "name": {"type": "string"},
                            "email": {"type": "string", "format": "email"},
                        },
                        "required": ["name", "email"],
                    },
                }
            },
            "required": ["users"],
        }

        state = create_mock_state(
            completion='{"users": [{"name": "John", "email": "john@example.com"}]}',
            metadata={"schema": schema},
        )
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == CORRECT
        assert result.metadata["json_valid"]
        assert result.metadata["schema_compliant"]
        assert not result.metadata["api_error"]

    async def test_api_error_handling(self):
        """Test scorer handles API errors correctly."""
        schema = {"type": "object"}

        # Create state that simulates an API error
        state = create_mock_state(
            completion="",
            metadata={"schema": schema},
            error="API timeout error",  # Simulate API error
        )
        target = Target(TEST_TARGET)

        scorer_fn = json_schema_scorer()
        result = await scorer_fn(state, target)

        assert result.value == INCORRECT
        assert result.answer == ""
        assert not result.metadata["json_valid"]
        assert not result.metadata["schema_compliant"]
        assert result.metadata["api_error"]
        assert "api_error: API timeout error" in result.metadata["error"]


class TestJSONValidityMetric:
    """Test the JSON validity metric."""

    def test_all_valid_json(self):
        """Test metric with all valid JSON scores from successful API calls."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=CORRECT,
                    metadata={
                        "json_valid": True,
                        "schema_compliant": True,
                        "api_error": False,
                    },
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": True,
                        "schema_compliant": False,
                        "api_error": False,
                    },
                ),
            ),
        ]

        metric_fn = json_validity()
        result = metric_fn(scores)

        assert result == 1.0  # 2/2 successful API calls produced valid JSON

    def test_mixed_json_validity(self):
        """Test metric with mixed JSON validity from successful API calls."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=CORRECT,
                    metadata={
                        "json_valid": True,
                        "schema_compliant": True,
                        "api_error": False,
                    },
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": False,
                        "schema_compliant": False,
                        "api_error": False,
                    },
                ),
            ),
        ]

        metric_fn = json_validity()
        result = metric_fn(scores)

        assert result == 0.5  # 1/2 successful API calls produced valid JSON

    def test_no_metadata_scores(self):
        """Test metric with scores that have no metadata."""
        scores = [
            SampleScore(sample_id="1", score=Score(value=CORRECT)),  # No metadata
            SampleScore(
                sample_id="2", score=Score(value=INCORRECT, metadata=None)
            ),  # None metadata
        ]

        metric_fn = json_validity()
        result = metric_fn(scores)

        assert result == 0.0  # 0/0 successful API calls (no valid denominators)

    def test_with_api_errors(self):
        """Test metric excludes API errors from denominator (empirical coverage formula)."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": False,
                        "schema_compliant": False,
                        "api_error": True,
                    },
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=CORRECT,
                    metadata={
                        "json_valid": True,
                        "schema_compliant": True,
                        "api_error": False,
                    },
                ),
            ),
            SampleScore(
                sample_id="3",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": False,
                        "schema_compliant": False,
                        "api_error": False,
                    },
                ),
            ),
        ]

        metric_fn = json_validity()
        result = metric_fn(scores)

        assert (
            result == 0.5
        )  # 1/2 successful API calls produced valid JSON (API error excluded)

    def test_empty_scores(self):
        """Test metric with empty scores list."""
        metric_fn = json_validity()
        result = metric_fn([])

        assert result == 0.0


class TestSchemaComplianceMetric:
    """Test the schema compliance metric."""

    def test_all_compliant(self):
        """Test metric with all schema compliant JSON."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=CORRECT,
                    metadata={"json_valid": True, "schema_compliant": True},
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=CORRECT,
                    metadata={"json_valid": True, "schema_compliant": True},
                ),
            ),
        ]

        metric_fn = schema_compliance()
        result = metric_fn(scores)

        assert result == 1.0  # 2/2 compliant among valid JSON

    def test_mixed_compliance(self):
        """Test metric with mixed schema compliance."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=CORRECT,
                    metadata={"json_valid": True, "schema_compliant": True},
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=INCORRECT,
                    metadata={"json_valid": True, "schema_compliant": False},
                ),
            ),
        ]

        metric_fn = schema_compliance()
        result = metric_fn(scores)

        assert result == 0.5  # 1/2 compliant among valid JSON

    def test_no_valid_json(self):
        """Test metric when no JSON is valid."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=INCORRECT,
                    metadata={"json_valid": False, "schema_compliant": False},
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=INCORRECT,
                    metadata={"json_valid": False, "schema_compliant": False},
                ),
            ),
        ]

        metric_fn = schema_compliance()
        result = metric_fn(scores)

        assert result == 0.0  # No valid JSON to check compliance


class TestAPISuccessRateMetric:
    """Test the API success rate metric."""

    def test_all_api_success(self):
        """Test metric with all successful API calls (no API errors)."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=CORRECT,
                    metadata={
                        "json_valid": True,
                        "schema_compliant": True,
                        "api_error": False,
                    },
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": True,
                        "schema_compliant": False,
                        "api_error": False,
                    },
                ),
            ),
        ]

        metric_fn = api_success_rate()
        result = metric_fn(scores)

        assert result == 1.0  # 2/2 successful API calls

    def test_mixed_api_success(self):
        """Test metric with mixed API success rates."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": False,
                        "schema_compliant": False,
                        "api_error": True,
                    },
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=CORRECT,
                    metadata={
                        "json_valid": True,
                        "schema_compliant": True,
                        "api_error": False,
                    },
                ),
            ),
            SampleScore(
                sample_id="3",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": True,
                        "schema_compliant": False,
                        "api_error": False,
                    },
                ),
            ),
        ]

        metric_fn = api_success_rate()
        result = metric_fn(scores)

        assert result == 2.0 / 3.0  # 2/3 successful API calls

    def test_all_api_errors(self):
        """Test metric when all API calls fail."""
        scores = [
            SampleScore(
                sample_id="1",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": False,
                        "schema_compliant": False,
                        "api_error": True,
                    },
                ),
            ),
            SampleScore(
                sample_id="2",
                score=Score(
                    value=INCORRECT,
                    metadata={
                        "json_valid": False,
                        "schema_compliant": False,
                        "api_error": True,
                    },
                ),
            ),
        ]

        metric_fn = api_success_rate()
        result = metric_fn(scores)

        assert result == 0.0  # 0/2 successful API calls

tests/test_registry.py (1.0 KiB)

"""Test the registry module functionality."""

import pytest
from openbench.config import load_task, TASK_REGISTRY


def test_task_registry_contents():
    """Test that the task registry contains expected benchmarks."""
    assert "mmlu" in TASK_REGISTRY
    assert TASK_REGISTRY["mmlu"] == "openbench.evals.mmlu.mmlu"


def test_load_task_valid():
    """Test loading a valid task from the registry."""
    task = load_task("mmlu")
    assert callable(task)


def test_load_task_invalid():
    """Test loading an invalid task from the registry."""
    with pytest.raises(ValueError) as exc_info:
        load_task("nonexistent_benchmark")

    # Check that error message mentions available benchmarks
    assert "Unknown benchmark" in str(exc_info.value)
    assert "mmlu" in str(exc_info.value)


def test_load_task_caching():
    """Test that the load_task function uses caching."""
    # Call twice and verify it's the same object (due to lru_cache)
    task1 = load_task("mmlu")
    task2 = load_task("mmlu")
    assert task1 is task2  # Same object due to caching

tests/test_robust_scorers.py (8.7 KiB)

"""Tests for robust answer extraction scorers."""

import asyncio
from openbench.scorers import (
    extract_mcq_answer,
    extract_boxed_answer,
    robust_mcq_scorer,
    aime_scorer,
)
from openbench.scorers.robust_boxed import normalize_numeric_answer
from inspect_ai.scorer import Target, CORRECT, INCORRECT
from dataclasses import dataclass


@dataclass
class MockOutput:
    """Mock output for testing."""

    completion: str


@dataclass
class MockTaskState:
    """Mock task state for testing."""

    output: MockOutput


class TestMCQExtraction:
    """Test multiple choice answer extraction."""

    def test_markdown_wrapped_answer(self):
        """Test extraction from markdown-wrapped answers."""
        assert extract_mcq_answer("**Answer:** A") == "A"
        assert extract_mcq_answer("*Answer:* B") == "B"
        assert extract_mcq_answer("__Answer:__ C") == "C"
        assert extract_mcq_answer("_Answer:_ D") == "D"

    def test_parenthesis_answer(self):
        """Test extraction from parenthesis format."""
        assert extract_mcq_answer("Answer: (A)") == "A"
        assert extract_mcq_answer("The answer is (B).") == "B"
        assert extract_mcq_answer("Choice: (C)") == "C"
        assert extract_mcq_answer("[D] is correct") == "D"

    def test_plain_answer(self):
        """Test extraction from plain format."""
        assert extract_mcq_answer("Answer: A") == "A"
        assert extract_mcq_answer("Answer – B") == "B"
        assert extract_mcq_answer("Option C") == "C"
        assert extract_mcq_answer("Choice: D") == "D"

    def test_latex_boxed(self):
        """Test extraction from LaTeX boxed format."""
        assert extract_mcq_answer(r"\boxed{A}") == "A"
        assert extract_mcq_answer(r"\boxed{\text{B}}") == "B"
        assert extract_mcq_answer(r"\boxed{\textbf{C}}") == "C"
        assert extract_mcq_answer(r"The answer is \boxed{D}") == "D"

    def test_markdown_standalone(self):
        """Test extraction from standalone markdown."""
        assert extract_mcq_answer("*A*") == "A"
        assert extract_mcq_answer("**B**") == "B"
        assert extract_mcq_answer("_C_") == "C"
        assert extract_mcq_answer("__D__") == "D"

    def test_complex_cases(self):
        """Test extraction from complex/mixed formats."""
        # Markdown with description
        assert extract_mcq_answer("**D) This is the correct answer**") == "D"

        # Multiple patterns (should get first/best match)
        assert extract_mcq_answer("Let me think... Answer: B\n\n(C) is wrong") == "B"

        # Case insensitive
        assert extract_mcq_answer("answer: a") == "A"
        assert extract_mcq_answer("ANSWER: B") == "B"

    def test_fallback_to_first_char(self):
        """Test fallback to first character."""
        assert extract_mcq_answer("A") == "A"
        assert extract_mcq_answer("B is the answer") == "B"
        assert extract_mcq_answer("**C") == "C"

    def test_no_answer_found(self):
        """Test when no answer is found."""
        assert extract_mcq_answer("No valid answer here") is None
        assert extract_mcq_answer("The options are 1, 2, 3, 4") is None
        assert extract_mcq_answer("") is None


class TestBoxedExtraction:
    """Test boxed answer extraction."""

    def test_boxed_extraction(self):
        """Test extraction from \boxed{} format."""
        assert extract_boxed_answer(r"\boxed{42}") == "42"
        assert extract_boxed_answer(r"The answer is \boxed{-3}") == "-3"
        assert extract_boxed_answer(r"\boxed{3.14159}") == "3.14159"

    def test_framebox_extraction(self):
        """Test extraction from \framebox{} format."""
        assert extract_boxed_answer(r"\framebox{100}") == "100"
        assert extract_boxed_answer(r"Answer: \framebox{0}") == "0"

    def test_fbox_extraction(self):
        """Test extraction from \fbox{} format (OpenBench compatibility)."""
        assert extract_boxed_answer(r"\fbox{42}") == "42"
        assert extract_boxed_answer(r"The answer is \fbox{-10}") == "-10"

    def test_multiple_boxed(self):
        """Test that last boxed answer is used."""
        text = r"First \boxed{1} then \boxed{2} finally \boxed{3}"
        assert extract_boxed_answer(text) == "3"

    def test_comma_separated(self):
        """Test handling of comma-separated values in box."""
        # Just test that it extracts something from comma-separated values
        assert extract_boxed_answer(r"\boxed{x = 2, y = 3}") == "y = 3"

    def test_fallback_to_last_number(self):
        """Test fallback to last number when no box found."""
        assert extract_boxed_answer("The answer is 42", True) == "42"
        assert extract_boxed_answer("First 10 then 20 finally 30", True) == "30"
        assert extract_boxed_answer("Negative: -5", True) == "-5"
        assert extract_boxed_answer("Decimal: 3.14", True) == "3.14"

    def test_no_fallback(self):
        """Test no fallback when disabled."""
        assert extract_boxed_answer("The answer is 42", False) is None
        assert extract_boxed_answer("No box here", False) is None

    def test_no_answer(self):
        """Test when no answer is found."""
        assert extract_boxed_answer("No numbers here", True) is None
        assert extract_boxed_answer("", True) is None


class TestNumericNormalization:
    """Test numeric answer normalization."""

    def test_comma_removal(self):
        """Test removal of commas."""
        assert normalize_numeric_answer("1,234") == "1234"
        assert normalize_numeric_answer("1,000,000") == "1000000"

    def test_integer_extraction(self):
        """Test extraction of leading integers."""
        assert normalize_numeric_answer("42 points") == "42"
        assert normalize_numeric_answer("-3 units") == "-3"
        assert normalize_numeric_answer("0") == "0"

    def test_decimal_normalization(self):
        """Test decimal number normalization."""
        # Our implementation extracts leading integers
        assert normalize_numeric_answer("3.14000") == "3"
        assert normalize_numeric_answer("5.0") == "5"
        assert normalize_numeric_answer("0.500") == "0"
        assert normalize_numeric_answer("42.") == "42"

    def test_invalid_input(self):
        """Test invalid input handling."""
        assert normalize_numeric_answer("abc") is None
        assert normalize_numeric_answer("") is None
        assert normalize_numeric_answer(None) is None


class TestRobustMCQScorer:
    """Test the robust MCQ scorer."""

    def test_correct_answer(self):
        """Test scoring correct answers."""
        scorer = robust_mcq_scorer()
        state = MockTaskState(MockOutput("Answer: B"))
        target = Target("B")

        score = asyncio.run(scorer(state, target))
        assert score.value == CORRECT
        assert score.answer == "B"

    def test_incorrect_answer(self):
        """Test scoring incorrect answers."""
        scorer = robust_mcq_scorer()
        state = MockTaskState(MockOutput("Answer: A"))
        target = Target("B")

        score = asyncio.run(scorer(state, target))
        assert score.value == INCORRECT
        assert score.answer == "A"

    def test_no_answer_found(self):
        """Test scoring when no answer found."""
        scorer = robust_mcq_scorer()
        state = MockTaskState(MockOutput("I don't know"))
        target = Target("A")

        score = asyncio.run(scorer(state, target))
        assert score.value == INCORRECT
        assert score.answer is None


class TestAIMEScorer:
    """Test the AIME scorer."""

    def test_boxed_integer(self):
        """Test scoring boxed integer answers."""
        scorer = aime_scorer()
        state = MockTaskState(MockOutput(r"\boxed{42}"))
        target = Target("42")

        score = asyncio.run(scorer(state, target))
        assert score.value == CORRECT
        assert score.answer == "42"

    def test_fallback_to_last_integer(self):
        """Test fallback to last integer."""
        scorer = aime_scorer()
        state = MockTaskState(MockOutput("The answer is 123"))
        target = Target("123")

        score = asyncio.run(scorer(state, target))
        assert score.value == CORRECT
        assert score.answer == "123"

    def test_out_of_range(self):
        """Test AIME range validation (0-999)."""
        scorer = aime_scorer()
        state = MockTaskState(MockOutput(r"\boxed{1000}"))
        target = Target("1000")

        score = asyncio.run(scorer(state, target))
        assert score.value == INCORRECT
        assert "outside valid AIME range" in score.explanation

    def test_incorrect_answer(self):
        """Test incorrect answer."""
        scorer = aime_scorer()
        state = MockTaskState(MockOutput(r"\boxed{41}"))
        target = Target("42")

        score = asyncio.run(scorer(state, target))
        assert score.value == INCORRECT
        assert score.answer == "41"

🤖 LLM View - CXML Format

Copy the text below and paste it to an LLM for analysis:

💡 Tip: Click in the text area and press Ctrl+A (Cmd+A on Mac) to select all, then Ctrl+C (Cmd+C) to copy.