Module qute.campaigns
Full training campaign definitions.
Classes
class CampaignTransforms
-
Abstract base class that defines all transforms needed for a full training campaign.
Expand source code
class CampaignTransforms(ABC): """Abstract base class that defines all transforms needed for a full training campaign.""" def __init__(self): super().__init__() @abstractmethod def get_train_transforms(self): """Return a composition of (dictionary) MapTransforms needed to train on a patch. These transforms are applied to the training dataset, to prepare the inputs to be fed into the model for the forward pass of training. """ pass @abstractmethod def get_valid_transforms(self): """Return a composition of Transforms needed to validate on a patch. These transforms are applied to the validation dataset, to prepare the inputs to be fed into the model for validation. """ pass @abstractmethod def get_test_transforms(self): """Return a composition of Transforms needed to test on a patch. These transforms are applied to the test dataset, to prepare the inputs to be fed into the model for testing. """ pass @abstractmethod def get_inference_transforms(self): """Define inference transforms to predict on patch. These transforms are applied to the output of the full inference (that is the full images, not the patches) to prepare them to be saved to disk as final inference images. """ pass @abstractmethod def get_post_inference_transforms(self): """Define post inference transforms to apply after prediction on patch. These transforms are applied to the images that will go through full inference. Please notice that the patch size will be the same as for training, a sliding windows approach will be used to predict the whole image. """ pass @abstractmethod def get_post_full_inference_transforms(self): """Define post inference transforms to apply after reconstructed prediction on whole image. These transforms are applied to the images that have gone through full inference. They will apply to the whole image as reconstructed by the sliding windows and will apply whatever transform is necessary to create the final output to be saved to disk. """ pass @abstractmethod def get_val_metrics_transforms(self): """Define default transforms for validation metric calculation on a patch. These transforms are applied to the validation dataset after the images have gone through the forward pass, to prepare the output -- if needed -- for the validation metrics to be applied. """ pass @abstractmethod def get_test_metrics_transforms(self): """Define default transforms for testing metric calculation on a patch. These transforms are applied to the test dataset after the images have gone through the forward pass, to prepare the output -- if needed -- for the test metrics to be applied. """ pass
Ancestors
- abc.ABC
Subclasses
- qute.campaigns._campaigns.RestorationCampaignTransforms
- qute.campaigns._campaigns.SegmentationCampaignTransforms2D
- qute.campaigns._campaigns.SegmentationCampaignTransforms3D
- qute.campaigns._campaigns.SegmentationCampaignTransformsIDT2D
- qute.campaigns._campaigns.SegmentationCampaignTransformsIDT3D
- qute.campaigns._campaigns.SelfSupervisedRestorationCampaignTransforms
Methods
def get_inference_transforms(self)
-
Define inference transforms to predict on patch.
These transforms are applied to the output of the full inference (that is the full images, not the patches) to prepare them to be saved to disk as final inference images.
def get_post_full_inference_transforms(self)
-
Define post inference transforms to apply after reconstructed prediction on whole image.
These transforms are applied to the images that have gone through full inference. They will apply to the whole image as reconstructed by the sliding windows and will apply whatever transform is necessary to create the final output to be saved to disk.
def get_post_inference_transforms(self)
-
Define post inference transforms to apply after prediction on patch.
These transforms are applied to the images that will go through full inference. Please notice that the patch size will be the same as for training, a sliding windows approach will be used to predict the whole image.
def get_test_metrics_transforms(self)
-
Define default transforms for testing metric calculation on a patch.
These transforms are applied to the test dataset after the images have gone through the forward pass, to prepare the output – if needed – for the test metrics to be applied.
def get_test_transforms(self)
-
Return a composition of Transforms needed to test on a patch.
These transforms are applied to the test dataset, to prepare the inputs to be fed into the model for testing.
def get_train_transforms(self)
-
Return a composition of (dictionary) MapTransforms needed to train on a patch.
These transforms are applied to the training dataset, to prepare the inputs to be fed into the model for the forward pass of training.
def get_val_metrics_transforms(self)
-
Define default transforms for validation metric calculation on a patch.
These transforms are applied to the validation dataset after the images have gone through the forward pass, to prepare the output – if needed – for the validation metrics to be applied.
def get_valid_transforms(self)
-
Return a composition of Transforms needed to validate on a patch.
These transforms are applied to the validation dataset, to prepare the inputs to be fed into the model for validation.
class RestorationCampaignTransforms (min_intensity: int = 0, max_intensity: int = 65535, patch_size: tuple = (640, 640), num_patches: int = 1)
-
Example restoration campaign transforms.
Constructor.
By default, these transforms apply to a single-channel input image to predict a single-channel output.
Expand source code
class RestorationCampaignTransforms(CampaignTransforms): """Example restoration campaign transforms.""" def __init__( self, min_intensity: int = 0, max_intensity: int = 65535, patch_size: tuple = (640, 640), num_patches: int = 1, ): """Constructor. By default, these transforms apply to a single-channel input image to predict a single-channel output. """ super().__init__() self.min_intensity = min_intensity self.max_intensity = max_intensity self.patch_size = patch_size self.num_patches = num_patches def get_train_transforms(self): """Return a composition of Transforms needed to train (patch).""" train_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, ), RandSpatialCropd( keys=("image", "label"), roi_size=self.patch_size, random_size=False, ), MinMaxNormalized( keys=("image", "label"), min_intensity=self.min_intensity, max_intensity=self.max_intensity, ), ToPyTorchLightningOutputd( image_dtype=torch.float32, label_dtype=torch.float32, ), ] ) return train_transforms def get_valid_transforms(self): """Return a composition of Transforms needed to validate (patch).""" val_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, ), RandSpatialCropd( keys=("image", "label"), roi_size=self.patch_size, random_size=False, ), MinMaxNormalized( keys=("image", "label"), min_intensity=self.min_intensity, max_intensity=self.max_intensity, ), ToPyTorchLightningOutputd( image_dtype=torch.float32, label_dtype=torch.float32, ), ] ) return val_transforms def get_test_transforms(self): """Return a composition of Transforms needed to test (patch).""" return self.get_valid_transforms() def get_inference_transforms(self): """Define inference transforms to predict (patch).""" inference_transforms = Compose( [ CustomTIFFReader( ensure_channel_first=True, dtype=torch.float32, ), MinMaxNormalize( min_intensity=self.min_intensity, max_intensity=self.max_intensity ), ] ) return inference_transforms def get_post_inference_transforms(self): """Define post inference transforms to apply after prediction on patch.""" post_inference_transforms = Compose( [ Scale( factor=self.max_intensity, dtype=torch.int32, ) ] ) return post_inference_transforms def get_post_full_inference_transforms(self): """Define post full-inference transforms to apply after reconstructed prediction on whole image.""" return self.get_post_inference_transforms() def get_val_metrics_transforms(self): """Define default transforms for validation metric calculation (patch).""" val_metrics_transforms = Compose([]) return val_metrics_transforms def get_test_metrics_transforms(self): """Define default transforms for testing metric calculation (patch).""" return self.get_val_metrics_transforms()
Ancestors
- qute.campaigns._campaigns.CampaignTransforms
- abc.ABC
Methods
def get_inference_transforms(self)
-
Define inference transforms to predict (patch).
def get_post_full_inference_transforms(self)
-
Define post full-inference transforms to apply after reconstructed prediction on whole image.
def get_post_inference_transforms(self)
-
Define post inference transforms to apply after prediction on patch.
def get_test_metrics_transforms(self)
-
Define default transforms for testing metric calculation (patch).
def get_test_transforms(self)
-
Return a composition of Transforms needed to test (patch).
def get_train_transforms(self)
-
Return a composition of Transforms needed to train (patch).
def get_val_metrics_transforms(self)
-
Define default transforms for validation metric calculation (patch).
def get_valid_transforms(self)
-
Return a composition of Transforms needed to validate (patch).
class SegmentationCampaignTransforms2D (num_classes: int = 3, patch_size: tuple = (640, 640), num_patches: int = 1)
-
Example segmentation campaign transforms.
Constructor.
By default, these transforms apply to a single-channel input image to predict three output classes.
Expand source code
class SegmentationCampaignTransforms2D(CampaignTransforms): """Example segmentation campaign transforms.""" def __init__( self, num_classes: int = 3, patch_size: tuple = (640, 640), num_patches: int = 1 ): """Constructor. By default, these transforms apply to a single-channel input image to predict three output classes. """ super().__init__() self.num_classes = num_classes self.patch_size = patch_size self.num_patches = num_patches def get_train_transforms(self): """Return a composition of Transforms needed to train (patch).""" train_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, ), RandRotated( keys=("image", "label"), prob=0.75, range_x=0.4, mode=["bilinear", "nearest"], padding_mode="reflection", ), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.patch_size, pos=1.0, neg=1.0, num_samples=self.num_patches, image_key="image", image_threshold=0.0, allow_smaller=False, lazy=False, ), ZNormalized(keys=("image",)), RandRotate90d(keys=("image", "label"), prob=0.5, spatial_axes=(-2, -1)), RandGaussianNoised(keys=("image",), prob=0.2), RandGaussianSmoothd(keys=("image",), prob=0.2), AsDiscreted(keys=["label"], to_onehot=self.num_classes), ToPyTorchLightningOutputd(), ] ) return train_transforms def get_valid_transforms(self): """Return a composition of Transforms needed to validate (patch).""" val_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, ), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.patch_size, pos=1.0, neg=1.0, num_samples=self.num_patches, image_key="image", image_threshold=0.0, allow_smaller=False, lazy=False, ), ZNormalized(keys=("image",)), AsDiscreted(keys=["label"], to_onehot=self.num_classes), ToPyTorchLightningOutputd(), ] ) return val_transforms def get_test_transforms(self): """Return a composition of Transforms needed to test (patch).""" return self.get_valid_transforms() def get_inference_transforms(self): """Define inference transforms to predict (patch).""" inference_transforms = Compose( [ CustomTIFFReader( ensure_channel_first=True, dtype=torch.float32, ), ZNormalize(), ] ) return inference_transforms def get_post_inference_transforms(self): """Define post inference transforms to apply after prediction on patch.""" post_inference_transforms = Compose([OneHotToMaskBatch()]) return post_inference_transforms def get_post_full_inference_transforms(self): """Define post full-inference transforms to apply after reconstructed prediction on whole image.""" return self.get_post_inference_transforms() def get_val_metrics_transforms(self): """Define default transforms for validation metric calculation (patch).""" val_metrics_transforms = Compose( [Activations(sigmoid=True), AsDiscrete(threshold=0.5)] ) return val_metrics_transforms def get_test_metrics_transforms(self): """Define default transforms for testing metric calculation (patch).""" return self.get_val_metrics_transforms()
Ancestors
- qute.campaigns._campaigns.CampaignTransforms
- abc.ABC
Methods
def get_inference_transforms(self)
-
Define inference transforms to predict (patch).
def get_post_full_inference_transforms(self)
-
Define post full-inference transforms to apply after reconstructed prediction on whole image.
def get_post_inference_transforms(self)
-
Define post inference transforms to apply after prediction on patch.
def get_test_metrics_transforms(self)
-
Define default transforms for testing metric calculation (patch).
def get_test_transforms(self)
-
Return a composition of Transforms needed to test (patch).
def get_train_transforms(self)
-
Return a composition of Transforms needed to train (patch).
def get_val_metrics_transforms(self)
-
Define default transforms for validation metric calculation (patch).
def get_valid_transforms(self)
-
Return a composition of Transforms needed to validate (patch).
class SegmentationCampaignTransforms3D (num_classes: int = 3, patch_size: tuple = (20, 300, 300), num_patches: int = 1, voxel_size: tuple[float, float, float] = (1.0, 1.0, 1.0), to_isotropic: bool = False, upscale_z: bool = True)
-
Example 3D segmentation campaign transforms.
Constructor.
By default, these transforms apply to a single-channel input image to predict three output classes.
Parameters
num_classes
:int = 3
- Number ouf output classes to predict.
patch_size
:tuple = (20, 300, 300)
- Patch size to pass through the neural network.
num_patches
:int = 1
- Number of patches per image to extract.
voxel_size
:Optional[tuple] = (1.0, 1.0, 1.0)
- Voxel size to use for setting the metadata of the image. Omit to set to (1.0, 1.0, 1.0).
to_isotropic
:bool (Optional, False)
- Se to True to resample the image to near-isotropic XYZ resolution. Ignored if all voxel sizes are the same.
upscale_z
:bool = True
-
Only considered it
to_isotropic
is True. If True, interpolate z to reach the resolution of x and y. Please notice that it is assumed that the z resolution is worse than the x and y resolution.If False, sub-sample x and y to reach the resolution of z. Please notice that it is assumed that the z resolution is worse than the x and y resolution.
Expand source code
class SegmentationCampaignTransforms3D(CampaignTransforms): """Example 3D segmentation campaign transforms.""" def __init__( self, num_classes: int = 3, patch_size: tuple = (20, 300, 300), num_patches: int = 1, voxel_size: tuple[float, float, float] = (1.0, 1.0, 1.0), to_isotropic: bool = False, upscale_z: bool = True, ): """Constructor. By default, these transforms apply to a single-channel input image to predict three output classes. PARAMETERS ---------- num_classes: int = 3 Number ouf output classes to predict. patch_size: tuple = (20, 300, 300) Patch size to pass through the neural network. num_patches: int = 1 Number of patches per image to extract. voxel_size: Optional[tuple] = (1.0, 1.0, 1.0) Voxel size to use for setting the metadata of the image. Omit to set to (1.0, 1.0, 1.0). to_isotropic: bool (Optional, False) Se to True to resample the image to near-isotropic XYZ resolution. Ignored if all voxel sizes are the same. upscale_z: bool = True Only considered it `to_isotropic` is True. If True, interpolate z to reach the resolution of x and y. Please notice that it is assumed that the z resolution is worse than the x and y resolution. If False, sub-sample x and y to reach the resolution of z. Please notice that it is assumed that the z resolution is worse than the x and y resolution. """ super().__init__() self.num_classes = num_classes self.patch_size = patch_size self.num_patches = num_patches self.voxel_size = np.array(voxel_size) self.target_voxel_size = self.voxel_size.copy() self.to_isotropic = to_isotropic self.upscale_z = upscale_z if self.to_isotropic: # Should we upscale the image to keep the higher resolution, or downscale it # to preserve the lowest resolution? if self.upscale_z: # x and y are left untouched; z is scaled up to # match (rounded) anisotropic resolution self.target_voxel_size[0] = self.target_voxel_size[1:2].mean() else: # z is left untouched; x and y are scaled down # to match (rounded) anisotropic resolution self.target_voxel_size[1:] = self.target_voxel_size[0] def get_train_transforms(self): """Return a composition of Transforms needed to train (patch).""" train_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, as_meta_tensor=True, voxel_size=self.voxel_size, ), CustomResamplerd( keys=("image", "label"), target_voxel_size=self.target_voxel_size, input_voxel_size=self.voxel_size, mode=("trilinear", "nearest"), ), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.patch_size, pos=1.0, neg=1.0, num_samples=self.num_patches, image_key="image", image_threshold=0.0, allow_smaller=False, lazy=False, ), LabelToTwoClassMaskd(keys=("label",), border_thickness=1), ZNormalized(keys=("image",)), RandGaussianNoised(keys=("image",), prob=0.2), RandGaussianSmoothd(keys=("image",), prob=0.2), AsDiscreted(keys=["label"], to_onehot=self.num_classes), ToPyTorchLightningOutputd(), ] ) return train_transforms def get_valid_transforms(self): """Return a composition of Transforms needed to validate (patch).""" val_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, as_meta_tensor=True, voxel_size=self.voxel_size, ), CustomResamplerd( keys=("image", "label"), target_voxel_size=self.target_voxel_size, input_voxel_size=self.voxel_size, mode=("trilinear", "nearest"), ), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.patch_size, pos=1.0, neg=1.0, num_samples=self.num_patches, image_key="image", image_threshold=0.0, allow_smaller=False, lazy=False, ), LabelToTwoClassMaskd(keys=("label",), border_thickness=1), ZNormalized(keys=("image",)), AsDiscreted(keys=["label"], to_onehot=self.num_classes), ToPyTorchLightningOutputd(), ] ) return val_transforms def get_test_transforms(self): """Return a composition of Transforms needed to test (patch).""" return self.get_valid_transforms() def get_inference_transforms(self): """Define inference transforms to predict (patch).""" inference_transforms = Compose( [ CustomTIFFReader( ensure_channel_first=True, dtype=torch.float32, as_meta_tensor=True, voxel_size=self.voxel_size, ), CustomResampler( target_voxel_size=self.target_voxel_size, input_voxel_size=self.voxel_size, mode="trilinear", ), ZNormalize(), ] ) return inference_transforms def get_post_inference_transforms(self): """Define post inference transforms to apply after prediction on patch.""" post_inference_transforms = Compose( [ OneHotToMaskBatch(), ] ) return post_inference_transforms def get_post_full_inference_transforms(self): """Define post full-inference transforms to apply after prediction on patch.""" if self.to_isotropic: post_full_inference_transforms = Compose( [ OneHotToMaskBatch(), CustomResampler( target_voxel_size=self.voxel_size, input_voxel_size=self.target_voxel_size, mode="nearest", with_batch_dim=True, ), ] ) else: post_full_inference_transforms = Compose( [ OneHotToMaskBatch(), ] ) return post_full_inference_transforms def get_val_metrics_transforms(self): """Define default transforms for validation metric calculation (patch).""" val_metrics_transforms = Compose( [Activations(sigmoid=True), AsDiscrete(threshold=0.5)] ) return val_metrics_transforms def get_test_metrics_transforms(self): """Define default transforms for testing metric calculation (patch).""" return self.get_val_metrics_transforms()
Ancestors
- qute.campaigns._campaigns.CampaignTransforms
- abc.ABC
Methods
def get_inference_transforms(self)
-
Define inference transforms to predict (patch).
def get_post_full_inference_transforms(self)
-
Define post full-inference transforms to apply after prediction on patch.
def get_post_inference_transforms(self)
-
Define post inference transforms to apply after prediction on patch.
def get_test_metrics_transforms(self)
-
Define default transforms for testing metric calculation (patch).
def get_test_transforms(self)
-
Return a composition of Transforms needed to test (patch).
def get_train_transforms(self)
-
Return a composition of Transforms needed to train (patch).
def get_val_metrics_transforms(self)
-
Define default transforms for validation metric calculation (patch).
def get_valid_transforms(self)
-
Return a composition of Transforms needed to validate (patch).
class SegmentationCampaignTransformsIDT2D (patch_size: tuple = (640, 640), num_patches: int = 1)
-
Example 2D segmentation campaign transforms using regression to Inverse Distance Transform.
Constructor.
By default, these transforms apply to a single-channel input image to predict three output classes.
Expand source code
class SegmentationCampaignTransformsIDT2D(CampaignTransforms): """Example 2D segmentation campaign transforms using regression to Inverse Distance Transform.""" def __init__(self, patch_size: tuple = (640, 640), num_patches: int = 1): """Constructor. By default, these transforms apply to a single-channel input image to predict three output classes. """ super().__init__() self.patch_size = patch_size self.num_patches = num_patches def get_train_transforms(self): """Return a composition of Transforms needed to train (patch).""" train_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, ), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.patch_size, pos=1.0, neg=0.0, num_samples=self.num_patches, image_key="image", image_threshold=0.0, allow_smaller=False, lazy=False, ), ZNormalized(keys=("image",)), RandRotate90d(keys=("image", "label"), prob=0.5, spatial_axes=(-2, -1)), RandGaussianNoised(keys=("image",), prob=0.2), RandGaussianSmoothd(keys=("image",), prob=0.2), NormalizedDistanceTransformd( keys=("label",), reverse=True, do_not_zero=True, add_seed_channel=True, seed_radius=2, ), ToPyTorchLightningOutputd(label_key="label", label_dtype=torch.float32), ] ) return train_transforms def get_valid_transforms(self): """Return a composition of Transforms needed to validate (patch).""" val_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, ), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.patch_size, pos=1.0, neg=0.0, num_samples=self.num_patches, image_key="image", image_threshold=0.0, allow_smaller=False, lazy=False, ), ZNormalized(keys=("image",)), NormalizedDistanceTransformd( keys=("label",), reverse=True, do_not_zero=True, add_seed_channel=True, seed_radius=2, ), ToPyTorchLightningOutputd(label_key="label", label_dtype=torch.float32), ] ) return val_transforms def get_test_transforms(self): """Return a composition of Transforms needed to test (patch).""" return self.get_valid_transforms() def get_inference_transforms(self): """Define inference transforms to predict (patch).""" inference_transforms = Compose( [ CustomTIFFReader( ensure_channel_first=True, dtype=torch.float32, ), ZNormalize(), ] ) return inference_transforms def get_post_inference_transforms(self): """Define post inference transforms to apply after prediction on patch.""" post_inference_transforms = Compose( [ WatershedAndLabelTransform( use_seed_channel=True, dt_threshold=0.02, with_batch_dim=True ) ] ) return post_inference_transforms def get_post_full_inference_transforms(self): """Define post full-inference transforms to apply after reconstructed prediction on whole image.""" return self.get_post_inference_transforms() def get_val_metrics_transforms(self): """Define default transforms for validation metric calculation (patch).""" val_metrics_transforms = Compose([]) return val_metrics_transforms def get_test_metrics_transforms(self): """Define default transforms for testing metric calculation (patch).""" return self.get_val_metrics_transforms()
Ancestors
- qute.campaigns._campaigns.CampaignTransforms
- abc.ABC
Methods
def get_inference_transforms(self)
-
Define inference transforms to predict (patch).
def get_post_full_inference_transforms(self)
-
Define post full-inference transforms to apply after reconstructed prediction on whole image.
def get_post_inference_transforms(self)
-
Define post inference transforms to apply after prediction on patch.
def get_test_metrics_transforms(self)
-
Define default transforms for testing metric calculation (patch).
def get_test_transforms(self)
-
Return a composition of Transforms needed to test (patch).
def get_train_transforms(self)
-
Return a composition of Transforms needed to train (patch).
def get_val_metrics_transforms(self)
-
Define default transforms for validation metric calculation (patch).
def get_valid_transforms(self)
-
Return a composition of Transforms needed to validate (patch).
class SegmentationCampaignTransformsIDT3D (patch_size: tuple = (20, 300, 300), num_patches: int = 1, voxel_size: tuple[float, float, float] = (1.0, 1.0, 1.0), to_isotropic: bool = False, upscale_z: bool = True)
-
Example 3D segmentation campaign transforms using regression to Inverse Distance Transform.
Constructor.
By default, these transforms apply to a single-channel input image to predict three output classes.
Parameters
patch_size
:tuple = (20, 300, 300)
- Patch size to pass through the neural network.
num_patches
:int = 1
- Number of patches per image to extract.
voxel_size
:Optional[tuple] = (1.0, 1.0, 1.0)
- Voxel size to use for setting the metadata of the image. Omit to set to (1.0, 1.0, 1.0).
to_isotropic
:bool (Optional, False)
- Se to True to resample the image to near-isotropic XYZ resolution. Ignored if all voxel sizes are the same.
upscale_z
:bool = True
-
Only considered it
to_isotropic
is True. If True, interpolate z to reach the resolution of x and y. Please notice that it is assumed that the z resolution is worse than the x and y resolution.If False, sub-sample x and y to reach the resolution of z. Please notice that it is assumed that the z resolution is worse than the x and y resolution.
Expand source code
class SegmentationCampaignTransformsIDT3D(CampaignTransforms): """Example 3D segmentation campaign transforms using regression to Inverse Distance Transform.""" def __init__( self, patch_size: tuple = (20, 300, 300), num_patches: int = 1, voxel_size: tuple[float, float, float] = (1.0, 1.0, 1.0), to_isotropic: bool = False, upscale_z: bool = True, ): """Constructor. By default, these transforms apply to a single-channel input image to predict three output classes. PARAMETERS ---------- patch_size: tuple = (20, 300, 300) Patch size to pass through the neural network. num_patches: int = 1 Number of patches per image to extract. voxel_size: Optional[tuple] = (1.0, 1.0, 1.0) Voxel size to use for setting the metadata of the image. Omit to set to (1.0, 1.0, 1.0). to_isotropic: bool (Optional, False) Se to True to resample the image to near-isotropic XYZ resolution. Ignored if all voxel sizes are the same. upscale_z: bool = True Only considered it `to_isotropic` is True. If True, interpolate z to reach the resolution of x and y. Please notice that it is assumed that the z resolution is worse than the x and y resolution. If False, sub-sample x and y to reach the resolution of z. Please notice that it is assumed that the z resolution is worse than the x and y resolution. """ super().__init__() self.patch_size = patch_size self.num_patches = num_patches self.voxel_size = np.array(voxel_size) self.target_voxel_size = self.voxel_size.copy() self.to_isotropic = to_isotropic self.upscale_z = upscale_z if self.to_isotropic: # Should we upscale the image to keep the higher resolution, or downscale it # to preserve the lowest resolution? if self.upscale_z: # x and y are left untouched; z is scaled up to # match (rounded) anisotropic resolution self.target_voxel_size[0] = self.target_voxel_size[1:2].mean() else: # z is left untouched; x and y are scaled down # to match (rounded) anisotropic resolution self.target_voxel_size[1:] = self.target_voxel_size[0] def get_train_transforms(self): """Return a composition of Transforms needed to train (patch).""" train_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, as_meta_tensor=True, voxel_size=self.voxel_size, ), CustomResamplerd( keys=("image", "label"), target_voxel_size=self.target_voxel_size, input_voxel_size=self.voxel_size, mode=("trilinear", "nearest"), ), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.patch_size, pos=1.0, neg=1.0, num_samples=self.num_patches, image_key="image", image_threshold=0.0, allow_smaller=False, lazy=False, ), ZNormalized(keys=("image",)), RandGaussianNoised(keys=("image",), prob=0.2), RandGaussianSmoothd(keys=("image",), prob=0.2), NormalizedDistanceTransformd( keys=("label",), reverse=True, do_not_zero=True, add_seed_channel=True, seed_radius=2, ), ToPyTorchLightningOutputd(label_key="label", label_dtype=torch.float32), ] ) return train_transforms def get_valid_transforms(self): """Return a composition of Transforms needed to validate (patch).""" val_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "label"), ensure_channel_first=True, dtype=torch.float32, as_meta_tensor=True, voxel_size=self.voxel_size, ), CustomResamplerd( keys=("image", "label"), target_voxel_size=self.target_voxel_size, input_voxel_size=self.voxel_size, mode=("trilinear", "nearest"), ), RandCropByPosNegLabeld( keys=("image", "label"), label_key="label", spatial_size=self.patch_size, pos=1.0, neg=1.0, num_samples=self.num_patches, image_key="image", image_threshold=0.0, allow_smaller=False, lazy=False, ), ZNormalized(keys=("image",)), NormalizedDistanceTransformd( keys=("label",), reverse=True, do_not_zero=True, add_seed_channel=True, seed_radius=2, ), ToPyTorchLightningOutputd(), ] ) return val_transforms def get_test_transforms(self): """Return a composition of Transforms needed to test (patch).""" return self.get_valid_transforms() def get_inference_transforms(self): """Define inference transforms to predict (patch).""" inference_transforms = Compose( [ CustomTIFFReader( ensure_channel_first=True, dtype=torch.float32, as_meta_tensor=True, voxel_size=self.voxel_size, ), CustomResampler( target_voxel_size=self.target_voxel_size, input_voxel_size=self.voxel_size, mode="trilinear", ), ZNormalize(), ] ) return inference_transforms def get_post_inference_transforms(self): """Define post inference transforms to apply after prediction on patch.""" post_inference_transforms = Compose( [ WatershedAndLabelTransform( use_seed_channel=True, dt_threshold=0.05, with_batch_dim=True ), ] ) return post_inference_transforms def get_post_full_inference_transforms(self): """Define post full-inference transforms to apply after prediction on patch.""" if self.to_isotropic: post_full_inference_transforms = Compose( [ # DebugExtractChannel( # channel_num=0, # mask=False, # to_binary=False # ), WatershedAndLabelTransform( use_seed_channel=True, dt_threshold=0.02, with_batch_dim=True ), CustomResampler( target_voxel_size=self.voxel_size, input_voxel_size=self.target_voxel_size, mode="nearest", with_batch_dim=True, ), ] ) else: post_full_inference_transforms = Compose( [ WatershedAndLabelTransform( use_seed_channel=True, dt_threshold=0.02, with_batch_dim=True ) ] ) return post_full_inference_transforms def get_val_metrics_transforms(self): """Define default transforms for validation metric calculation (patch).""" val_metrics_transforms = Compose([]) return val_metrics_transforms def get_test_metrics_transforms(self): """Define default transforms for testing metric calculation (patch).""" return self.get_val_metrics_transforms()
Ancestors
- qute.campaigns._campaigns.CampaignTransforms
- abc.ABC
Methods
def get_inference_transforms(self)
-
Define inference transforms to predict (patch).
def get_post_full_inference_transforms(self)
-
Define post full-inference transforms to apply after prediction on patch.
def get_post_inference_transforms(self)
-
Define post inference transforms to apply after prediction on patch.
def get_test_metrics_transforms(self)
-
Define default transforms for testing metric calculation (patch).
def get_test_transforms(self)
-
Return a composition of Transforms needed to test (patch).
def get_train_transforms(self)
-
Return a composition of Transforms needed to train (patch).
def get_val_metrics_transforms(self)
-
Define default transforms for validation metric calculation (patch).
def get_valid_transforms(self)
-
Return a composition of Transforms needed to validate (patch).
class SelfSupervisedRestorationCampaignTransforms (patch_size: tuple = (640, 640), num_patches: int = 1)
-
Example self-supervised restoration campaign transforms.
Constructor.
By default, these transforms apply to a single-channel input image to predict a single-channel output.
Expand source code
class SelfSupervisedRestorationCampaignTransforms(CampaignTransforms): """Example self-supervised restoration campaign transforms.""" def __init__( self, patch_size: tuple = (640, 640), num_patches: int = 1, ): """Constructor. By default, these transforms apply to a single-channel input image to predict a single-channel output. """ super().__init__() self.patch_size = patch_size self.num_patches = num_patches def get_train_transforms(self): """Return a composition of Transforms needed to train (patch).""" train_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "target"), ensure_channel_first=True, dtype=torch.float32, ), RandRotated( keys=("image", "target"), prob=0.75, range_x=0.4, mode=["bilinear", "bilinear"], padding_mode="reflection", ), ZNormalized( keys=("image", "target"), ), RandGaussianNoised( keys=("image",), prob=1.0, mean=0.0, std=0.5, ), RandSpatialCropd( keys=("image", "target"), roi_size=self.patch_size, random_size=False, ), ToPyTorchLightningOutputd( image_key="image", label_key="target", image_dtype=torch.float32, label_dtype=torch.float32, ), ] ) return train_transforms def get_valid_transforms(self): """Return a composition of Transforms needed to validate (patch).""" val_transforms = Compose( [ CustomTIFFReaderd( keys=("image", "target"), ensure_channel_first=True, dtype=torch.float32, ), ZNormalized( keys=("image", "target"), ), RandGaussianNoised( keys=("image",), prob=1.0, mean=0.0, std=0.5, ), RandSpatialCropd( keys=("image", "target"), roi_size=self.patch_size, random_size=False, ), ToPyTorchLightningOutputd( image_key="image", label_key="target", image_dtype=torch.float32, label_dtype=torch.float32, ), ] ) return val_transforms def get_test_transforms(self): """Return a composition of Transforms needed to test (patch).""" return self.get_valid_transforms() def get_inference_transforms(self): """Define inference transforms to predict (patch).""" inference_transforms = Compose( [ CustomTIFFReader( ensure_channel_first=True, dtype=torch.float32, ), ZNormalize(), RandGaussianNoise( prob=1.0, mean=0.0, std=0.5, ), ] ) return inference_transforms def get_post_inference_transforms(self): """Define post inference transforms to apply after prediction on patch.""" post_inference_transforms = Compose() return post_inference_transforms def get_post_full_inference_transforms(self): """Define post full-inference transforms to apply after reconstructed prediction on whole image.""" return self.get_post_inference_transforms() def get_val_metrics_transforms(self): """Define default transforms for validation metric calculation (patch).""" val_metrics_transforms = Compose([]) return val_metrics_transforms def get_test_metrics_transforms(self): """Define default transforms for testing metric calculation (patch).""" return self.get_val_metrics_transforms()
Ancestors
- qute.campaigns._campaigns.CampaignTransforms
- abc.ABC
Methods
def get_inference_transforms(self)
-
Define inference transforms to predict (patch).
def get_post_full_inference_transforms(self)
-
Define post full-inference transforms to apply after reconstructed prediction on whole image.
def get_post_inference_transforms(self)
-
Define post inference transforms to apply after prediction on patch.
def get_test_metrics_transforms(self)
-
Define default transforms for testing metric calculation (patch).
def get_test_transforms(self)
-
Return a composition of Transforms needed to test (patch).
def get_train_transforms(self)
-
Return a composition of Transforms needed to train (patch).
def get_val_metrics_transforms(self)
-
Define default transforms for validation metric calculation (patch).
def get_valid_transforms(self)
-
Return a composition of Transforms needed to validate (patch).