Skip to content

Commit 670b7bc

Browse files
authored
Fix couler evaluate step (#2848)
* fix couler generate evaluation step * fix couler evaluate step * fix test
1 parent e9b5f7b commit 670b7bc

2 files changed

Lines changed: 12 additions & 7 deletions

File tree

go/cmd/sqlflowserver/e2e_workflow_test.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,25 +120,30 @@ func CaseWorkflowTrainAndPredictDNN(t *testing.T) {
120120
sqlProgram := fmt.Sprintf(`
121121
SELECT * FROM %s LIMIT 10;
122122
123-
SELECT *
124-
FROM %s
123+
SELECT * FROM %s
125124
TO TRAIN DNNClassifier
126125
WITH
127126
model.n_classes = 3,
128127
model.hidden_units = [10, 20],
129128
validation.select = "SELECT * FROM %s"
130-
COLUMN sepal_length, sepal_width, petal_length, petal_width
131129
LABEL class
132130
INTO %s;
133131
134-
SELECT *
135-
FROM %s
132+
SELECT * FROM %s
133+
TO EVALUATE %s
134+
WITH validation.metrics="Accuracy"
135+
LABEL class
136+
INTO %s.sqlflow_iris_eval_result;
137+
138+
SELECT * FROM %s
136139
TO PREDICT %s.class
137140
USING %s;
138141
139142
SELECT *
140143
FROM %s LIMIT 5;
141-
`, caseTrainTable, caseTrainTable, caseTestTable, caseInto, caseTestTable, casePredictTable, caseInto, casePredictTable)
144+
`, caseTrainTable, caseTrainTable, caseTestTable, caseInto,
145+
caseTestTable, caseInto, caseDB,
146+
caseTestTable, casePredictTable, caseInto, casePredictTable)
142147

143148
conn, err := createRPCConn()
144149
if err != nil {

go/workflow/couler/codegen.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func GenFiller(programIR []ir.SQLFlowStmt, session *pb.Session) (*Filler, error)
137137

138138
for _, sqlIR := range programIR {
139139
switch i := sqlIR.(type) {
140-
case *ir.NormalStmt, *ir.PredictStmt, *ir.ExplainStmt:
140+
case *ir.NormalStmt, *ir.PredictStmt, *ir.ExplainStmt, *ir.EvaluateStmt:
141141
// TODO(typhoonzero): get model image used when training.
142142
sqlStmt := &sqlStatement{
143143
OriginalSQL: sqlIR.GetOriginalSQL(), IsExtendedSQL: sqlIR.IsExtended(),

0 commit comments

Comments
 (0)