174 lines
6.4 KiB
Python
174 lines
6.4 KiB
Python
#!/usr/local/venvs/pg_backup_service/bin/python
|
|
|
|
import argparse
|
|
import logging
|
|
import signal
|
|
import sys
|
|
import textwrap
|
|
import time
|
|
from math import floor, log
|
|
from typing import Any, Union
|
|
|
|
from packaging.version import Version
|
|
from pgcos import PgUtil
|
|
|
|
|
|
class PgBackupService:
|
|
__slots__ = ["args", "debug", "dryrun", "logger", "start_time", "delay", "password", "pg", "version"] # noqa: RUF023
|
|
|
|
def __init__(self):
|
|
self.args = self.get_args()
|
|
self.debug = self.args.debug
|
|
self.dryrun = self.args.debug
|
|
self.logger = self.configure_logger()
|
|
self.start_time = time.time()
|
|
self.delay = 5
|
|
self.pg = PgUtil(
|
|
{
|
|
"host": "/var/run/postgresql",
|
|
"database": "postgres",
|
|
"user": "postgres",
|
|
"application_name": f"{self.__class__.__name__}",
|
|
}
|
|
)
|
|
self.version = self.get_server_major_version()
|
|
self.logger.info(f"{self.__class__.__name__} instance created, running against postgres {self.version}")
|
|
|
|
signal.signal(signal.SIGTERM, self.__handle_signal)
|
|
signal.signal(signal.SIGINT, self.__handle_signal)
|
|
|
|
@staticmethod
|
|
def format_time(seconds: Union[int, float]) -> str:
|
|
"""Return human readable string given seconds input."""
|
|
if seconds < 1:
|
|
power = 0
|
|
else:
|
|
power = floor(log(seconds, 60))
|
|
return f"{round(seconds / 60 ** power, 2)} {['s', 'm', 'h'][int(power)]}"
|
|
|
|
@staticmethod
|
|
def get_args() -> argparse.Namespace:
|
|
"""Parse and return CLI arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description=textwrap.dedent(
|
|
"""
|
|
Utility to start and stop postgres backup using the low level API.
|
|
Intended to be run as a systemd service as postgres user.
|
|
Cleanup is done on SIGTERM and SIGINT.
|
|
https://www.postgresql.org/docs/current/continuous-archiving.html#BACKUP-LOWLEVEL-BASE-BACKUP
|
|
https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADMIN-BACKUP
|
|
"""
|
|
),
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
parser.add_argument("--debug", help="debug mode", action="store_true")
|
|
parser.add_argument("--dryrun", help="show what would be done, but do nothing", action="store_true")
|
|
return parser.parse_args()
|
|
|
|
def configure_logger(self) -> logging.Logger:
|
|
"""Configure the logger."""
|
|
logger = logging.getLogger("__name__")
|
|
log_format = f"{'DRYRUN - ' if self.dryrun else ''}{'%(levelname)s - ' if self.debug else ''}%(message)s"
|
|
logging.basicConfig(
|
|
format=log_format,
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=logging.INFO,
|
|
handlers=[logging.StreamHandler()],
|
|
force=True,
|
|
)
|
|
if self.debug:
|
|
logger.setLevel(logging.DEBUG)
|
|
return logger
|
|
|
|
def pg_connect(self) -> None:
|
|
"""Connect to postgresql database."""
|
|
error_log = "could not connect to local postgres"
|
|
try:
|
|
self.pg.connect()
|
|
if not self.pg.is_connected():
|
|
self.logger.error(f"{error_log}")
|
|
sys.exit(1)
|
|
except Exception as err:
|
|
self.logger.error(f"{error_log}: {err}") # noqa: TRY400
|
|
sys.exit(1)
|
|
|
|
def get_server_major_version(self) -> Version:
|
|
"""Retrieve postgres server version."""
|
|
statement = "SELECT CURRENT_SETTING('server_version')"
|
|
self.pg_connect()
|
|
raw_version = self.pg.execute_fetch_one(statement=statement)[0]
|
|
return Version(raw_version.split(" ")[0])
|
|
|
|
def __handle_signal(self, sig_num: int, curr_stack_frame: Any):
|
|
"""Handle signal received, to use exclusively in signal.signal()."""
|
|
self.logger.info(f"{signal.strsignal(sig_num)} received")
|
|
self.stop()
|
|
|
|
def start(self) -> None:
|
|
"""Start postgres backup."""
|
|
label = "pitr_snapshot"
|
|
data = [label]
|
|
|
|
if self.version >= Version("15"):
|
|
statement = "SELECT PG_BACKUP_START(%s, true)"
|
|
else:
|
|
statement = "SELECT PG_START_BACKUP(%s, true, false)"
|
|
|
|
self.pg_connect()
|
|
|
|
if self.debug:
|
|
self.logger.debug(f"{self.pg.mogrify(statement=statement, data=data)}")
|
|
elif not self.debug and self.dryrun:
|
|
self.logger.info(f"{self.pg.mogrify(statement=statement, data=data)}")
|
|
|
|
while True:
|
|
if not self.dryrun:
|
|
try:
|
|
pg_backup_start = self.pg.execute_fetch_one(statement=statement, data=data)[0]
|
|
self.logger.info(f"backup started at WAL location '{pg_backup_start} with label '{label}'")
|
|
except Exception as err:
|
|
if "WAL generated with full_page_writes=off was replayed" in str(err):
|
|
self.logger.info(
|
|
"starting a backup an a standby with full_page_writes=off on the primary is not possible"
|
|
)
|
|
sys.exit()
|
|
if "ERROR: a backup is already in progress" not in str(err):
|
|
self.logger.error(err) # noqa: TRY400
|
|
sys.exit(1)
|
|
time.sleep(self.delay)
|
|
self.logger.info("backup in progress...")
|
|
|
|
def stop(self) -> None:
|
|
"""Stop postgres backup gracefully."""
|
|
self.logger.info("stopping backup")
|
|
|
|
if self.version >= Version("15"):
|
|
statement = "SELECT PG_BACKUP_STOP()"
|
|
else:
|
|
statement = "SELECT PG_STOP_BACKUP(false)"
|
|
|
|
if self.debug:
|
|
self.logger.debug(f"{self.pg.mogrify(statement=statement)}")
|
|
elif not self.debug and self.dryrun:
|
|
self.logger.info(f"{self.pg.mogrify(statement=statement)}")
|
|
|
|
if not self.dryrun:
|
|
try:
|
|
pg_stop_backup = self.pg.execute_fetch_one(statement)[0]
|
|
except Exception as err:
|
|
self.logger.error(err) # noqa: TRY400
|
|
sys.exit(1)
|
|
self.logger.info(f"backup finished:\n{pg_stop_backup}")
|
|
|
|
end_time = time.time()
|
|
self.logger.info(f"service ran for {self.format_time(end_time - self.start_time)}")
|
|
sys.exit(0)
|
|
|
|
|
|
def main():
|
|
service = PgBackupService()
|
|
service.start()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|