Skip to content

Commit e9b5f7b

Browse files
authored
add temp_file.py (#2838)
1 parent 277ab0f commit e9b5f7b

6 files changed

Lines changed: 85 additions & 53 deletions

File tree

go/codegen/experimental/xgboost.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ const xgbTrainTemplate = `
141141
def step_entry_{{.StepIndex}}():
142142
import json
143143
import os
144-
import tempfile
144+
import runtime.temp_file as temp_file
145145
import runtime.feature.column as fc
146146
import runtime.feature.field_desc as fd
147147
import runtime.{{.Submitter}}.xgboost as xgboost_submitter
@@ -156,8 +156,7 @@ def step_entry_{{.StepIndex}}():
156156
model_params = json.loads('''{{.ModelParamsJSON}}''')
157157
train_params = json.loads('''{{.TrainParamsJSON}}''')
158158
159-
with tempfile.TemporaryDirectory() as temp_dir:
160-
os.chdir(temp_dir)
159+
with temp_file.TemporaryDirectory(as_cwd=True) as temp_dir:
161160
xgboost_submitter.train(original_sql='''{{.OriginalSQL}}''',
162161
model_image='''{{.ModelImage}}''',
163162
estimator='''{{.Estimator}}''',

python/runtime/local/xgboost/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
This module launches a XGBoost training task on host.
1515
"""
1616
import os
17-
import tempfile
1817
import types
1918

2019
import runtime.db as db
20+
import runtime.temp_file as temp_file
2121
import runtime.xgboost as xgboost_extended
2222
import xgboost as xgb
2323
from runtime.feature.compile import compile_ir_feature_columns
@@ -106,7 +106,7 @@ def build_dataset(fn, slct):
106106
else:
107107
bst = None
108108

109-
with tempfile.TemporaryDirectory() as tmp_dir_name:
109+
with temp_file.TemporaryDirectory() as tmp_dir_name:
110110
train_fn = os.path.join(tmp_dir_name, 'train.txt')
111111
val_fn = os.path.join(tmp_dir_name, 'val.txt')
112112
train_dataset = build_dataset(train_fn, select)

python/runtime/model/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
"""This module saves or loads the SQLFlow model.
1414
"""
1515
import os
16-
import tempfile
1716
from enum import Enum
1817

18+
import runtime.temp_file as temp_file
1919
from runtime.model import oss
2020
from runtime.model.db import read_with_generator, write_with_generator
2121
from runtime.model.tar import unzip_dir, zip_dir
@@ -143,7 +143,7 @@ def save_to_db(self, datasource, table, local_dir=None):
143143
if local_dir is None:
144144
local_dir = os.getcwd()
145145

146-
with tempfile.TemporaryDirectory() as tmp_dir:
146+
with temp_file.TemporaryDirectory() as tmp_dir:
147147
tarball = os.path.join(tmp_dir, TARBALL_NAME)
148148
self._zip(local_dir, tarball)
149149

@@ -178,7 +178,7 @@ def load_from_db(datasource, table, local_dir=None):
178178
if local_dir is None:
179179
local_dir = os.getcwd()
180180

181-
with tempfile.TemporaryDirectory() as tmp_dir:
181+
with temp_file.TemporaryDirectory() as tmp_dir:
182182
tarball = os.path.join(tmp_dir, TARBALL_NAME)
183183
gen = read_with_generator(datasource, table)
184184
with open(tarball, "wb") as f:
@@ -203,7 +203,7 @@ def save_to_oss(self, oss_model_dir, local_dir=None):
203203
if local_dir is None:
204204
local_dir = os.getcwd()
205205

206-
with tempfile.TemporaryDirectory() as tmp_dir:
206+
with temp_file.TemporaryDirectory() as tmp_dir:
207207
tarball = os.path.join(tmp_dir, TARBALL_NAME)
208208
self._zip(local_dir, tarball)
209209
oss.save_file(oss_model_dir, tarball, TARBALL_NAME)
@@ -225,7 +225,7 @@ def load_from_oss(oss_model_dir, local_dir=None):
225225
if local_dir is None:
226226
local_dir = os.getcwd()
227227

228-
with tempfile.TemporaryDirectory() as tmp_dir:
228+
with temp_file.TemporaryDirectory() as tmp_dir:
229229
tarball = os.path.join(tmp_dir, TARBALL_NAME)
230230
oss.load_file(oss_model_dir, tarball, TARBALL_NAME)
231231
return Model._unzip(local_dir, tarball)

python/runtime/model/model_test.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,27 @@
1212
# limitations under the License.
1313

1414
import os
15-
import tempfile
1615
import unittest
1716

1817
import runtime.model.oss as oss
18+
import runtime.temp_file as temp_file
1919
from runtime.model import EstimatorType, Model
2020
from runtime.testing import get_datasource
2121

2222

2323
class TestModel(unittest.TestCase):
24-
def setUp(self):
25-
self.cur_dir = os.getcwd()
26-
27-
def tearDown(self):
28-
os.chdir(self.cur_dir)
29-
3024
def test_save_load_db(self):
3125
table = "sqlflow_models.test_model"
3226
meta = {"model_params": {"n_classes": 3}}
3327
m = Model(EstimatorType.XGBOOST, meta)
3428
datasource = get_datasource()
3529

3630
# save mode
37-
with tempfile.TemporaryDirectory() as d:
31+
with temp_file.TemporaryDirectory() as d:
3832
m.save_to_db(datasource, table, d)
3933

4034
# load model
41-
with tempfile.TemporaryDirectory() as d:
35+
with temp_file.TemporaryDirectory() as d:
4236
m = Model.load_from_db(datasource, table, d)
4337
self.assertEqual(m._meta, meta)
4438

@@ -57,12 +51,12 @@ def test_save_load_oss(self):
5751

5852
# save model
5953
def save_to_oss():
60-
with tempfile.TemporaryDirectory() as d:
54+
with temp_file.TemporaryDirectory() as d:
6155
m.save_to_oss(oss_model_path, d)
6256

6357
# load model
6458
def load_from_oss():
65-
with tempfile.TemporaryDirectory() as d:
59+
with temp_file.TemporaryDirectory() as d:
6660
return Model.load_from_oss(oss_model_path, d)
6761

6862
with self.assertRaises(Exception):

python/runtime/model/tar_test.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,38 @@
1212
# limitations under the License.
1313

1414
import os
15-
import tempfile
1615
import unittest
1716

17+
import runtime.temp_file as temp_file
1818
from runtime.model.tar import unzip_dir, zip_dir
1919

2020

2121
class TestTarOperator(unittest.TestCase):
22-
def setUp(self):
23-
self.test_dir = tempfile.TemporaryDirectory()
24-
self.cur_dir = os.getcwd()
25-
os.chdir(self.test_dir.name)
26-
27-
def tearDown(self):
28-
self.test_dir.cleanup()
29-
os.chdir(self.cur_dir)
30-
3122
def test_tar(self):
32-
# create the test file tree:
33-
#
34-
# |-sqlflow_tar
35-
# |-sqlflow_sub_dir
36-
# |-hello.py
37-
test_dir = "sqlflow_tar"
38-
test_sub_dir = "sqlflow_sub_dir"
39-
test_py_file = "hello.py"
40-
test_py_content = "print('hello SQLFlow!')"
41-
42-
fullpath = os.path.join(test_dir, test_sub_dir)
43-
os.makedirs(fullpath)
44-
with open(os.path.join(fullpath, test_py_file), "w") as f:
45-
f.write(test_py_content)
46-
47-
zip_dir(fullpath, "sqlflow.tar.gz")
48-
unzip_dir("sqlflow.tar.gz", "output")
49-
self.assertTrue(os.path.isdir("output/sqlflow_tar/sqlflow_sub_dir"))
50-
self.assertTrue(
51-
os.path.isfile("output/sqlflow_tar/sqlflow_sub_dir/hello.py"))
52-
with open(os.path.join(fullpath, test_py_file), "r") as f:
53-
self.assertEqual(f.read(), test_py_content)
23+
with temp_file.TemporaryDirectory(as_cwd=True):
24+
# create the test file tree:
25+
#
26+
# |-sqlflow_tar
27+
# |-sqlflow_sub_dir
28+
# |-hello.py
29+
test_dir = "sqlflow_tar"
30+
test_sub_dir = "sqlflow_sub_dir"
31+
test_py_file = "hello.py"
32+
test_py_content = "print('hello SQLFlow!')"
33+
34+
fullpath = os.path.join(test_dir, test_sub_dir)
35+
os.makedirs(fullpath)
36+
with open(os.path.join(fullpath, test_py_file), "w") as f:
37+
f.write(test_py_content)
38+
39+
zip_dir(fullpath, "sqlflow.tar.gz")
40+
unzip_dir("sqlflow.tar.gz", "output")
41+
self.assertTrue(
42+
os.path.isdir("output/sqlflow_tar/sqlflow_sub_dir"))
43+
self.assertTrue(
44+
os.path.isfile("output/sqlflow_tar/sqlflow_sub_dir/hello.py"))
45+
with open(os.path.join(fullpath, test_py_file), "r") as f:
46+
self.assertEqual(f.read(), test_py_content)
5447

5548

5649
if __name__ == '__main__':

python/runtime/temp_file.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import os
15+
import shutil
16+
import tempfile
17+
18+
19+
# NOTE: Python 2 does not have tempfile.TemporaryDirectory. To unify the code
20+
# of Python 2 and 3, we make the following class.
21+
class TemporaryDirectory(object):
22+
def __init__(self, as_cwd=False, suffix=None, prefix=None, dir=None):
23+
"""
24+
Create a temporary directory.
25+
26+
Args:
27+
as_cwd (bool): whether to change the current working directory
28+
as the created temporary directory.
29+
suffix (str): the suffix of the created temporary directory.
30+
prefix (str): the prefix of the created temporary directory.
31+
dir (str): where to create the temporary directory.
32+
"""
33+
self.tmp_dir = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=dir)
34+
self.as_cwd = as_cwd
35+
if self.as_cwd:
36+
self.old_dir = os.getcwd()
37+
38+
def __enter__(self, *args, **kwargs):
39+
if self.as_cwd:
40+
os.chdir(self.tmp_dir)
41+
return self.tmp_dir
42+
43+
def __exit__(self, *args, **kwargs):
44+
if self.as_cwd:
45+
os.chdir(self.old_dir)
46+
shutil.rmtree(self.tmp_dir, ignore_errors=True)

0 commit comments

Comments
 (0)