# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Simple client to communicate with a Presto server.
"""
import json
import logging
import os
import socket
import urlparse
from httplib import HTTPConnection, HTTPException
from tempfile import mkstemp
from StringIO import StringIO
from fabric.operations import get
from fabric.state import env
from fabric.utils import error
from jks import jks, base64, textwrap
from prestoadmin.util.constants import REMOTE_CONF_DIR, CONFIG_PROPERTIES
from prestoadmin.util.exception import InvalidArgumentError
from prestoadmin.util.httpscacertconnection import HTTPSCaCertConnection
from prestoadmin.util.local_config_util import get_coordinator_directory, get_topology_path
from prestoadmin.util.presto_config import PrestoConfig, LDAP_CLIENT_USER_KEY, LDAP_CLIENT_PASSWORD_KEY
_LOGGER = logging.getLogger(__name__)
URL_TIMEOUT_MS = 5000
NUM_ROWS = 1000
DATA_RESP = "data"
NEXT_URI_RESP = "nextUri"
CERTIFICATE_ALIAS = 'certificate_alias'
[docs]class PrestoClient:
def __init__(self, server, user, coordinator_config=None):
# immutable stuff
self.server = server
self.user = user
if (coordinator_config is None):
coordinator_config = PrestoConfig.coordinator_config()
self.coordinator_config = coordinator_config
self.port = PrestoClient._get_configured_port(self.coordinator_config)
# mutable stuff
self.ca_file_path = ""
self.keystore_data = ""
self.rows = []
self.next_uri = ''
self.response_from_server = {}
@staticmethod
def _remove_silently(path):
try:
os.remove(path)
except:
pass
[docs] def close(self):
PrestoClient._remove_silently(self.ca_file_path)
def _clear_old_results(self):
if self.rows:
self.rows = []
if self.next_uri:
self.next_uri = ''
if self.response_from_server:
self.response_from_server = {}
[docs] def run_sql(self, sql, schema="default", catalog="hive"):
"""
Execute a query connecting to Presto server using passed parameters.
Args:
sql: SQL query to be executed
schema: Presto schema to be used while executing query
(default=default)
catalog: Catalog to be used by the server
Returns:
list of rows or None if client was unable to connect to Presto
"""
status = self._execute_query(sql, schema, catalog)
if status:
return self._get_rows()
else:
return None
def _execute_query(self, sql, schema, catalog):
if not sql:
raise InvalidArgumentError("SQL query missing")
if not self.server:
raise InvalidArgumentError("Server IP missing")
if not self.user:
raise InvalidArgumentError("Username missing")
self._clear_old_results()
headers = {"X-Presto-Catalog": catalog,
"X-Presto-Schema": schema,
"X-Presto-User": self.user,
"X-Presto-Source": "presto-admin"}
answer = ''
try:
_LOGGER.info("Connecting to server at: " + self.server +
":" + str(self.port) + " as user " + self.user +
" to execute query " + sql)
conn = self._get_connection()
self._add_auth_headers(headers)
conn.request("POST", "/v1/statement", sql, headers)
response = conn.getresponse()
if response.status != 200:
conn.close()
_LOGGER.error("Connection error: " +
str(response.status) + " " + response.reason)
return False
answer = response.read()
conn.close()
self.response_from_server = json.loads(answer)
_LOGGER.info("Query executed successfully: %s" % (sql))
return True
except (HTTPException, socket.error) as e:
_LOGGER.error("Error connecting to presto server at: " +
self.server + ":" + str(self.port) + ' ' + e.message)
return False
except ValueError as e:
_LOGGER.error('Error connecting to Presto server: ' + e.message +
' error from server: ' + answer)
raise e
def _get_response_from(self, uri):
"""
Sends a GET request to the Presto server at the specified next_uri
and updates the response
Remove the scheme and host/port from the uri; the connection itself
has that information.
"""
parts = list(urlparse.urlsplit(uri))
parts[0] = None
parts[1] = None
location = urlparse.urlunsplit(parts)
conn = self._get_connection()
headers = {"X-Presto-User": self.user}
self._add_auth_headers(headers)
conn.request("GET", location, headers=headers)
response = conn.getresponse()
if response.status != 200:
conn.close()
_LOGGER.error("Error making GET request to %s: %s %s" %
(uri, response.status, response.reason))
return False
answer = response.read()
conn.close()
self.response_from_server = json.loads(answer)
_LOGGER.info("GET request successful for uri: " + uri)
return True
def _build_results_from_response(self):
"""
Build result from the response
The reponse_from_server may contain up to 3 uri's.
1. link to fetch the next packet of data ('nextUri')
2. TODO: information about the query execution ('infoUri')
3. TODO: cancel the query ('partialCancelUri').
"""
if NEXT_URI_RESP in self.response_from_server:
self.next_uri = self.response_from_server[NEXT_URI_RESP]
else:
self.next_uri = ""
if DATA_RESP in self.response_from_server:
if self.rows:
self.rows.extend(self.response_from_server[DATA_RESP])
else:
self.rows = self.response_from_server[DATA_RESP]
def _get_rows(self, num_of_rows=NUM_ROWS):
"""
Get the rows returned from the query.
The client sends GET requests to the server using the 'nextUri'
from the previous response until the servers response does not
contain anymore 'nextUri's. When there is no 'nextUri' the query is
finished
Note that this can only be called once and does not page through
the results.
Parameters:
num_of_rows: to be retrieved. 1000 by default
"""
if num_of_rows == 0:
return []
self._build_results_from_response()
if not self._get_next_uri():
return []
while self._get_next_uri():
if not self._get_response_from(self._get_next_uri()):
return []
if (len(self.rows) <= num_of_rows):
self._build_results_from_response()
return self.rows
def _get_next_uri(self):
return self.next_uri
def _get_connection(self):
if self.coordinator_config.use_https():
return self._get_https_connection()
else:
return HTTPConnection(self.server, self.port, False, URL_TIMEOUT_MS)
@staticmethod
def _get_configured_port(coordinator_config):
if coordinator_config.use_https():
return coordinator_config.get_https_port()
else:
return coordinator_config.get_http_port()
def _get_https_connection(self):
ca_file_path = self._get_pem()
result = HTTPSCaCertConnection(
self.server, self.port, None, None, ca_file_path, False, URL_TIMEOUT_MS)
return result
def _fetch_keystore_data(self):
if not self.keystore_data:
remote_keystore_path = self.coordinator_config.get_client_keystore_path()
keystore_data = StringIO()
get(remote_keystore_path, keystore_data, use_sudo=True)
keystore_data.seek(0)
self.keystore_data = keystore_data.getvalue()
return self.keystore_data
def _pem_string(self, der_bytes, type):
result = "-----BEGIN %s-----\n" % type
result += "\r\n".join(
textwrap.wrap(base64.b64encode(der_bytes).decode('ascii'), 64))
result += "\n-----END %s-----\n" % type
return result
def _write_pem_file(self, directory, der_bytes_list, type):
prefix = os.path.join(directory,
'%s-' % type.lower().replace(' ', '-'))
fd, pem_path = mkstemp('.pem', prefix)
# https://www.digicert.com/ssl-support/pem-ssl-creation.htm
with open(pem_path, 'w') as pem_file:
for der_bytes in der_bytes_list:
pem_file.write(self._pem_string(der_bytes, type))
os.close(fd)
return pem_path
def _get_pem(self):
keystore_data = self._fetch_keystore_data()
keystore = jks.KeyStore.loads(
keystore_data,
self.coordinator_config.get_client_keystore_password())
if len(keystore.private_keys.items()) == 1:
_, private_key = keystore.private_keys.items()[0]
else:
private_key = self._get_private_key(keystore)
if not self.ca_file_path:
"""
Each member of the cert chain is a tuple (cert_type, cert_data)
We only need to write the data out to the .PEM file.
This usage is shown in the example in the README.md on github:
https://github.com/kurtbrose/pyjks
"""
self.ca_file_path = self._write_pem_file(
get_coordinator_directory(),
[cert[1] for cert in private_key.cert_chain], 'CERTIFICATE')
return self.ca_file_path
def _get_private_key(self, keystore):
all_keys = ", ".join(keystore.private_keys.keys())
try:
alias = env.conf[CERTIFICATE_ALIAS]
except KeyError:
error('Multiple keys found in %s. Set %s in %s. Available aliases are %s' %
(self.coordinator_config.get_client_keystore_path(),
CERTIFICATE_ALIAS, get_topology_path(), all_keys))
try:
return keystore.private_keys[alias]
except KeyError:
error('No alias %s found in %s. Available aliases are %s' %
(alias, self.coordinator_config.get_client_keystore_path(),
all_keys))
def _add_auth_headers(self, headers):
if self.coordinator_config.use_ldap():
if self.coordinator_config.use_ldap():
auth_headers = self._create_auth_headers(
self.coordinator_config.get_ldap_user(),
self.coordinator_config.get_ldap_password())
headers.update(auth_headers)
_LOGGER.info("Using LDAP = %s" % self.coordinator_config.use_ldap())
@staticmethod
def _create_auth_headers(user, password):
if not user:
error('LDAP user (taken from %s in %s on the coordinator) cannot be null or empty' %
(LDAP_CLIENT_USER_KEY, os.path.join(REMOTE_CONF_DIR, CONFIG_PROPERTIES)))
return {}
if not password:
error('LDAP password (taken from %s in %s on the coordinator) cannot be null or empty' %
(LDAP_CLIENT_PASSWORD_KEY, os.path.join(REMOTE_CONF_DIR, CONFIG_PROPERTIES)))
return {}
if ':' in user:
error("LDAP user cannot contain ':': %s" % user)
# base64 encode the username and password
auth = base64.encodestring('%s:%s' % (user, password)).replace('\n', '')
return {'Authorization': 'Basic %s' % auth}