ansible-edda/scripts/testing/vmgr.py

96 lines
2.7 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import contextlib
import libvirt
from typing import Dict, Iterable
CONNECTION = "qemu:///system"
INVENTORY = ["heimdall-virt", "valkyrie-virt", "yggdrasil-virt"]
class VirtDomain:
__ANSIBLE_EDDA_SNAPSHOT = "ansible-edda"
def __init__(self, libvirt_domain: libvirt.virDomain):
self.__domain = libvirt_domain
def start(self) -> None:
if not self.__domain.isActive():
self.__domain.create()
def stop(self) -> None:
if self.__domain.isActive():
self.__domain.shutdown()
def revert(self) -> None:
if self.__domain.isActive():
raise RuntimeError(f"\"{self.__domain.name()}\" must be stopped before reverting")
snap = self.__domain.snapshotLookupByName(self.__ANSIBLE_EDDA_SNAPSHOT)
self.__domain.revertToSnapshot(snap)
class VirtManager:
def __init__(self, conn_name: str, hosts: Iterable[str]):
self.__conn = libvirt.open(conn_name)
self.__inventory = {}
for hostname in hosts:
try:
self.__inventory[hostname] = VirtDomain(self.__conn.lookupByName(hostname))
except libvirt.libvirtError as e:
if e.get_error_code() != libvirt.VIR_ERR_NO_DOMAIN:
raise
def close(self):
self.__conn.close()
@property
def inventory(self) -> Dict[str, VirtDomain]:
return self.__inventory
def start(self):
for vdom in self.__inventory.values():
vdom.start()
def stop(self):
for vdom in self.__inventory.values():
vdom.stop()
def revert(self):
for vdom in self.__inventory.values():
vdom.revert()
@contextlib.contextmanager
def virt_manager(connection_name: str, hosts: Iterable[str]):
vmgr = VirtManager(connection_name, hosts)
yield vmgr
vmgr.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Manage virtual machines for testing")
parser.add_argument("--connect", type=str, default=CONNECTION,
help="hypervisor connection URI")
parser.add_argument("--limit", type=str, default=','.join(INVENTORY),
help="limit to selected hosts")
subparsers = parser.add_subparsers()
start_parser = subparsers.add_parser("start")
start_parser.set_defaults(func=VirtManager.start)
stop_parser = subparsers.add_parser("stop")
stop_parser.set_defaults(func=VirtManager.stop)
revert_parser = subparsers.add_parser("revert")
revert_parser.set_defaults(func=VirtManager.revert)
args = parser.parse_args()
with virt_manager(CONNECTION, args.limit.split(',')) as vmgr:
args.func(vmgr)