#!/usr/bin/env python3 import argparse import contextlib import libvirt from typing import Dict, Iterable CONNECTION = "qemu:///system" INVENTORY = ["heimdall-virt", "valkyrie-virt", "yggdrasil-virt"] class VirtDomain: __ANSIBLE_EDDA_SNAPSHOT = "ansible-edda" def __init__(self, libvirt_domain: libvirt.virDomain): self.__domain = libvirt_domain def start(self) -> None: if not self.__domain.isActive(): self.__domain.create() def stop(self) -> None: if self.__domain.isActive(): self.__domain.shutdown() def revert(self) -> None: if self.__domain.isActive(): raise RuntimeError(f"\"{self.__domain.name()}\" must be stopped before reverting") snap = self.__domain.snapshotLookupByName(self.__ANSIBLE_EDDA_SNAPSHOT) self.__domain.revertToSnapshot(snap) class VirtManager: def __init__(self, conn_name: str, hosts: Iterable[str]): self.__conn = libvirt.open(conn_name) self.__inventory = {} for hostname in hosts: try: self.__inventory[hostname] = VirtDomain(self.__conn.lookupByName(hostname)) except libvirt.libvirtError as e: if e.get_error_code() != libvirt.VIR_ERR_NO_DOMAIN: raise def close(self): self.__conn.close() @property def inventory(self) -> Dict[str, VirtDomain]: return self.__inventory def start(self): for vdom in self.__inventory.values(): vdom.start() def stop(self): for vdom in self.__inventory.values(): vdom.stop() def revert(self): for vdom in self.__inventory.values(): vdom.revert() @contextlib.contextmanager def virt_manager(connection_name: str, hosts: Iterable[str]): vmgr = VirtManager(connection_name, hosts) yield vmgr vmgr.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Manage virtual machines for testing") parser.add_argument("--connect", type=str, default=CONNECTION, help="hypervisor connection URI") parser.add_argument("--limit", type=str, default=','.join(INVENTORY), help="limit to selected hosts") subparsers = parser.add_subparsers() start_parser = subparsers.add_parser("start") start_parser.set_defaults(func=VirtManager.start) stop_parser = subparsers.add_parser("stop") stop_parser.set_defaults(func=VirtManager.stop) revert_parser = subparsers.add_parser("revert") revert_parser.set_defaults(func=VirtManager.revert) args = parser.parse_args() with virt_manager(CONNECTION, args.limit.split(',')) as vmgr: args.func(vmgr)