Commit
·
26e7919
1
Parent(s):
1c48c2d
Upgrade the ways of importing timm modules (>=1.0.23).
Browse files- birefnet.py +10 -16
birefnet.py
CHANGED
|
@@ -164,8 +164,8 @@ import torch
|
|
| 164 |
import torch.nn as nn
|
| 165 |
from functools import partial
|
| 166 |
|
| 167 |
-
from timm.
|
| 168 |
-
|
| 169 |
|
| 170 |
import math
|
| 171 |
|
|
@@ -545,7 +545,6 @@ def _conv_filter(state_dict, patch_size=16):
|
|
| 545 |
return out_dict
|
| 546 |
|
| 547 |
|
| 548 |
-
## @register_model
|
| 549 |
class pvt_v2_b0(PyramidVisionTransformerImpr):
|
| 550 |
def __init__(self, **kwargs):
|
| 551 |
super(pvt_v2_b0, self).__init__(
|
|
@@ -555,7 +554,6 @@ class pvt_v2_b0(PyramidVisionTransformerImpr):
|
|
| 555 |
|
| 556 |
|
| 557 |
|
| 558 |
-
## @register_model
|
| 559 |
class pvt_v2_b1(PyramidVisionTransformerImpr):
|
| 560 |
def __init__(self, **kwargs):
|
| 561 |
super(pvt_v2_b1, self).__init__(
|
|
@@ -563,7 +561,6 @@ class pvt_v2_b1(PyramidVisionTransformerImpr):
|
|
| 563 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
| 564 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 565 |
|
| 566 |
-
## @register_model
|
| 567 |
class pvt_v2_b2(PyramidVisionTransformerImpr):
|
| 568 |
def __init__(self, in_channels=3, **kwargs):
|
| 569 |
super(pvt_v2_b2, self).__init__(
|
|
@@ -571,7 +568,6 @@ class pvt_v2_b2(PyramidVisionTransformerImpr):
|
|
| 571 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
| 572 |
drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
|
| 573 |
|
| 574 |
-
## @register_model
|
| 575 |
class pvt_v2_b3(PyramidVisionTransformerImpr):
|
| 576 |
def __init__(self, **kwargs):
|
| 577 |
super(pvt_v2_b3, self).__init__(
|
|
@@ -579,7 +575,6 @@ class pvt_v2_b3(PyramidVisionTransformerImpr):
|
|
| 579 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
| 580 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 581 |
|
| 582 |
-
## @register_model
|
| 583 |
class pvt_v2_b4(PyramidVisionTransformerImpr):
|
| 584 |
def __init__(self, **kwargs):
|
| 585 |
super(pvt_v2_b4, self).__init__(
|
|
@@ -588,7 +583,6 @@ class pvt_v2_b4(PyramidVisionTransformerImpr):
|
|
| 588 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 589 |
|
| 590 |
|
| 591 |
-
## @register_model
|
| 592 |
class pvt_v2_b5(PyramidVisionTransformerImpr):
|
| 593 |
def __init__(self, **kwargs):
|
| 594 |
super(pvt_v2_b5, self).__init__(
|
|
@@ -612,7 +606,7 @@ import torch.nn as nn
|
|
| 612 |
import torch.nn.functional as F
|
| 613 |
import torch.utils.checkpoint as checkpoint
|
| 614 |
import numpy as np
|
| 615 |
-
from timm.
|
| 616 |
|
| 617 |
# from config import Config
|
| 618 |
|
|
@@ -1193,7 +1187,7 @@ class SwinTransformer(nn.Module):
|
|
| 1193 |
# interpolate the position embedding to the corresponding size
|
| 1194 |
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 1195 |
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 1196 |
-
|
| 1197 |
outs = []#x.contiguous()]
|
| 1198 |
x = x.flatten(2).transpose(1, 2)
|
| 1199 |
x = self.pos_drop(x)
|
|
@@ -1250,13 +1244,13 @@ class DeformableConv2d(nn.Module):
|
|
| 1250 |
bias=False):
|
| 1251 |
|
| 1252 |
super(DeformableConv2d, self).__init__()
|
| 1253 |
-
|
| 1254 |
assert type(kernel_size) == tuple or type(kernel_size) == int
|
| 1255 |
|
| 1256 |
kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
|
| 1257 |
self.stride = stride if type(stride) == tuple else (stride, stride)
|
| 1258 |
self.padding = padding
|
| 1259 |
-
|
| 1260 |
self.offset_conv = nn.Conv2d(in_channels,
|
| 1261 |
2 * kernel_size[0] * kernel_size[1],
|
| 1262 |
kernel_size=kernel_size,
|
|
@@ -1266,7 +1260,7 @@ class DeformableConv2d(nn.Module):
|
|
| 1266 |
|
| 1267 |
nn.init.constant_(self.offset_conv.weight, 0.)
|
| 1268 |
nn.init.constant_(self.offset_conv.bias, 0.)
|
| 1269 |
-
|
| 1270 |
self.modulator_conv = nn.Conv2d(in_channels,
|
| 1271 |
1 * kernel_size[0] * kernel_size[1],
|
| 1272 |
kernel_size=kernel_size,
|
|
@@ -1290,7 +1284,7 @@ class DeformableConv2d(nn.Module):
|
|
| 1290 |
|
| 1291 |
offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
|
| 1292 |
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
| 1293 |
-
|
| 1294 |
x = deform_conv2d(
|
| 1295 |
input=x,
|
| 1296 |
offset=offset,
|
|
@@ -1488,7 +1482,7 @@ class ResBlk(nn.Module):
|
|
| 1488 |
|
| 1489 |
self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
|
| 1490 |
self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
|
| 1491 |
-
|
| 1492 |
self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
|
| 1493 |
|
| 1494 |
def forward(self, x):
|
|
@@ -2139,7 +2133,7 @@ class Decoder(nn.Module):
|
|
| 2139 |
self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2140 |
self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2141 |
self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2142 |
-
|
| 2143 |
self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2144 |
self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2145 |
self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
|
|
|
| 164 |
import torch.nn as nn
|
| 165 |
from functools import partial
|
| 166 |
|
| 167 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
| 168 |
+
|
| 169 |
|
| 170 |
import math
|
| 171 |
|
|
|
|
| 545 |
return out_dict
|
| 546 |
|
| 547 |
|
|
|
|
| 548 |
class pvt_v2_b0(PyramidVisionTransformerImpr):
|
| 549 |
def __init__(self, **kwargs):
|
| 550 |
super(pvt_v2_b0, self).__init__(
|
|
|
|
| 554 |
|
| 555 |
|
| 556 |
|
|
|
|
| 557 |
class pvt_v2_b1(PyramidVisionTransformerImpr):
|
| 558 |
def __init__(self, **kwargs):
|
| 559 |
super(pvt_v2_b1, self).__init__(
|
|
|
|
| 561 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
| 562 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 563 |
|
|
|
|
| 564 |
class pvt_v2_b2(PyramidVisionTransformerImpr):
|
| 565 |
def __init__(self, in_channels=3, **kwargs):
|
| 566 |
super(pvt_v2_b2, self).__init__(
|
|
|
|
| 568 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
| 569 |
drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
|
| 570 |
|
|
|
|
| 571 |
class pvt_v2_b3(PyramidVisionTransformerImpr):
|
| 572 |
def __init__(self, **kwargs):
|
| 573 |
super(pvt_v2_b3, self).__init__(
|
|
|
|
| 575 |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
| 576 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 577 |
|
|
|
|
| 578 |
class pvt_v2_b4(PyramidVisionTransformerImpr):
|
| 579 |
def __init__(self, **kwargs):
|
| 580 |
super(pvt_v2_b4, self).__init__(
|
|
|
|
| 583 |
drop_rate=0.0, drop_path_rate=0.1)
|
| 584 |
|
| 585 |
|
|
|
|
| 586 |
class pvt_v2_b5(PyramidVisionTransformerImpr):
|
| 587 |
def __init__(self, **kwargs):
|
| 588 |
super(pvt_v2_b5, self).__init__(
|
|
|
|
| 606 |
import torch.nn.functional as F
|
| 607 |
import torch.utils.checkpoint as checkpoint
|
| 608 |
import numpy as np
|
| 609 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
| 610 |
|
| 611 |
# from config import Config
|
| 612 |
|
|
|
|
| 1187 |
# interpolate the position embedding to the corresponding size
|
| 1188 |
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 1189 |
x = (x + absolute_pos_embed) # B Wh*Ww C
|
| 1190 |
+
|
| 1191 |
outs = []#x.contiguous()]
|
| 1192 |
x = x.flatten(2).transpose(1, 2)
|
| 1193 |
x = self.pos_drop(x)
|
|
|
|
| 1244 |
bias=False):
|
| 1245 |
|
| 1246 |
super(DeformableConv2d, self).__init__()
|
| 1247 |
+
|
| 1248 |
assert type(kernel_size) == tuple or type(kernel_size) == int
|
| 1249 |
|
| 1250 |
kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
|
| 1251 |
self.stride = stride if type(stride) == tuple else (stride, stride)
|
| 1252 |
self.padding = padding
|
| 1253 |
+
|
| 1254 |
self.offset_conv = nn.Conv2d(in_channels,
|
| 1255 |
2 * kernel_size[0] * kernel_size[1],
|
| 1256 |
kernel_size=kernel_size,
|
|
|
|
| 1260 |
|
| 1261 |
nn.init.constant_(self.offset_conv.weight, 0.)
|
| 1262 |
nn.init.constant_(self.offset_conv.bias, 0.)
|
| 1263 |
+
|
| 1264 |
self.modulator_conv = nn.Conv2d(in_channels,
|
| 1265 |
1 * kernel_size[0] * kernel_size[1],
|
| 1266 |
kernel_size=kernel_size,
|
|
|
|
| 1284 |
|
| 1285 |
offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
|
| 1286 |
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
| 1287 |
+
|
| 1288 |
x = deform_conv2d(
|
| 1289 |
input=x,
|
| 1290 |
offset=offset,
|
|
|
|
| 1482 |
|
| 1483 |
self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
|
| 1484 |
self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
|
| 1485 |
+
|
| 1486 |
self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
|
| 1487 |
|
| 1488 |
def forward(self, x):
|
|
|
|
| 2133 |
self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2134 |
self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2135 |
self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2136 |
+
|
| 2137 |
self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2138 |
self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|
| 2139 |
self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
|