@@ -34,7 +34,7 @@ num_boost_round = {{.NumBoostRound}}
3434maximize = True if "{{.Maximize}}" == "true" else False
3535early_stopping_rounds = {{.EarlyStoppingRounds}}
3636if early_stopping_rounds == -1:
37- early_stopping_rounds = None
37+ early_stopping_rounds = None
3838
3939{{if ne .ParamsCfgJSON ""}}
4040params = {{.ParamsCfgJSON}}
@@ -58,22 +58,20 @@ feature_specs["{{$value.FeatureName}}"] = {
5858}
5959{{end}}
6060
61-
62-
6361conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")
6462
6563def xgb_dataset(fn, dataset_sql):
66- gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs)
67- with open(fn, 'w') as f:
68- for item in gen():
69- features, label = item
70- row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
71- f.write("\t".join(row_data) + "\n")
72- # TODO(yancey1989): genearte group and weight text file if necessary
73- return xgb.DMatrix(fn)
74-
75- dtrain = xgb_dataset('train.txt', "{{.TrainingDatasetSQL}}")
76- dtest = xgb_dataset('test.txt', "{{.ValidationDatasetSQL}}")
64+ gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs)
65+ with open(fn, 'w') as f:
66+ for item in gen():
67+ features, label = item
68+ row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
69+ f.write("\t".join(row_data) + "\n")
70+ # TODO(yancey1989): genearte group and weight text file if necessary
71+ return xgb.DMatrix(fn)
72+
73+ dtrain = xgb_dataset('train.txt', """ {{.TrainingDatasetSQL}}"" ")
74+ dtest = xgb_dataset('test.txt', """ {{.ValidationDatasetSQL}}"" ")
7775
7876train_args = {}
7977train_args["num_boost_round"] = num_boost_round
@@ -84,3 +82,74 @@ train_args["evals"] = [(dtrain, "train"), (dtest, "validation")]
8482bst = xgb.train(params, dtrain, **train_args)
8583bst.save_model("{{.Save}}")
8684`
85+
86+ const xgbPredictTemplateText = `
87+ import xgboost as xgb
88+ import numpy as np
89+ from sqlflow_submitter.db import connect, db_generator, buffered_db_writer
90+
91+ driver="{{.Driver}}"
92+
93+ {{if ne .Database ""}}
94+ database="{{.Database}}"
95+ {{else}}
96+ database=""
97+ {{end}}
98+
99+ session_cfg = {}
100+ {{ range $k, $v := .Session }}
101+ session_cfg["{{$k}}"] = "{{$v}}"
102+ {{end}}
103+
104+ feature_column_names = [{{range .X}}
105+ "{{.FeatureName}}",
106+ {{end}}]
107+
108+ {{/* Convert go side featureSpec to python dict for input_fn */}}
109+ feature_specs = dict()
110+ {{ range $value := .X }}
111+ feature_specs["{{$value.FeatureName}}"] = {
112+ "feature_name": "{{$value.FeatureName}}",
113+ "dtype": "{{$value.Dtype}}",
114+ "delimiter": "{{$value.Delimiter}}",
115+ "shape": {{$value.InputShape}},
116+ "is_sparse": "{{$value.IsSparse}}" == "true"
117+ }
118+ {{end}}
119+
120+ conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}")
121+
122+ def xgb_dataset(fn, dataset_sql):
123+ gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "", feature_specs)
124+ with open(fn, 'w') as f:
125+ for item in gen():
126+ features, label = item
127+ row_data = [str(label[0])] + ["%d:%f" % (i, v) for i, v in enumerate(features)]
128+ f.write("\t".join(row_data) + "\n")
129+ # TODO(yancey1989): genearte group and weight text file if necessary
130+ return xgb.DMatrix(fn)
131+
132+ dpred = xgb_dataset('predict.txt', """{{.PredictionDatasetSQL}}""")
133+
134+ bst = xgb.Booster({'nthread': 4}) # init model
135+ bst.load_model("{{.Save}}") # load data
136+ preds = bst.predict(dpred)
137+ # TODO(typhoonzero): regression models may have different behavior
138+ pred_classes = np.argmax(np.array(preds), axis=1)
139+
140+ feature_file_read = open("predict.txt", "r")
141+
142+ result_column_names = feature_column_names
143+ result_column_names.append("{{.Y.FeatureName}}")
144+
145+ line_no = 0
146+ with buffered_db_writer(driver, conn, "{{.TableName}}", result_column_names, 100) as w:
147+ while True:
148+ line = feature_file_read.readline()
149+ if not line:
150+ break
151+ row = [i.split(":")[1] for i in line.replace("\n", "").split("\t")[1:]]
152+ row.append(pred_classes[line_no])
153+ w.write(row)
154+ line_no += 1
155+ `
0 commit comments