Skip to content

Commit 1b0233f

Browse files
authored
Add python table writer to output structed query results (#2832)
* add python table writer to output structed query results * fix pylint * update * update * update
1 parent fec2e8e commit 1b0233f

6 files changed

Lines changed: 90 additions & 2 deletions

File tree

go/codegen/experimental/codegen_normal_stmt.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,16 @@ var normalStmtStepTmpl = `
2424
def step_entry_{{.StepIndex}}():
2525
import runtime
2626
import runtime.dbapi
27+
from runtime.dbapi import table_writer
28+
2729
conn = runtime.dbapi.connect("{{.DataSource}}")
2830
stmt = """{{.Stmt}}"""
2931
if conn.is_query(stmt):
3032
rs = conn.query(stmt)
31-
# write rs to stdout using protobuf table writer
33+
tw = table_writer.ProtobufWriter(rs)
34+
lines = tw.dump_strings()
35+
for l in lines:
36+
print(l)
3237
else:
3338
success = conn.execute(stmt)
3439
if not success:

go/proto/sqlflow.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ message Session {
9090
// 2. `USE ...`, `DELETE ...`
9191
// 3. `SELECT ... TO TRAIN/PREDICT ...`
9292
message Request {
93-
string sql = 1; // The SQL statement to be executed.
93+
string sql = 1; // The SQL statement to be executed.
9494
Session session = 2;
9595
}
9696

python/runtime/dbapi/mysql_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from unittest import TestCase
1616

1717
from runtime import testing
18+
from runtime.dbapi import table_writer
1819
from runtime.dbapi.mysql import MySQLConnection
1920

2021

@@ -70,6 +71,16 @@ def test_get_table_schema(self):
7071
('petal_length', 'FLOAT'), ('petal_width', 'FLOAT'),
7172
('class', 'INT')], col_info)
7273

74+
def test_proto_table_writer(self):
75+
conn = MySQLConnection(testing.get_datasource())
76+
rs = conn.query("select * from iris.train limit 10;")
77+
self.assertTrue(rs.success())
78+
tw = table_writer.ProtobufWriter(rs)
79+
lines = tw.dump_strings()
80+
self.assertTrue(lines[0].find(
81+
"head { column_names: \"sepal_length\" column_names: \"sepal_width\" column_names: \"petal_length\" column_names: \"petal_width\" column_names: \"class\" }" # noqa: E501
82+
) >= 0)
83+
7384

7485
if __name__ == "__main__":
7586
unittest.main()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from runtime.dbapi.table_writer.protobuf_writer import ProtobufWriter
15+
16+
__all__ = ["ProtobufWriter"]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from google.protobuf import text_format, wrappers_pb2
15+
from runtime.dbapi.connection import ResultSet
16+
from runtime.dbapi.table_writer import sqlflow_pb2
17+
18+
19+
class ProtobufWriter:
20+
def __init__(self, result_set):
21+
assert isinstance(result_set, ResultSet)
22+
column_info = result_set.column_info()
23+
self.all_responses = []
24+
head = sqlflow_pb2.Head()
25+
for field_name, _ in column_info:
26+
head.column_names.append(field_name)
27+
self.all_responses.append(sqlflow_pb2.Response(head=head))
28+
for row in result_set:
29+
pb_row = sqlflow_pb2.Row()
30+
for col in row:
31+
any_msg = self.pod_to_pb_any(col)
32+
any = pb_row.data.add()
33+
any.Pack(any_msg)
34+
self.all_responses.append(sqlflow_pb2.Response(row=pb_row))
35+
36+
@staticmethod
37+
def pod_to_pb_any(value):
38+
if isinstance(value, bool):
39+
v = wrappers_pb2.BoolValue(value=value)
40+
elif isinstance(value, int):
41+
v = wrappers_pb2.Int32Value(value=value)
42+
elif isinstance(value, float):
43+
v = wrappers_pb2.FloatValue(value=value)
44+
elif isinstance(value, str):
45+
v = wrappers_pb2.StringValue(value=value)
46+
else:
47+
raise ValueError("not supported cell data type: %s" % type(value))
48+
return v
49+
50+
def dump_strings(self):
51+
lines = []
52+
for resp in self.all_responses:
53+
lines.append(text_format.MessageToString(resp, as_one_line=True))
54+
return lines

scripts/test/prepare.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ git clone https://github.com/sql-machine-learning/models.git
4848
git checkout v0.0.6 -b v0.0.6 && \
4949
python setup.py install)
5050

51+
protoc --python_out=python/runtime/dbapi/table_writer/ -I go/proto sqlflow.proto
52+
5153
# 3. install java parser
5254
echo "Build parser gRPC servers in Java ..."
5355

0 commit comments

Comments
 (0)