Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mmagic/datasets/transforms/random_degradations.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _apply_random_blur(self, imgs):
return imgs

def __call__(self, results):
"""Call this transform."""
if np.random.uniform() > self.params.get('prob', 1):
return results

Expand All @@ -147,6 +148,7 @@ def __call__(self, results):
return results

def __repr__(self):
"""Print the basic information of the transform."""
repr_str = self.__class__.__name__
repr_str += (f'(params={self.params}, keys={self.keys})')
return repr_str
Expand Down Expand Up @@ -208,6 +210,7 @@ def _apply_random_compression(self, imgs):
return outputs

def __call__(self, results):
"""Call this transform."""
if np.random.uniform() > self.params.get('prob', 1):
return results

Expand All @@ -217,6 +220,7 @@ def __call__(self, results):
return results

def __repr__(self):
"""Print the basic information of the transform."""
repr_str = self.__class__.__name__
repr_str += (f'(params={self.params}, keys={self.keys})')
return repr_str
Expand Down Expand Up @@ -329,6 +333,7 @@ def _apply_random_noise(self, imgs):
return imgs

def __call__(self, results):
"""Call this transform."""
if np.random.uniform() > self.params.get('prob', 1):
return results

Expand All @@ -338,6 +343,7 @@ def __call__(self, results):
return results

def __repr__(self):
"""Print the basic information of the transform."""
repr_str = self.__class__.__name__
repr_str += (f'(params={self.params}, keys={self.keys})')
return repr_str
Expand Down Expand Up @@ -443,6 +449,7 @@ def _random_resize(self, imgs):
return outputs

def __call__(self, results):
"""Call this transform."""
if np.random.uniform() > self.params.get('prob', 1):
return results

Expand All @@ -452,6 +459,7 @@ def __call__(self, results):
return results

def __repr__(self):
"""Print the basic information of the transform."""
repr_str = self.__class__.__name__
repr_str += (f'(params={self.params}, keys={self.keys})')
return repr_str
Expand Down Expand Up @@ -519,6 +527,7 @@ def _apply_random_compression(self, imgs):
return outputs

def __call__(self, results):
"""Call this transform."""
if np.random.uniform() > self.params.get('prob', 1):
return results

Expand All @@ -528,6 +537,7 @@ def __call__(self, results):
return results

def __repr__(self):
"""Print the basic information of the transform."""
repr_str = self.__class__.__name__
repr_str += (f'(params={self.params}, keys={self.keys})')
return repr_str
Expand Down Expand Up @@ -593,6 +603,7 @@ def _build_degradations(self, degradations):
return degradations

def __call__(self, results):
"""Call this transform."""
# shuffle degradations
if len(self.shuffle_idx) > 0:
shuffle_list = [self.degradations[i] for i in self.shuffle_idx]
Expand All @@ -611,6 +622,7 @@ def __call__(self, results):
return results

def __repr__(self):
"""Print the basic information of the transform."""
repr_str = self.__class__.__name__
repr_str += (f'(degradations={self.degradations}, '
f'keys={self.keys}, '
Expand Down
5 changes: 5 additions & 0 deletions mmagic/models/editors/deblurganv2/deblurganv2_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self, norm_layer, num_filter=256, pretrained='imagenet'):
param.requires_grad = False

def unfreeze(self):
"""Unfreeze params."""
for param in self.inception.parameters():
param.requires_grad = True

Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(self,
num_filter // 2, output_ch, kernel_size=3, padding=1)

def unfreeze(self):
"""Unfreeze params."""
self.fpn.unfreeze()

def forward(self, x):
Expand Down Expand Up @@ -256,6 +258,7 @@ def __init__(self, norm_layer, num_filters=256):
param.requires_grad = False

def unfreeze(self):
"""Unfreeze params."""
for param in self.inception.parameters():
param.requires_grad = True

Expand Down Expand Up @@ -338,6 +341,7 @@ def __init__(self,
num_filter // 2, output_ch, kernel_size=3, padding=1)

def unfreeze(self):
"""unfreeze the fpn network."""
self.fpn.unfreeze()

def forward(self, x):
Expand Down Expand Up @@ -423,6 +427,7 @@ def __init__(self, norm_layer, num_filters=128, pretrained=None):
param.requires_grad = False

def unfreeze(self):
"""Unfreeze params."""
for param in self.features.parameters():
param.requires_grad = True

Expand Down