Skip to content

Commit fec2e8e

Browse files
authored
Remove hdfs params in Python (#2834)
* Remove hdfs params in Python * fomat code
1 parent 2c17885 commit fec2e8e

16 files changed

Lines changed: 52 additions & 173 deletions

File tree

python/runtime/db.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,7 @@ def buffered_db_writer(conn, table_name, table_schema, buff_size=100):
199199
elif driver == "mysql":
200200
w = db_writer.MySQLDBWriter(conn, table_name, table_schema, buff_size)
201201
elif driver == "hive":
202-
w = db_writer.HiveDBWriter(
203-
conn,
204-
table_name,
205-
table_schema,
206-
buff_size,
207-
hdfs_namenode_addr=conn.param("hdfs_namenode_addr", ""),
208-
hive_location=conn.param("hive_location", ""),
209-
hdfs_user=conn.param("hdfs_user", ""),
210-
hdfs_pass=conn.param("hdfs_pass", ""))
202+
w = db_writer.HiveDBWriter(conn, table_name, table_schema, buff_size)
211203
elif driver == "paiio":
212204
w = db_writer.PAIMaxComputeDBWriter(table_name, table_schema,
213205
buff_size)

python/runtime/db_writer/hive.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,15 @@
2121

2222

2323
class HiveDBWriter(BufferedDBWriter):
24-
def __init__(self,
25-
conn,
26-
table_name,
27-
table_schema,
28-
buff_size=10000,
29-
hdfs_namenode_addr="",
30-
hive_location="",
31-
hdfs_user="",
32-
hdfs_pass=""):
24+
def __init__(self, conn, table_name, table_schema, buff_size=10000):
3325
super().__init__(conn, table_name, table_schema, buff_size)
3426
self.tmp_f = tempfile.NamedTemporaryFile(dir="./")
3527
self.f = open(self.tmp_f.name, "w")
3628
self.schema_idx = self._indexing_table_schema(table_schema)
37-
self.hdfs_namenode_addr = hdfs_namenode_addr
38-
self.hive_location = hive_location
39-
self.hdfs_user = hdfs_user
40-
self.hdfs_pass = hdfs_pass
29+
self.hdfs_namenode_addr = conn.param("hdfs_namenode_addr")
30+
self.hive_location = conn.param("hive_location")
31+
self.hdfs_user = conn.uripts.username
32+
self.hdfs_pass = conn.uripts.password
4133

4234
def _indexing_table_schema(self, table_schema):
4335
column_list = self.conn.get_table_schema(self.table_name)

python/runtime/dbapi/hive.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def __init__(self, conn_uri):
8282
if k.startswith("session.")])
8383

8484
def _get_result_set(self, statement):
85-
cursor = self._conn.cursor(configuration=self._session_cfg)
85+
cursor = self._conn.cursor(user=self.uripts.username,
86+
configuration=self._session_cfg)
8687
try:
8788
cursor.execute(statement.rstrip(";"))
8889
return HiveResultSet(cursor)

python/runtime/pai/tensorflow/evaluate.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import sys
1515

1616
import tensorflow as tf
17+
from runtime.dbapi.paiio import PaiIOConnection
1718
from runtime.model import oss
1819
from runtime.pai.pai_distributed import define_tf_flags
1920
from runtime.tensorflow import is_tf_estimator
@@ -131,12 +132,5 @@ def _evaluate(datasource,
131132

132133
if result_table:
133134
metric_name_list = ["loss"] + validation_metrics
134-
write_result_metrics(result_metrics,
135-
metric_name_list,
136-
result_table,
137-
"paiio",
138-
None,
139-
hdfs_namenode_addr="",
140-
hive_location="",
141-
hdfs_user="",
142-
hdfs_pass="")
135+
write_result_metrics(result_metrics, metric_name_list, result_table,
136+
PaiIOConnection.from_table(result_table))

python/runtime/pai/tensorflow/explain.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,19 @@ def _input_fn():
113113
return dataset.batch(1).cache()
114114

115115
estimator = init_model_with_feature_column(estimator_cls, model_params)
116-
driver = "paiio"
117116
conn = PaiIOConnection.from_table(result_table) if result_table else None
118117
if estimator_cls in (tf.estimator.BoostedTreesClassifier,
119118
tf.estimator.BoostedTreesRegressor):
120119
explain_boosted_trees(datasource, estimator, _input_fn, plot_type,
121-
result_table, feature_column_names, driver, conn,
122-
"", "", "", "", oss_dest, oss_ak, oss_sk,
123-
oss_endpoint, oss_bucket_name)
120+
result_table, feature_column_names, conn,
121+
oss_dest, oss_ak, oss_sk, oss_endpoint,
122+
oss_bucket_name)
124123
else:
125124
shap_dataset = pd.DataFrame(columns=feature_column_names)
126125
for i, (features, label) in enumerate(_input_fn()):
127126
shap_dataset.loc[i] = [
128127
item.numpy()[0][0] for item in features.values()
129128
]
130129
explain_dnns(datasource, estimator, shap_dataset, plot_type,
131-
result_table, feature_column_names, driver, conn, "", "",
132-
"", "", oss_dest, oss_ak, oss_sk, oss_endpoint,
133-
oss_bucket_name)
130+
result_table, feature_column_names, conn, oss_dest,
131+
oss_ak, oss_sk, oss_endpoint, oss_bucket_name)

python/runtime/pai/tensorflow/predict.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def _predict(datasource,
9797
model_params.update(feature_columns)
9898
is_estimator = is_tf_estimator(estimator)
9999

100-
driver = "paiio"
101100
conn = PaiIOConnection.from_table(pai_table)
102101
selected_cols = db.selected_cols(conn, None)
103102
predict_generator = db.db_generator(conn, None)
@@ -107,42 +106,16 @@ def _predict(datasource,
107106
# functional model need field_metas parameter
108107
model_params["field_metas"] = feature_metas
109108
print("Start predicting using keras model...")
110-
keras_predict(estimator,
111-
model_params,
112-
save,
113-
result_table,
114-
feature_column_names,
115-
feature_metas,
116-
train_label_name,
117-
result_col_name,
118-
driver,
119-
conn,
120-
predict_generator,
121-
selected_cols,
122-
hdfs_namenode_addr="",
123-
hive_location="",
124-
hdfs_user="",
125-
hdfs_pass="")
109+
keras_predict(estimator, model_params, save, result_table,
110+
feature_column_names, feature_metas, train_label_name,
111+
result_col_name, conn, predict_generator, selected_cols)
126112
else:
127113
model_params['model_dir'] = save
128114
print("Start predicting using estimator model...")
129-
estimator_predict(estimator,
130-
model_params,
131-
save,
132-
result_table,
133-
feature_column_names,
134-
feature_column_names_map,
135-
feature_columns,
136-
feature_metas,
137-
train_label_name,
138-
result_col_name,
139-
driver,
140-
conn,
141-
predict_generator,
142-
selected_cols,
143-
hdfs_namenode_addr="",
144-
hive_location="",
145-
hdfs_user="",
146-
hdfs_pass="")
115+
estimator_predict(estimator, model_params, save, result_table,
116+
feature_column_names, feature_column_names_map,
117+
feature_columns, feature_metas, train_label_name,
118+
result_col_name, conn, predict_generator,
119+
selected_cols)
147120

148121
print("Done predicting. Predict table : %s" % result_table)

python/runtime/pai/xgboost/evaluate.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@ def evaluate(datasource, select, data_table, result_table, oss_model_path,
5353
label_meta=label_meta,
5454
result_table=result_table,
5555
validation_metrics=metrics,
56-
hdfs_namenode_addr="",
57-
hive_location="",
58-
hdfs_user="",
59-
hdfs_pass="",
6056
is_pai=True,
6157
pai_table=data_table,
6258
model_params=model_params,

python/runtime/pai/xgboost/explain.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,6 @@ def explain(datasource, select, data_table, result_table, label_column,
5555
result_table=result_table,
5656
is_pai=True,
5757
pai_explain_table=data_table,
58-
hdfs_namenode_addr="",
59-
hive_location="",
60-
hdfs_user="",
61-
hdfs_pass="",
6258
# (TODO:lhw) save/load explain result storage info into/from FLAGS
6359
oss_dest="",
6460
oss_ak="",

python/runtime/pai/xgboost/predict.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,6 @@ def predict(datasource, select, data_table, result_table, label_column,
5858
pred_label_meta=label_meta,
5959
result_table=result_table,
6060
is_pai=True,
61-
hdfs_namenode_addr="",
62-
hive_location="",
63-
hdfs_user="",
64-
hdfs_pass="",
6561
pai_table=data_table,
6662
model_params=model_params,
6763
train_params=train_params,

python/runtime/tensorflow/evaluate.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,7 @@ def evaluate(datasource,
3636
save="",
3737
batch_size=1,
3838
validation_steps=None,
39-
verbose=0,
40-
hdfs_namenode_addr="",
41-
hive_location="",
42-
hdfs_user="",
43-
hdfs_pass=""):
39+
verbose=0):
4440
estimator_cls = import_model(estimator_string)
4541
is_estimator = is_tf_estimator(estimator_cls)
4642
set_log_level(verbose, is_estimator)
@@ -68,18 +64,10 @@ def evaluate(datasource,
6864

6965
# write result metrics to a table
7066
conn = connect_with_data_source(datasource)
71-
driver = conn.driver
7267
if result_table:
7368
metric_name_list = ["loss"] + validation_metrics
74-
write_result_metrics(result_metrics,
75-
metric_name_list,
76-
result_table,
77-
driver,
78-
conn,
79-
hdfs_namenode_addr=hdfs_namenode_addr,
80-
hive_location=hive_location,
81-
hdfs_user=hdfs_user,
82-
hdfs_pass=hdfs_pass)
69+
write_result_metrics(result_metrics, metric_name_list, result_table,
70+
conn)
8371

8472

8573
def estimator_evaluate(estimator, eval_dataset, validation_metrics):
@@ -152,9 +140,7 @@ def get_features(sample, label):
152140
return result_metrics
153141

154142

155-
def write_result_metrics(result_metrics, metric_name_list, result_table,
156-
driver, conn, hdfs_namenode_addr, hive_location,
157-
hdfs_user, hdfs_pass):
143+
def write_result_metrics(result_metrics, metric_name_list, result_table, conn):
158144
# NOTE: assume that the result table is already created with columns:
159145
# loss | metric_names ...
160146
column_names = metric_name_list

0 commit comments

Comments
 (0)