Skip to content

Commit 90a77c0

Browse files
committed
Enable ONNX export
1 parent 1782e05 commit 90a77c0

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

lib/models/cls_hrnet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,12 @@ def forward(self, x):
474474

475475
y = self.final_layer(y)
476476

477-
y = F.avg_pool2d(y, kernel_size=y.size()
478-
[2:]).view(y.size(0), -1)
479-
477+
if torch._C._get_tracing_state():
478+
y = y.flatten(start_dim=2).mean(dim=2)
479+
else:
480+
y = F.avg_pool2d(y, kernel_size=y.size()
481+
[2:]).view(y.size(0), -1)
482+
480483
y = self.classifier(y)
481484

482485
return y

0 commit comments

Comments
 (0)