ZhengPeng7 commited on
Commit
26e7919
·
1 Parent(s): 1c48c2d

Upgrade the ways of importing timm modules (>=1.0.23).

Browse files
Files changed (1) hide show
  1. birefnet.py +10 -16
birefnet.py CHANGED
@@ -164,8 +164,8 @@ import torch
164
  import torch.nn as nn
165
  from functools import partial
166
 
167
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
168
- from timm.models.registry import register_model
169
 
170
  import math
171
 
@@ -545,7 +545,6 @@ def _conv_filter(state_dict, patch_size=16):
545
  return out_dict
546
 
547
 
548
- ## @register_model
549
  class pvt_v2_b0(PyramidVisionTransformerImpr):
550
  def __init__(self, **kwargs):
551
  super(pvt_v2_b0, self).__init__(
@@ -555,7 +554,6 @@ class pvt_v2_b0(PyramidVisionTransformerImpr):
555
 
556
 
557
 
558
- ## @register_model
559
  class pvt_v2_b1(PyramidVisionTransformerImpr):
560
  def __init__(self, **kwargs):
561
  super(pvt_v2_b1, self).__init__(
@@ -563,7 +561,6 @@ class pvt_v2_b1(PyramidVisionTransformerImpr):
563
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
564
  drop_rate=0.0, drop_path_rate=0.1)
565
 
566
- ## @register_model
567
  class pvt_v2_b2(PyramidVisionTransformerImpr):
568
  def __init__(self, in_channels=3, **kwargs):
569
  super(pvt_v2_b2, self).__init__(
@@ -571,7 +568,6 @@ class pvt_v2_b2(PyramidVisionTransformerImpr):
571
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
572
  drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
573
 
574
- ## @register_model
575
  class pvt_v2_b3(PyramidVisionTransformerImpr):
576
  def __init__(self, **kwargs):
577
  super(pvt_v2_b3, self).__init__(
@@ -579,7 +575,6 @@ class pvt_v2_b3(PyramidVisionTransformerImpr):
579
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
580
  drop_rate=0.0, drop_path_rate=0.1)
581
 
582
- ## @register_model
583
  class pvt_v2_b4(PyramidVisionTransformerImpr):
584
  def __init__(self, **kwargs):
585
  super(pvt_v2_b4, self).__init__(
@@ -588,7 +583,6 @@ class pvt_v2_b4(PyramidVisionTransformerImpr):
588
  drop_rate=0.0, drop_path_rate=0.1)
589
 
590
 
591
- ## @register_model
592
  class pvt_v2_b5(PyramidVisionTransformerImpr):
593
  def __init__(self, **kwargs):
594
  super(pvt_v2_b5, self).__init__(
@@ -612,7 +606,7 @@ import torch.nn as nn
612
  import torch.nn.functional as F
613
  import torch.utils.checkpoint as checkpoint
614
  import numpy as np
615
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
616
 
617
  # from config import Config
618
 
@@ -1193,7 +1187,7 @@ class SwinTransformer(nn.Module):
1193
  # interpolate the position embedding to the corresponding size
1194
  absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
1195
  x = (x + absolute_pos_embed) # B Wh*Ww C
1196
-
1197
  outs = []#x.contiguous()]
1198
  x = x.flatten(2).transpose(1, 2)
1199
  x = self.pos_drop(x)
@@ -1250,13 +1244,13 @@ class DeformableConv2d(nn.Module):
1250
  bias=False):
1251
 
1252
  super(DeformableConv2d, self).__init__()
1253
-
1254
  assert type(kernel_size) == tuple or type(kernel_size) == int
1255
 
1256
  kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
1257
  self.stride = stride if type(stride) == tuple else (stride, stride)
1258
  self.padding = padding
1259
-
1260
  self.offset_conv = nn.Conv2d(in_channels,
1261
  2 * kernel_size[0] * kernel_size[1],
1262
  kernel_size=kernel_size,
@@ -1266,7 +1260,7 @@ class DeformableConv2d(nn.Module):
1266
 
1267
  nn.init.constant_(self.offset_conv.weight, 0.)
1268
  nn.init.constant_(self.offset_conv.bias, 0.)
1269
-
1270
  self.modulator_conv = nn.Conv2d(in_channels,
1271
  1 * kernel_size[0] * kernel_size[1],
1272
  kernel_size=kernel_size,
@@ -1290,7 +1284,7 @@ class DeformableConv2d(nn.Module):
1290
 
1291
  offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
1292
  modulator = 2. * torch.sigmoid(self.modulator_conv(x))
1293
-
1294
  x = deform_conv2d(
1295
  input=x,
1296
  offset=offset,
@@ -1488,7 +1482,7 @@ class ResBlk(nn.Module):
1488
 
1489
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
1490
  self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1491
-
1492
  self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
1493
 
1494
  def forward(self, x):
@@ -2139,7 +2133,7 @@ class Decoder(nn.Module):
2139
  self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2140
  self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2141
  self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2142
-
2143
  self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2144
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2145
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
 
164
  import torch.nn as nn
165
  from functools import partial
166
 
167
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
168
+
169
 
170
  import math
171
 
 
545
  return out_dict
546
 
547
 
 
548
  class pvt_v2_b0(PyramidVisionTransformerImpr):
549
  def __init__(self, **kwargs):
550
  super(pvt_v2_b0, self).__init__(
 
554
 
555
 
556
 
 
557
  class pvt_v2_b1(PyramidVisionTransformerImpr):
558
  def __init__(self, **kwargs):
559
  super(pvt_v2_b1, self).__init__(
 
561
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
562
  drop_rate=0.0, drop_path_rate=0.1)
563
 
 
564
  class pvt_v2_b2(PyramidVisionTransformerImpr):
565
  def __init__(self, in_channels=3, **kwargs):
566
  super(pvt_v2_b2, self).__init__(
 
568
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
569
  drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
570
 
 
571
  class pvt_v2_b3(PyramidVisionTransformerImpr):
572
  def __init__(self, **kwargs):
573
  super(pvt_v2_b3, self).__init__(
 
575
  qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
576
  drop_rate=0.0, drop_path_rate=0.1)
577
 
 
578
  class pvt_v2_b4(PyramidVisionTransformerImpr):
579
  def __init__(self, **kwargs):
580
  super(pvt_v2_b4, self).__init__(
 
583
  drop_rate=0.0, drop_path_rate=0.1)
584
 
585
 
 
586
  class pvt_v2_b5(PyramidVisionTransformerImpr):
587
  def __init__(self, **kwargs):
588
  super(pvt_v2_b5, self).__init__(
 
606
  import torch.nn.functional as F
607
  import torch.utils.checkpoint as checkpoint
608
  import numpy as np
609
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
610
 
611
  # from config import Config
612
 
 
1187
  # interpolate the position embedding to the corresponding size
1188
  absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
1189
  x = (x + absolute_pos_embed) # B Wh*Ww C
1190
+
1191
  outs = []#x.contiguous()]
1192
  x = x.flatten(2).transpose(1, 2)
1193
  x = self.pos_drop(x)
 
1244
  bias=False):
1245
 
1246
  super(DeformableConv2d, self).__init__()
1247
+
1248
  assert type(kernel_size) == tuple or type(kernel_size) == int
1249
 
1250
  kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
1251
  self.stride = stride if type(stride) == tuple else (stride, stride)
1252
  self.padding = padding
1253
+
1254
  self.offset_conv = nn.Conv2d(in_channels,
1255
  2 * kernel_size[0] * kernel_size[1],
1256
  kernel_size=kernel_size,
 
1260
 
1261
  nn.init.constant_(self.offset_conv.weight, 0.)
1262
  nn.init.constant_(self.offset_conv.bias, 0.)
1263
+
1264
  self.modulator_conv = nn.Conv2d(in_channels,
1265
  1 * kernel_size[0] * kernel_size[1],
1266
  kernel_size=kernel_size,
 
1284
 
1285
  offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
1286
  modulator = 2. * torch.sigmoid(self.modulator_conv(x))
1287
+
1288
  x = deform_conv2d(
1289
  input=x,
1290
  offset=offset,
 
1482
 
1483
  self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
1484
  self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1485
+
1486
  self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
1487
 
1488
  def forward(self, x):
 
2133
  self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2134
  self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2135
  self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2136
+
2137
  self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2138
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2139
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))