Source code for vasp.runners.slurm

"""SLURM runner for VASP execution on HPC clusters.

Submits VASP jobs to SLURM scheduler and monitors their status.
Designed for non-blocking async operation.
"""

from __future__ import annotations

import os
import subprocess
from typing import TYPE_CHECKING

from ..exceptions import VaspQueued, VaspRunning, VaspSubmitted
from .base import JobState, JobStatus, Runner

if TYPE_CHECKING:
    pass


[docs] class SlurmRunner(Runner): """Run VASP via SLURM scheduler. Submits jobs to SLURM and monitors status. All operations are non-blocking - run() submits and returns immediately. Args: partition: SLURM partition name. nodes: Number of nodes. ntasks_per_node: MPI tasks per node. time: Wall time limit (HH:MM:SS format). memory: Memory per node (e.g., '64G'). account: SLURM account for billing. qos: Quality of service. vasp_command: Command to run VASP. Defaults to 'srun $VASP_COMMAND' if VASP_COMMAND is set, or 'srun vasp_std' otherwise. modules: List of modules to load before running. extra_sbatch: Additional #SBATCH directives as list of strings. constraint: Node constraint (e.g., 'gpu' or 'skylake'). Example: >>> runner = SlurmRunner( ... partition='compute', ... nodes=2, ... ntasks_per_node=48, ... time='24:00:00', ... modules=['vasp/6.3.0'], ... ) >>> >>> calc = Vasp('my_calc', runner=runner, atoms=atoms) >>> try: ... energy = calc.potential_energy ... except VaspSubmitted as e: ... print(f"Submitted job {e.jobid}") """ def __init__( self, partition: str = 'normal', nodes: int = 1, ntasks_per_node: int = 24, time: str = '24:00:00', memory: str | None = None, account: str | None = None, qos: str | None = None, vasp_command: str | None = None, modules: list[str] | None = None, extra_sbatch: list[str] | None = None, constraint: str | None = None, ): self.partition = partition self.nodes = nodes self.ntasks_per_node = ntasks_per_node self.time = time self.memory = memory self.account = account self.qos = qos if vasp_command is not None: self.vasp_command = vasp_command else: base_cmd = os.environ.get('VASP_COMMAND', 'vasp_std') self.vasp_command = f'srun {base_cmd}' self.modules = modules or [] self.extra_sbatch = extra_sbatch or [] self.constraint = constraint
[docs] def run(self, directory: str) -> JobStatus: """Submit VASP job to SLURM.""" current = self.status(directory) if current.state == JobState.COMPLETE: return current if current.state == JobState.QUEUED: raise VaspQueued(message=f"Job {current.jobid} queued", jobid=current.jobid) if current.state == JobState.RUNNING: raise VaspRunning(message=f"Job {current.jobid} running", jobid=current.jobid) if current.state == JobState.FAILED: # Allow resubmission - clean up old job file self._cleanup_old_job(directory) # Submit new job jobid = self._submit(directory) raise VaspSubmitted(jobid=jobid)
[docs] def status(self, directory: str) -> JobStatus: """Check SLURM job status.""" jobid = self._read_jobid(directory) if jobid: state = self._query_slurm(jobid) if state: return JobStatus(state, jobid=jobid) # Job not in SLURM - check output files return self._check_output_files(directory)
[docs] def cancel(self, directory: str) -> bool: """Cancel SLURM job.""" jobid = self._read_jobid(directory) if not jobid: return True try: result = subprocess.run( ['scancel', jobid], capture_output=True, text=True, ) return result.returncode == 0 except FileNotFoundError: return False
[docs] def get_logs(self, directory: str, tail_lines: int = 100) -> str: """Get SLURM job output.""" # Try SLURM output file first jobid = self._read_jobid(directory) if jobid: slurm_out = os.path.join(directory, f'slurm-{jobid}.out') if os.path.exists(slurm_out): with open(slurm_out) as f: lines = f.readlines() return ''.join(lines[-tail_lines:]) # Fall back to OUTCAR return super().get_logs(directory, tail_lines)
def _submit(self, directory: str) -> str: """Create and submit SLURM job script.""" script = self._create_script(directory) script_path = os.path.join(directory, 'submit.slurm') with open(script_path, 'w') as f: f.write(script) result = subprocess.run( ['sbatch', script_path], cwd=directory, capture_output=True, text=True, ) if result.returncode != 0: raise RuntimeError(f"sbatch failed: {result.stderr}") # Parse "Submitted batch job 123456" jobid = result.stdout.strip().split()[-1] self._write_jobid(directory, jobid) return jobid def _query_slurm(self, jobid: str) -> JobState | None: """Query SLURM for job state.""" try: result = subprocess.run( ['squeue', '-j', jobid, '-h', '-o', '%T'], capture_output=True, text=True, timeout=30, ) except (FileNotFoundError, subprocess.TimeoutExpired): return None if result.returncode != 0: return None state = result.stdout.strip() mapping = { 'PENDING': JobState.QUEUED, 'CONFIGURING': JobState.QUEUED, 'RUNNING': JobState.RUNNING, 'COMPLETING': JobState.RUNNING, 'SUSPENDED': JobState.QUEUED, } return mapping.get(state) def _check_output_files(self, directory: str) -> JobStatus: """Check calculation status from output files.""" if self._check_outcar_complete(directory): return JobStatus(JobState.COMPLETE) error = self._check_outcar_error(directory) if error: return JobStatus(JobState.FAILED, message=error) outcar = os.path.join(directory, 'OUTCAR') if os.path.exists(outcar): return JobStatus(JobState.FAILED, message="OUTCAR incomplete") return JobStatus(JobState.NOT_STARTED) def _create_script(self, directory: str) -> str: """Generate SLURM batch script.""" job_name = os.path.basename(os.path.abspath(directory))[:50] lines = [ '#!/bin/bash', f'#SBATCH --job-name={job_name}', f'#SBATCH --partition={self.partition}', f'#SBATCH --nodes={self.nodes}', f'#SBATCH --ntasks-per-node={self.ntasks_per_node}', f'#SBATCH --time={self.time}', '#SBATCH --output=slurm-%j.out', '#SBATCH --error=slurm-%j.err', ] if self.memory: lines.append(f'#SBATCH --mem={self.memory}') if self.account: lines.append(f'#SBATCH --account={self.account}') if self.qos: lines.append(f'#SBATCH --qos={self.qos}') if self.constraint: lines.append(f'#SBATCH --constraint={self.constraint}') for directive in self.extra_sbatch: lines.append(f'#SBATCH {directive}') lines.append('') lines.append('# Load modules') for mod in self.modules: lines.append(f'module load {mod}') lines.append('') lines.append('# Run VASP') lines.append(self.vasp_command) return '\n'.join(lines) + '\n' def _read_jobid(self, directory: str) -> str | None: """Read SLURM job ID from tracking file.""" path = os.path.join(directory, '.slurm_jobid') if os.path.exists(path): with open(path) as f: return f.read().strip() return None def _write_jobid(self, directory: str, jobid: str) -> None: """Save SLURM job ID to tracking file.""" path = os.path.join(directory, '.slurm_jobid') with open(path, 'w') as f: f.write(jobid) def _cleanup_old_job(self, directory: str) -> None: """Remove old job tracking file before resubmission.""" path = os.path.join(directory, '.slurm_jobid') if os.path.exists(path): os.remove(path) def __repr__(self) -> str: return ( f"SlurmRunner(partition={self.partition!r}, " f"nodes={self.nodes}, time={self.time!r})" )