Source code for train_lib.clients.pht_client

import requests
from io import BytesIO
import tarfile
import pika
import json
from typing import Union, List
from tarfile import TarFile
from dotenv import load_dotenv, find_dotenv
import os
import logging
import base64

LOGGER = logging.getLogger(__name__)


[docs]class PHTClient: """ Client class for interacting with PHT services """ def __init__(self, api_url: str, api_port: int = 5555, api_token: str = None, ampq_url: str = None, vault_url: str = None, vault_token: str = None): """ Set up connection parameters for the services (train api and rabbit mq) :param api_url: endpoint of the central TrainAPI :param api_port: :param ampq_url: ampq url containing username and password for connecting to rabbitmq :param api_token: token to be passed to the api :param vault_url: url of the vault api used for storing """ self.api_url = api_url self.port = api_port self.token = api_token self.vault_url = vault_url self.api_headers = None self.vault_headers = None self._create_headers(api_token, vault_token) self.rmq_params = None if ampq_url: self.rmq_params = pika.URLParameters(ampq_url)
[docs] def publish_message_rabbit_mq(self, message: Union[str, bytes, List[str], dict], exchange: str = "pht", exchange_type: str = "topic", routing_key: str = "pht"): """ Publish a message to rabbit mq with the given message parameters :param message: the message to be published :param exchange: the identifier of the exchange :param exchange_type: :param routing_key: :return: """ if self.rmq_params: connection = pika.BlockingConnection(self.rmq_params) else: LOGGER.info("No connection to rabbit mq specified, attempting connection on localhost") connection = pika.BlockingConnection(pika.ConnectionParameters(host='localhost')) channel = connection.channel() channel.exchange_declare(exchange=exchange, exchange_type=exchange_type, durable=True) json_message = json.dumps(message).encode("utf-8") channel.basic_publish(exchange=exchange, routing_key=routing_key, body=json_message) LOGGER.info(" [x] Sent %r" % json_message) connection.close()
[docs] def get_train_files_archive(self, train_id: str, token: str = None, client_id: str = None): """ Get the tar archive containing files for building a train from the UI api :param train_id: :return: """ endpoint = f"{train_id}/files/download" if not token: archive = self._get_tar_archive_from_stream(endpoint) else: archive = self._get_tar_archive_from_stream(endpoint, token=token, client_id=client_id) return archive
def _get_tar_archive_from_stream(self, endpoint: str, params: dict = None, external_endpoint: bool = False, token: str = None, client_id: str = None) -> TarFile: """ Read a stream of tar data from the given endpoint and return an in memory BytesIO object containing the data :param endpoint: address relative to this instances api address :param params: dictionary containing additional parameters to be passed to the request :param external_endpoint: boolean parameter controlling whether the URL where the request is sent should built using the combination of api + endpoint or if the connection should be attempted on the raw endpoint string :return: """ if external_endpoint: url = endpoint else: url = self.api_url + endpoint headers = self._create_api_headers(api_token=token, client_id=client_id) with requests.get(url, params=params, headers=headers, stream=True) as r: r.raise_for_status() file_obj = BytesIO() for chunk in r.iter_content(): file_obj.write(chunk) file_obj.seek(0) return file_obj
[docs] def get_user_pk(self, user_id): """ Get the public key associated with the given user_id from vault storage :param user_id: :return: hex string containing an rsa public key """ url = f"{self.vault_url}v1/user_pks/{user_id}" r = requests.get(url, headers=self.vault_headers) r.raise_for_status() data = r.json()["data"] return data["data"]["rsa_public_key"]
[docs] def get_station_pk(self, station_id): """ Get the rsa public of the station specified by station id from vault storage :param station_id: identifier of the station in vault :return: hex string containing an rsa public key """ url = f"{self.vault_url}v1/station_pks/{station_id}" r = requests.get(url, headers=self.vault_headers) r.raise_for_status() public_key = r.json()["data"]["data"]["rsa_station_public_key"] return public_key
[docs] def get_multiple_station_pks(self, station_ids: List) -> dict: station_pks = {} for id in station_ids: station_pks[id] = self.get_station_pk(id) return station_pks
[docs] def post_route_to_vault(self, train_id, route, periodic=False): route = [str(_) for _ in route] vault_url = f"{self.vault_url}v1/kv-pht-routes/data/{train_id}" payload = { "options": { "cas": 0 }, "data": { "harborProjects": route, "repositorySuffix": str(train_id), "periodic": periodic } } r = requests.post(vault_url, headers=self.vault_headers, data=json.dumps(payload)) r.raise_for_status()
def _create_headers(self, api_token, vault_token): self.headers = self._create_api_headers(api_token) self.vault_headers = {"X-Vault-Token": vault_token} def _create_api_headers(self, api_token: str, client_id: str = "TRAIN_BUILDER"): auth_string = f"{client_id}:{api_token}" auth_string = base64.b64encode(auth_string.encode("utf-8")).decode() headers = {"Authorization": f"Basic {auth_string}"} return headers