|
| 1 | +# Copyright (c) MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import json |
| 15 | +import os |
| 16 | +import tempfile |
| 17 | +import unittest |
| 18 | +from pathlib import Path |
| 19 | + |
| 20 | +from parameterized import parameterized |
| 21 | + |
| 22 | +from monai.bundle import ConfigParser |
| 23 | +from monai.data import load_exported_program |
| 24 | +from monai.networks import save_state |
| 25 | +from tests.test_utils import command_line_tests, skip_if_windows |
| 26 | + |
| 27 | +TESTS_PATH = Path(__file__).parents[1] |
| 28 | + |
| 29 | +# key_in_ckpt |
| 30 | +TEST_CASE_1 = [""] |
| 31 | +TEST_CASE_2 = ["model"] |
| 32 | + |
| 33 | + |
| 34 | +@skip_if_windows |
| 35 | +class TestExportCheckpoint(unittest.TestCase): |
| 36 | + |
| 37 | + def setUp(self): |
| 38 | + self._orig_cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES") |
| 39 | + |
| 40 | + def tearDown(self): |
| 41 | + if self._orig_cuda_env is not None: |
| 42 | + os.environ["CUDA_VISIBLE_DEVICES"] = self._orig_cuda_env |
| 43 | + else: |
| 44 | + os.environ.pop("CUDA_VISIBLE_DEVICES", None) |
| 45 | + |
| 46 | + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) |
| 47 | + def test_export(self, key_in_ckpt): |
| 48 | + meta_file = os.path.join(TESTS_PATH, "testing_data", "metadata.json") |
| 49 | + config_file = os.path.join(TESTS_PATH, "testing_data", "inference.json") |
| 50 | + with tempfile.TemporaryDirectory() as tempdir: |
| 51 | + def_args = {"meta_file": "will be replaced by `meta_file` arg"} |
| 52 | + def_args_file = os.path.join(tempdir, "def_args.yaml") |
| 53 | + |
| 54 | + ckpt_file = os.path.join(tempdir, "model.pt") |
| 55 | + pt2_file = os.path.join(tempdir, "model.pt2") |
| 56 | + |
| 57 | + parser = ConfigParser() |
| 58 | + parser.export_config_file(config=def_args, filepath=def_args_file) |
| 59 | + parser.read_config(config_file) |
| 60 | + net = parser.get_parsed_content("network_def") |
| 61 | + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) |
| 62 | + |
| 63 | + cmd = [ |
| 64 | + "coverage", "run", "-m", "monai.bundle", "export_checkpoint", |
| 65 | + "network_def", "--filepath", pt2_file, |
| 66 | + "--meta_file", meta_file, |
| 67 | + "--config_file", f"['{config_file}','{def_args_file}']", |
| 68 | + "--ckpt_file", ckpt_file, |
| 69 | + "--key_in_ckpt", key_in_ckpt, |
| 70 | + "--args_file", def_args_file, |
| 71 | + "--input_shape", "[1, 1, 96, 96, 96]", |
| 72 | + ] |
| 73 | + command_line_tests(cmd) |
| 74 | + self.assertTrue(os.path.exists(pt2_file)) |
| 75 | + |
| 76 | + _, _metadata, extra_files = load_exported_program( |
| 77 | + pt2_file, more_extra_files=["inference.json", "def_args.json"] |
| 78 | + ) |
| 79 | + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) |
| 80 | + self.assertIn("network_def", json.loads(extra_files["inference.json"])) |
| 81 | + |
| 82 | + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) |
| 83 | + def test_default_value(self, key_in_ckpt): |
| 84 | + config_file = os.path.join(TESTS_PATH, "testing_data", "inference.json") |
| 85 | + with tempfile.TemporaryDirectory() as tempdir: |
| 86 | + def_args = {"meta_file": "will be replaced by `meta_file` arg"} |
| 87 | + def_args_file = os.path.join(tempdir, "def_args.yaml") |
| 88 | + ckpt_file = os.path.join(tempdir, "models", "model.pt") |
| 89 | + pt2_file = os.path.join(tempdir, "models", "model.pt2") |
| 90 | + |
| 91 | + parser = ConfigParser() |
| 92 | + parser.export_config_file(config=def_args, filepath=def_args_file) |
| 93 | + parser.read_config(config_file) |
| 94 | + net = parser.get_parsed_content("network_def") |
| 95 | + save_state(src=net if key_in_ckpt == "" else {key_in_ckpt: net}, path=ckpt_file) |
| 96 | + |
| 97 | + # check with default value |
| 98 | + cmd = [ |
| 99 | + "coverage", "run", "-m", "monai.bundle", "export_checkpoint", |
| 100 | + "--key_in_ckpt", key_in_ckpt, |
| 101 | + "--config_file", config_file, |
| 102 | + "--bundle_root", tempdir, |
| 103 | + "--input_shape", "[1, 1, 96, 96, 96]", |
| 104 | + ] |
| 105 | + command_line_tests(cmd) |
| 106 | + self.assertTrue(os.path.exists(pt2_file)) |
| 107 | + |
| 108 | + |
| 109 | +if __name__ == "__main__": |
| 110 | + unittest.main() |
0 commit comments