We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1782e05 commit 90a77c0Copy full SHA for 90a77c0
1 file changed
lib/models/cls_hrnet.py
@@ -474,9 +474,12 @@ def forward(self, x):
474
475
y = self.final_layer(y)
476
477
- y = F.avg_pool2d(y, kernel_size=y.size()
478
- [2:]).view(y.size(0), -1)
479
-
+ if torch._C._get_tracing_state():
+ y = y.flatten(start_dim=2).mean(dim=2)
+ else:
480
+ y = F.avg_pool2d(y, kernel_size=y.size()
481
+ [2:]).view(y.size(0), -1)
482
+
483
y = self.classifier(y)
484
485
return y
0 commit comments