Skip to content

Commit 4ad13bb

Browse files
authored
add the config of stage 1
1 parent 5631f68 commit 4ad13bb

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

lib/models/cls_hrnet.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,21 @@ def __init__(self, cfg, **kwargs):
263263
bias=False)
264264
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
265265
self.relu = nn.ReLU(inplace=True)
266-
self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
266+
267+
self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']
268+
num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
269+
block = blocks_dict[self.stage1_cfg['BLOCK']]
270+
num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
271+
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
272+
stage1_out_channel = block.expansion*num_channels
267273

268274
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
269275
num_channels = self.stage2_cfg['NUM_CHANNELS']
270276
block = blocks_dict[self.stage2_cfg['BLOCK']]
271277
num_channels = [
272278
num_channels[i] * block.expansion for i in range(len(num_channels))]
273279
self.transition1 = self._make_transition_layer(
274-
[256], num_channels)
280+
[stage1_out_channel], num_channels)
275281
self.stage2, pre_stage_channels = self._make_stage(
276282
self.stage2_cfg, num_channels)
277283

0 commit comments

Comments
 (0)