Source code for prestoadmin.prestoclient

# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Simple client to communicate with a Presto server.
"""
import json
import logging
import os
import socket
import urlparse
from httplib import HTTPConnection, HTTPException
from tempfile import mkstemp

from StringIO import StringIO
from fabric.operations import get
from fabric.state import env
from fabric.utils import error
from jks import jks, base64, textwrap
from prestoadmin.util.constants import REMOTE_CONF_DIR, CONFIG_PROPERTIES
from prestoadmin.util.exception import InvalidArgumentError
from prestoadmin.util.httpscacertconnection import HTTPSCaCertConnection
from prestoadmin.util.local_config_util import get_coordinator_directory, get_topology_path
from prestoadmin.util.presto_config import PrestoConfig, LDAP_CLIENT_USER_KEY, LDAP_CLIENT_PASSWORD_KEY

_LOGGER = logging.getLogger(__name__)
URL_TIMEOUT_MS = 5000
NUM_ROWS = 1000
DATA_RESP = "data"
NEXT_URI_RESP = "nextUri"

CERTIFICATE_ALIAS = 'certificate_alias'


[docs]class PrestoClient: def __init__(self, server, user, coordinator_config=None): # immutable stuff self.server = server self.user = user if (coordinator_config is None): coordinator_config = PrestoConfig.coordinator_config() self.coordinator_config = coordinator_config self.port = PrestoClient._get_configured_port(self.coordinator_config) # mutable stuff self.ca_file_path = "" self.keystore_data = "" self.rows = [] self.next_uri = '' self.response_from_server = {} @staticmethod def _remove_silently(path): try: os.remove(path) except: pass
[docs] def close(self): PrestoClient._remove_silently(self.ca_file_path)
def _clear_old_results(self): if self.rows: self.rows = [] if self.next_uri: self.next_uri = '' if self.response_from_server: self.response_from_server = {}
[docs] def run_sql(self, sql, schema="default", catalog="hive"): """ Execute a query connecting to Presto server using passed parameters. Args: sql: SQL query to be executed schema: Presto schema to be used while executing query (default=default) catalog: Catalog to be used by the server Returns: list of rows or None if client was unable to connect to Presto """ status = self._execute_query(sql, schema, catalog) if status: return self._get_rows() else: return None
def _execute_query(self, sql, schema, catalog): if not sql: raise InvalidArgumentError("SQL query missing") if not self.server: raise InvalidArgumentError("Server IP missing") if not self.user: raise InvalidArgumentError("Username missing") self._clear_old_results() headers = {"X-Presto-Catalog": catalog, "X-Presto-Schema": schema, "X-Presto-User": self.user, "X-Presto-Source": "presto-admin"} answer = '' try: _LOGGER.info("Connecting to server at: " + self.server + ":" + str(self.port) + " as user " + self.user + " to execute query " + sql) conn = self._get_connection() self._add_auth_headers(headers) conn.request("POST", "/v1/statement", sql, headers) response = conn.getresponse() if response.status != 200: conn.close() _LOGGER.error("Connection error: " + str(response.status) + " " + response.reason) return False answer = response.read() conn.close() self.response_from_server = json.loads(answer) _LOGGER.info("Query executed successfully: %s" % (sql)) return True except (HTTPException, socket.error) as e: _LOGGER.error("Error connecting to presto server at: " + self.server + ":" + str(self.port) + ' ' + e.message) return False except ValueError as e: _LOGGER.error('Error connecting to Presto server: ' + e.message + ' error from server: ' + answer) raise e def _get_response_from(self, uri): """ Sends a GET request to the Presto server at the specified next_uri and updates the response Remove the scheme and host/port from the uri; the connection itself has that information. """ parts = list(urlparse.urlsplit(uri)) parts[0] = None parts[1] = None location = urlparse.urlunsplit(parts) conn = self._get_connection() headers = {"X-Presto-User": self.user} self._add_auth_headers(headers) conn.request("GET", location, headers=headers) response = conn.getresponse() if response.status != 200: conn.close() _LOGGER.error("Error making GET request to %s: %s %s" % (uri, response.status, response.reason)) return False answer = response.read() conn.close() self.response_from_server = json.loads(answer) _LOGGER.info("GET request successful for uri: " + uri) return True def _build_results_from_response(self): """ Build result from the response The reponse_from_server may contain up to 3 uri's. 1. link to fetch the next packet of data ('nextUri') 2. TODO: information about the query execution ('infoUri') 3. TODO: cancel the query ('partialCancelUri'). """ if NEXT_URI_RESP in self.response_from_server: self.next_uri = self.response_from_server[NEXT_URI_RESP] else: self.next_uri = "" if DATA_RESP in self.response_from_server: if self.rows: self.rows.extend(self.response_from_server[DATA_RESP]) else: self.rows = self.response_from_server[DATA_RESP] def _get_rows(self, num_of_rows=NUM_ROWS): """ Get the rows returned from the query. The client sends GET requests to the server using the 'nextUri' from the previous response until the servers response does not contain anymore 'nextUri's. When there is no 'nextUri' the query is finished Note that this can only be called once and does not page through the results. Parameters: num_of_rows: to be retrieved. 1000 by default """ if num_of_rows == 0: return [] self._build_results_from_response() if not self._get_next_uri(): return [] while self._get_next_uri(): if not self._get_response_from(self._get_next_uri()): return [] if (len(self.rows) <= num_of_rows): self._build_results_from_response() return self.rows def _get_next_uri(self): return self.next_uri def _get_connection(self): if self.coordinator_config.use_https(): return self._get_https_connection() else: return HTTPConnection(self.server, self.port, False, URL_TIMEOUT_MS) @staticmethod def _get_configured_port(coordinator_config): if coordinator_config.use_https(): return coordinator_config.get_https_port() else: return coordinator_config.get_http_port() def _get_https_connection(self): ca_file_path = self._get_pem() result = HTTPSCaCertConnection( self.server, self.port, None, None, ca_file_path, False, URL_TIMEOUT_MS) return result def _fetch_keystore_data(self): if not self.keystore_data: remote_keystore_path = self.coordinator_config.get_client_keystore_path() keystore_data = StringIO() get(remote_keystore_path, keystore_data, use_sudo=True) keystore_data.seek(0) self.keystore_data = keystore_data.getvalue() return self.keystore_data def _pem_string(self, der_bytes, type): result = "-----BEGIN %s-----\n" % type result += "\r\n".join( textwrap.wrap(base64.b64encode(der_bytes).decode('ascii'), 64)) result += "\n-----END %s-----\n" % type return result def _write_pem_file(self, directory, der_bytes_list, type): prefix = os.path.join(directory, '%s-' % type.lower().replace(' ', '-')) fd, pem_path = mkstemp('.pem', prefix) # https://www.digicert.com/ssl-support/pem-ssl-creation.htm with open(pem_path, 'w') as pem_file: for der_bytes in der_bytes_list: pem_file.write(self._pem_string(der_bytes, type)) os.close(fd) return pem_path def _get_pem(self): keystore_data = self._fetch_keystore_data() keystore = jks.KeyStore.loads( keystore_data, self.coordinator_config.get_client_keystore_password()) if len(keystore.private_keys.items()) == 1: _, private_key = keystore.private_keys.items()[0] else: private_key = self._get_private_key(keystore) if not self.ca_file_path: """ Each member of the cert chain is a tuple (cert_type, cert_data) We only need to write the data out to the .PEM file. This usage is shown in the example in the README.md on github: https://github.com/kurtbrose/pyjks """ self.ca_file_path = self._write_pem_file( get_coordinator_directory(), [cert[1] for cert in private_key.cert_chain], 'CERTIFICATE') return self.ca_file_path def _get_private_key(self, keystore): all_keys = ", ".join(keystore.private_keys.keys()) try: alias = env.conf[CERTIFICATE_ALIAS] except KeyError: error('Multiple keys found in %s. Set %s in %s. Available aliases are %s' % (self.coordinator_config.get_client_keystore_path(), CERTIFICATE_ALIAS, get_topology_path(), all_keys)) try: return keystore.private_keys[alias] except KeyError: error('No alias %s found in %s. Available aliases are %s' % (alias, self.coordinator_config.get_client_keystore_path(), all_keys)) def _add_auth_headers(self, headers): if self.coordinator_config.use_ldap(): if self.coordinator_config.use_ldap(): auth_headers = self._create_auth_headers( self.coordinator_config.get_ldap_user(), self.coordinator_config.get_ldap_password()) headers.update(auth_headers) _LOGGER.info("Using LDAP = %s" % self.coordinator_config.use_ldap()) @staticmethod def _create_auth_headers(user, password): if not user: error('LDAP user (taken from %s in %s on the coordinator) cannot be null or empty' % (LDAP_CLIENT_USER_KEY, os.path.join(REMOTE_CONF_DIR, CONFIG_PROPERTIES))) return {} if not password: error('LDAP password (taken from %s in %s on the coordinator) cannot be null or empty' % (LDAP_CLIENT_PASSWORD_KEY, os.path.join(REMOTE_CONF_DIR, CONFIG_PROPERTIES))) return {} if ':' in user: error("LDAP user cannot contain ':': %s" % user) # base64 encode the username and password auth = base64.encodestring('%s:%s' % (user, password)).replace('\n', '') return {'Authorization': 'Basic %s' % auth}