|
12 | 12 | # limitations under the License. |
13 | 13 |
|
14 | 14 | import os |
15 | | -import tempfile |
16 | 15 | import unittest |
17 | 16 |
|
| 17 | +import runtime.temp_file as temp_file |
18 | 18 | from runtime.model.tar import unzip_dir, zip_dir |
19 | 19 |
|
20 | 20 |
|
21 | 21 | class TestTarOperator(unittest.TestCase): |
22 | | - def setUp(self): |
23 | | - self.test_dir = tempfile.TemporaryDirectory() |
24 | | - self.cur_dir = os.getcwd() |
25 | | - os.chdir(self.test_dir.name) |
26 | | - |
27 | | - def tearDown(self): |
28 | | - self.test_dir.cleanup() |
29 | | - os.chdir(self.cur_dir) |
30 | | - |
31 | 22 | def test_tar(self): |
32 | | - # create the test file tree: |
33 | | - # |
34 | | - # |-sqlflow_tar |
35 | | - # |-sqlflow_sub_dir |
36 | | - # |-hello.py |
37 | | - test_dir = "sqlflow_tar" |
38 | | - test_sub_dir = "sqlflow_sub_dir" |
39 | | - test_py_file = "hello.py" |
40 | | - test_py_content = "print('hello SQLFlow!')" |
41 | | - |
42 | | - fullpath = os.path.join(test_dir, test_sub_dir) |
43 | | - os.makedirs(fullpath) |
44 | | - with open(os.path.join(fullpath, test_py_file), "w") as f: |
45 | | - f.write(test_py_content) |
46 | | - |
47 | | - zip_dir(fullpath, "sqlflow.tar.gz") |
48 | | - unzip_dir("sqlflow.tar.gz", "output") |
49 | | - self.assertTrue(os.path.isdir("output/sqlflow_tar/sqlflow_sub_dir")) |
50 | | - self.assertTrue( |
51 | | - os.path.isfile("output/sqlflow_tar/sqlflow_sub_dir/hello.py")) |
52 | | - with open(os.path.join(fullpath, test_py_file), "r") as f: |
53 | | - self.assertEqual(f.read(), test_py_content) |
| 23 | + with temp_file.TemporaryDirectory(as_cwd=True): |
| 24 | + # create the test file tree: |
| 25 | + # |
| 26 | + # |-sqlflow_tar |
| 27 | + # |-sqlflow_sub_dir |
| 28 | + # |-hello.py |
| 29 | + test_dir = "sqlflow_tar" |
| 30 | + test_sub_dir = "sqlflow_sub_dir" |
| 31 | + test_py_file = "hello.py" |
| 32 | + test_py_content = "print('hello SQLFlow!')" |
| 33 | + |
| 34 | + fullpath = os.path.join(test_dir, test_sub_dir) |
| 35 | + os.makedirs(fullpath) |
| 36 | + with open(os.path.join(fullpath, test_py_file), "w") as f: |
| 37 | + f.write(test_py_content) |
| 38 | + |
| 39 | + zip_dir(fullpath, "sqlflow.tar.gz") |
| 40 | + unzip_dir("sqlflow.tar.gz", "output") |
| 41 | + self.assertTrue( |
| 42 | + os.path.isdir("output/sqlflow_tar/sqlflow_sub_dir")) |
| 43 | + self.assertTrue( |
| 44 | + os.path.isfile("output/sqlflow_tar/sqlflow_sub_dir/hello.py")) |
| 45 | + with open(os.path.join(fullpath, test_py_file), "r") as f: |
| 46 | + self.assertEqual(f.read(), test_py_content) |
54 | 47 |
|
55 | 48 |
|
56 | 49 | if __name__ == '__main__': |
|
0 commit comments