Skip to content

Commit 530fa0e

Browse files
authored
Merge pull request #13 from BowenBao/onnx
Enable ONNX export
2 parents 1782e05 + 90a77c0 commit 530fa0e

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

lib/models/cls_hrnet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,12 @@ def forward(self, x):
474474

475475
y = self.final_layer(y)
476476

477-
y = F.avg_pool2d(y, kernel_size=y.size()
478-
[2:]).view(y.size(0), -1)
479-
477+
if torch._C._get_tracing_state():
478+
y = y.flatten(start_dim=2).mean(dim=2)
479+
else:
480+
y = F.avg_pool2d(y, kernel_size=y.size()
481+
[2:]).view(y.size(0), -1)
482+
480483
y = self.classifier(y)
481484

482485
return y

0 commit comments

Comments
 (0)