Module qute.models.unet
UNet.
Classes
class UNet (campaign_transforms: qute.campaigns._campaigns.CampaignTransforms, spatial_dims: int = 2, in_channels: int = 1, out_channels: int = 3, class_names: Optional[tuple] = None, channels=(16, 32, 64), strides: Optional[tuple] = None, criterion=DiceCELoss( (dice): DiceLoss() (cross_entropy): CrossEntropyLoss() (binary_cross_entropy): BCEWithLogitsLoss() ), metrics=<monai.metrics.meandice.DiceMetric object>, learning_rate: float = 0.01, optimizer_class=torch.optim.adamw.AdamW, lr_scheduler_class=torch.optim.lr_scheduler.PolynomialLR, lr_scheduler_parameters: dict = {'total_iters': 100, 'power': 0.95}, num_res_units: int = 0, dropout: float = 0.0)
-
Wrap MONAI's UNet architecture into a PyTorch Lightning module.
The default settings are compatible with a classification task, where a single-channel input image is transformed into a three-class label image.
Constructor.
Parameters
campaign_transforms
:CampaignTransforms
- Define all transforms necessary for training, validation, testing and (full) prediction.
@see
qute.transforms.CampaignTransforms
for documentation. spatial_dims
:int = 2
- Whether 2D or 3D data.
in_channels
:int = 1
- Number of input channels.
out_channels
:int = 3
- Number of output channels (or labels, or classes)
class_names
:Optional[tuple] = None
- Names of the output classes (for logging purposes). If omitted, they will default to ("class_1", "class_2", …)
channels
:tuple = (16, 32, 64)
- Number of neuron per layer.
strides
:Optional[tuple] = (2, 2)
- Strides for down-sampling.
criterion
:DiceCELoss(include_background=False, to_onehot_y=False, softmax=True)
- Loss function. Please NOTE: for classification, the loss function must convert
y
to OneHot. The default loss function applies to a multi-label target where the background class is omitted. metrics
:DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
-
Metrics used for validation and test. Set to None to omit.
The default metrics applies to a three-label target where the background (index = 0) class is omitted from calculation.
learning_rate
:float = 1e-2
- Learning rate for optimization.
optimizer_class=AdamW Optimizer.
lr_scheduler_class=PolynomialLR Learning rate scheduler.
lr_scheduler_parameters={"total_iters": 100, "power": 0.99} Dictionary of scheduler parameters.
num_res_units
:int = 0
- Number of residual units for the UNet.
dropout
:float = 0.0
- Dropout ratio.
Expand source code
class UNet(pl.LightningModule): """Wrap MONAI's UNet architecture into a PyTorch Lightning module. The default settings are compatible with a classification task, where a single-channel input image is transformed into a three-class label image. """ def __init__( self, campaign_transforms: CampaignTransforms, spatial_dims: int = 2, in_channels: int = 1, out_channels: int = 3, class_names: Optional[tuple] = None, channels=(16, 32, 64), strides: Optional[tuple] = None, criterion=DiceCELoss(include_background=True, to_onehot_y=False, softmax=True), metrics=DiceMetric( include_background=True, reduction="mean", get_not_nans=False ), learning_rate: float = 1e-2, optimizer_class=AdamW, lr_scheduler_class=PolynomialLR, lr_scheduler_parameters: dict = {"total_iters": 100, "power": 0.95}, num_res_units: int = 0, dropout: float = 0.0, ): """ Constructor. Parameters ---------- campaign_transforms: CampaignTransforms Define all transforms necessary for training, validation, testing and (full) prediction. @see `qute.transforms.CampaignTransforms` for documentation. spatial_dims: int = 2 Whether 2D or 3D data. in_channels: int = 1 Number of input channels. out_channels: int = 3 Number of output channels (or labels, or classes) class_names: Optional[tuple] = None Names of the output classes (for logging purposes). If omitted, they will default to ("class_1", "class_2", ...) channels: tuple = (16, 32, 64) Number of neuron per layer. strides: Optional[tuple] = (2, 2) Strides for down-sampling. criterion: DiceCELoss(include_background=False, to_onehot_y=False, softmax=True) Loss function. Please NOTE: for classification, the loss function must convert `y` to OneHot. The default loss function applies to a multi-label target where the background class is omitted. metrics: DiceMetric(include_background=False, reduction="mean", get_not_nans=False) Metrics used for validation and test. Set to None to omit. The default metrics applies to a three-label target where the background (index = 0) class is omitted from calculation. learning_rate: float = 1e-2 Learning rate for optimization. optimizer_class=AdamW Optimizer. lr_scheduler_class=PolynomialLR Learning rate scheduler. lr_scheduler_parameters={"total_iters": 100, "power": 0.99} Dictionary of scheduler parameters. num_res_units: int = 0 Number of residual units for the UNet. dropout: float = 0.0 Dropout ratio. """ super().__init__() self.campaign_transforms = campaign_transforms self.criterion = criterion self.metrics = metrics self.learning_rate = learning_rate self.optimizer_class = optimizer_class self.scheduler_class = lr_scheduler_class self.scheduler_parameters = lr_scheduler_parameters if class_names is None: class_names = list(f"class_{i}" for i in range(out_channels)) self.class_names = class_names if strides is None: strides = (2,) * (len(channels) - 1) self.net = MonaiUNet( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, channels=channels, strides=strides, num_res_units=num_res_units, dropout=dropout, ) # Log the hyperparameters self.save_hyperparameters(ignore=["criterion", "metrics"]) def configure_optimizers(self): """Configure and return the optimizer and scheduler.""" optimizer = self.optimizer_class(self.parameters(), lr=self.learning_rate) scheduler = { "scheduler": self.scheduler_class(optimizer, **self.scheduler_parameters), "monitor": "val_loss", "interval": "step", # Call "scheduler.step()" after every batch (1 step) "frequency": 1, # Update scheduler after every step "strict": True, # Ensures the scheduler is strictly followed (PyTorch Lightning parameter) } return [optimizer], [scheduler] def training_step(self, batch, batch_idx): """Perform a training step.""" x, y = batch y_hat = self.net(x) loss = self.criterion(y_hat, y) self.log("loss", loss, on_step=True, on_epoch=True, prog_bar=True) return {"loss": loss} def validation_step(self, batch, batch_idx): """Perform a validation step.""" x, y = batch y_hat = self.net(x) val_loss = self.criterion(y_hat, y) # Log the loss self.log( "val_loss", val_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) # Update the metrics if needed if self.metrics is not None: if self.campaign_transforms.get_val_metrics_transforms() is not None: val_metrics = self.metrics( self.campaign_transforms.get_val_metrics_transforms()(y_hat), y ) else: val_metrics = self.metrics(y_hat, y) # Compute and log the mean metrics score per class mean_val_per_class = val_metrics.nanmean(dim=0) # Do we have more than one output classes? if len(self.class_names) > 1: # Make sure to log the correct class name in case the background is not # considered in the calculation start = len(self.class_names) - val_metrics.shape[1] for i, val_score in enumerate(mean_val_per_class): self.log( f"val_metrics_{self.class_names[start + i]}", torch.tensor([val_score]), on_step=False, on_epoch=True, ) else: self.log( "val_metrics", torch.tensor([mean_val_per_class]), on_step=False, on_epoch=True, ) return {"val_loss": val_loss} def test_step(self, batch, batch_idx): """Perform a test step.""" x, y = batch y_hat = self.net(x) test_loss = self.criterion(y_hat, y) self.log("test_loss", test_loss) if self.metrics is not None: if self.campaign_transforms.get_test_metrics_transforms() is not None: test_metrics = self.metrics( self.campaign_transforms.get_test_metrics_transforms()(y_hat), y ) else: test_metrics = self.metrics(y_hat, y) # Compute and log the mean metrics score per class mean_test_per_class = test_metrics.nanmean(dim=0) # Do we have more than one output classes? if len(self.class_names) > 1: # Make sure to log the correct class name in case the background is not # considered in the calculation start = len(self.class_names) - test_metrics.shape[1] for i, test_score in enumerate(mean_test_per_class): self.log( f"test_metrics_{self.class_names[start + i]}", torch.tensor([test_score]), on_step=False, on_epoch=True, ) else: self.log( "test_metrics", torch.tensor([mean_test_per_class]), on_step=False, on_epoch=True, ) return {"test_loss": test_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0): """The predict step creates a label image from the output one-hot tensor.""" x, _ = batch y_hat = self.net(x) if self.campaign_transforms.get_post_inference_transforms() is not None: label = self.campaign_transforms.get_post_inference_transforms()(y_hat) else: label = y_hat return label def full_inference( self, data_loader: DataLoader, target_folder: Union[Path, str], roi_size: Tuple[int, ...], batch_size: int, overlap: float = 0.25, transpose: bool = True, output_dtype: Optional[Union[str, np.dtype]] = None, prefix: str = "pred_", ): """Run inference on full images using given model. Parameters ---------- data_loader: DataLoader DataLoader for the image files names to be predicted on. target_folder: Union[Path|str] Path to the folder where to store the predicted images. roi_size: Tuple[int, int] Size of the patch for the sliding window prediction. It must match the patch size during training. batch_size: int Number of parallel batches to run. overlap: float Fraction of overlap between rois. transpose: bool Whether the transpose the image before saving, to compensate for the default behavior of monai.transforms.LoadImage(). output_dtype: Optional[np.dtype] Optional NumPy dtype for the output image. Omit to save the output of inference without casting. prefix: str = "pred_" Prefix to append to the file name. Set to "" to keep the original file name. Returns ------- result: bool True if the inference was successful, False otherwise. """ # Make sure the target folder exists Path(target_folder).mkdir(parents=True, exist_ok=True) # Retrieve file names from the dataloader input_file_names = data_loader.dataset.dataset.data if len(input_file_names) == 0: print("No input files provided to process. Quitting.") return # Device device = get_device() # Make sure the model is on the device self.net.to(device) # Switch to evaluation mode self.net.eval() # Instantiate the inferer sliding_window_inferer = SlidingWindowInferer( roi_size=roi_size, sw_batch_size=batch_size, overlap=overlap, mode=BlendMode.GAUSSIAN, sigma_scale=0.125, device=device, ) # Process all images c = 0 with torch.no_grad(): for images in data_loader: # Apply sliding inference over ROI size outputs = sliding_window_inferer( inputs=images.to(device), network=self.net, ) # Apply post-transforms? outputs = self.campaign_transforms.get_post_full_inference_transforms()( outputs ) # Retrieve the image from the GPU (if needed) preds = outputs.cpu().numpy() # Process one batch at a time for pred in preds: # Drop the channel singleton dimension if pred.shape[0] == 1: pred = pred.squeeze(0) if transpose: # Transpose to undo the effect of monai.transform.LoadImage(d) pred = pred.T # Type-cast if needed if output_dtype is not None: # Make sure not to wrap around if np.issubdtype(output_dtype, np.integer): info = np.iinfo(output_dtype) pred[pred < info.min] = info.min pred[pred > info.max] = info.max pred = pred.astype(output_dtype) # Save prediction image as tiff file output_name = ( Path(target_folder) / f"{prefix}{input_file_names[c].stem}.tif" ) c += 1 with TiffWriter(output_name) as tif: tif.write(pred) # Inform print(f"Saved {output_name}.") print(f"Prediction completed.") # Return success return True @staticmethod def full_inference_ensemble( models: list, data_loader: DataLoader, target_folder: Union[Path, str], post_full_inference_transforms: Transform, roi_size: Tuple[int, ...], batch_size: int, voting_mechanism: str = "mode", weights: Optional[list] = None, overlap: float = 0.25, transpose: bool = True, save_individual_preds: bool = False, output_dtype: Optional[Union[str, np.dtype]] = None, prefix: str = "pred_", ensemble_prefix: str = "ensemble_", ): """Run inference on full images using given model. Parameters ---------- models: list List of trained UNet models to use for ensemble prediction. data_loader: DataLoader DataLoader for the image files names to be predicted on. target_folder: Union[Path|str] Path to the folder where to store the predicted images. post_full_inference_transforms: Transform Composition of transforms to be applied to the result of the sliding window inference (whole image). roi_size: Tuple[int, int] Size of the patch for the sliding window prediction. It must match the patch size during training. batch_size: int Number of parallel batches to run. voting_mechanism: str = "mode" Voting mechanism to assign the final class among the predictions from the ensemble of models. One of "mode" (default" and "mean"). "mode": pick the most common class among the predictions for each pixel. "mean": (rounded) weighted mean of the predicted classed per pixel. The `weights` argument defines the relative contribution of the models. weights: Optional[list] List of weights for each of the contributions. Only used if `voting_mechanism` is "mean". overlap: float Fraction of overlap between rois. transpose: bool Whether the transpose the image before saving, to compensate for the default behavior of monai.transforms.LoadImage(). save_individual_preds: bool Whether to save the individual predictions of each model. output_dtype: Optional[np.dtype] Optional NumPy dtype for the output image. Omit to save the output of inference without casting. prefix: str = "pred_" Prefix to append to the file name. Set to "" to keep the original file name. ensemble_prefix: str = "ensemble_pred_" Prefix to append to the ensemble prediction file name. Set to "" to keep the original file name. Returns ------- result: bool True if the inference was successful, False otherwise. """ if voting_mechanism not in ["mode", "mean"]: raise ValueError("`voting mechanism` must be one of 'mode' or 'mean'.") if voting_mechanism == "mean": if len(models) != len(weights): raise ValueError( "The number of weights must match the number of models." ) # Turn the weights into a NumPy array of float 32 bit weights = np.array(weights, dtype=np.float32) weights = weights / weights.sum() if not isinstance(models[0], UNet) or not hasattr(models[0], "net"): raise ValueError("The models must be of type `qute.models.unet.UNet`.") # Make sure the target folder exists Path(target_folder).mkdir(parents=True, exist_ok=True) # Retrieve file names from the dataloader input_file_names = data_loader.dataset.dataset.data if len(input_file_names) == 0: print("No input files provided to process. Quitting.") return # If needed, create the sub-folders for the individual predictions if save_individual_preds: for f in range(len(models)): fold_subfolder = Path(target_folder) / f"fold_{f}" Path(fold_subfolder).mkdir(parents=True, exist_ok=True) # Device device = get_device() # Switch to evaluation mode on all models for model in models: model.net.eval() # Instantiate the inferer sliding_window_inferer = SlidingWindowInferer( roi_size=roi_size, sw_batch_size=batch_size, overlap=overlap, mode=BlendMode.GAUSSIAN, sigma_scale=0.125, device=device, ) c = 0 with torch.no_grad(): for images in data_loader: predictions = [[] for _ in range(len(models))] # Not process all models for n, model in enumerate(models): # Make sure the model is on the device model.to(device) # Apply sliding inference over ROI size outputs = sliding_window_inferer( inputs=images.to(device), network=model.net, ) # Apply post-transforms? outputs = post_full_inference_transforms(outputs) # Retrieve the image from the GPU (if needed) preds = outputs.cpu().numpy() stored_preds = [] for pred in preds: # Drop the channel singleton dimension if pred.shape[0] == 1: pred = pred.squeeze(0) if transpose: # Transpose to undo the effect of monai.transform.LoadImage(d) pred = pred.T stored_preds.append(pred) # Store predictions[n] = stored_preds # Iterate over all images in the batch pred_dim = len(models) batch_dim = len(predictions[0]) for b in range(batch_dim): # Apply selected voting mechanism if voting_mechanism == "mean": # Apply weighted mean (and rounding) of the predictions per pixel # Iterate over all predictions from the models for p in range(pred_dim): if p == 0: ensemble_pred = weights[p] * predictions[p][b] else: ensemble_pred += weights[p] * predictions[p][b] ensemble_pred = np.round(ensemble_pred).astype(np.int32) elif voting_mechanism == "mode": # Select the mode of the predictions per pixel # Store predictions in a stack target = np.zeros( ( pred_dim, predictions[0][0].shape[0], predictions[0][0].shape[1], ) ) # Iterate over all predictions from the models for p in range(pred_dim): target[p, :, :] = predictions[p][b] values, _ = torch.mode(torch.tensor(target), dim=0) ensemble_pred = values.numpy().astype(np.int32) else: raise ValueError( "`voting mechanism` must be one of 'mode' or 'mean'." ) # Type-cast if needed if output_dtype is not None: # Make sure not to wrap around if np.issubdtype(output_dtype, np.integer): info = np.iinfo(output_dtype) ensemble_pred[ensemble_pred < info.min] = info.min ensemble_pred[ensemble_pred > info.max] = info.max ensemble_pred = ensemble_pred.astype(output_dtype) # Save ensemble prediction image as tiff file output_name = ( Path(target_folder) / f"{ensemble_prefix}{input_file_names[c].stem}.tif" ) with TiffWriter(output_name) as tif: tif.write(ensemble_pred) # Inform print(f"Saved {output_name}.") # Save individual predictions? if save_individual_preds: # Iterate over all predictions from the models for p in range(len(predictions)): # Save prediction image as tiff file output_name = ( Path(target_folder) / f"fold_{p}" / f"{prefix}{input_file_names[c].stem}.tif" ) # Get current prediction current_pred = predictions[p][b] # Type-cast if needed if output_dtype is not None: # Make sure not to wrap around if np.issubdtype(output_dtype, np.integer): info = np.iinfo(output_dtype) current_pred[current_pred < info.min] = info.min current_pred[current_pred > info.max] = info.max current_pred = current_pred.astype(output_dtype) # Save with TiffWriter(output_name) as tif: tif.write(current_pred) # Update global file counter c c += 1 print(f"Ensemble prediction completed.") # Return success return True @staticmethod def load_from_checkpoint_and_swap_output_layer( checkpoint_path: Union[Path, str], new_out_channels: int, new_campaign_transforms: CampaignTransforms, new_criterion, new_metrics, class_names: tuple[str, ...], previous_out_channels: int = 1, strict: bool = True, verbose: bool = False, map_location: Optional[torch.device] = None, ): """Load a model from a checkpoint and modify it by replacing the last Conv2d layer with a new Conv2d layer that has a specified number of output channels. Parameters ---------- checkpoint_path: Union[Path, str] Full path to the checkpoint file to load the model from. new_out_channels: int The number of output channels for the new Conv2d layer. new_campaign_transforms: CampaignTransforms, New CampaignTransform for the loaded model. new_criterion: loss function New criterion for the loaded model. new_metrics: metrics New metrics for the loaded model. class_names: tuple[str, ...] Class names for the new outputs. previous_out_channels: int = 1 Number of output channels in the last convolutional layer of the loaded model. Since this method expects a regression model, previous_out_channels defaults to 1, but it should work also for a different number of output channels. strict: bool: True Set to True for strict loading of the model (all modules and parameters must match). verbose: bool = False Set to True for verbose info when scanning the model. Use this if something goes wrong and you want to report an issue. map_location: Optional[torch.device] The device to map the model's weights to when loading the checkpoint. Default is None, which means the model is loaded to the current device. Returns ------- model: The model with the last Conv2d layer replaced by a new Conv2d layer with the specified number of output channels. """ # Check inputs if new_out_channels != len(class_names): raise ValueError( f"Please provide a valid number of class names ({new_out_channels})." ) # Load the model from checkpoint model = UNet.load_from_checkpoint( checkpoint_path=checkpoint_path, map_location=map_location, strict=strict, campaign_transforms=new_campaign_transforms, criterion=new_criterion, metrics=new_metrics, ) # Debug: assert that the campaign was replaced assert model.campaign_transforms == new_campaign_transforms # Debug: assert that the criterion was replaced assert model.criterion == new_criterion # Debug: assert that the metrics was replaced assert model.metrics == new_metrics # List to store all matching Conv2d layers matching_layers = [] # Helper function to collect all matching Conv2d layers def collect_matching_conv2d_layers( module: nn.Module, previous_out_channels: int, depth: int = 0, parent_name="", ): for name, child in module.named_children(): full_name = f"{parent_name}.{name}" if parent_name else name if isinstance(child, nn.Conv2d): if verbose: print( f"Found Conv2d layer ('{full_name}') with {child.out_channels} output channel(s)" ) if child.out_channels == previous_out_channels: matching_layers.append((module, name, full_name, child)) else: collect_matching_conv2d_layers( child, previous_out_channels, depth + 1, full_name ) # Collect all matching Conv2d layers collect_matching_conv2d_layers(model, previous_out_channels) # Ensure we found at least one matching layer if not matching_layers: raise ValueError( f"No Conv2d layer with {previous_out_channels} channel(s) found." ) # Get the last matching layer parent_module, name, full_name, last_conv_layer = matching_layers[-1] in_channels = last_conv_layer.in_channels out_channels = last_conv_layer.out_channels # Print the identified last Conv2d layer if needed if verbose: print( f"Found {len(matching_layers)} module(s) with {in_channels} input channel(s) and {out_channels} output channel(s)." ) print( f"Replacing last Conv2d layer ('{full_name}') with {in_channels} input channel(s) and {new_out_channels} output channel(s)." ) # Assert that the last Conv2d's output channels match the expected previous_out_channels assert ( out_channels == previous_out_channels ), f"Expected last Conv2d output channels to be {previous_out_channels}, but got {out_channels}" # Create the new Conv2d layer and initialize its weights new_conv = nn.Conv2d(in_channels, new_out_channels, kernel_size=3, padding=1) nn.init.kaiming_normal_(new_conv.weight) if new_conv.bias is not None: nn.init.constant_(new_conv.bias, 0) # Replace the last Conv2d layer setattr(parent_module, name, new_conv) # Set the new class names model.class_names = class_names # Update hyperparameters to reflect new output channels and class names model.hparams.out_channels = new_out_channels model.hparams.class_names = class_names # Log the updated hyperparameters (including criterion and metrics) model.save_hyperparameters() # Return the loaded and modified model return model
Ancestors
- pytorch_lightning.core.module.LightningModule
- lightning_fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
- pytorch_lightning.core.hooks.ModelHooks
- pytorch_lightning.core.hooks.DataHooks
- pytorch_lightning.core.hooks.CheckpointHooks
- torch.nn.modules.module.Module
Subclasses
Static methods
def full_inference_ensemble(models: list, data_loader: monai.data.dataloader.DataLoader, target_folder: Union[pathlib.Path, str], post_full_inference_transforms: monai.transforms.transform.Transform, roi_size: Tuple[int, ...], batch_size: int, voting_mechanism: str = 'mode', weights: Optional[list] = None, overlap: float = 0.25, transpose: bool = True, save_individual_preds: bool = False, output_dtype: Union[str, numpy.dtype, ForwardRef(None)] = None, prefix: str = 'pred_', ensemble_prefix: str = 'ensemble_')
-
Run inference on full images using given model.
Parameters
models
:list
- List of trained UNet models to use for ensemble prediction.
data_loader
:DataLoader
- DataLoader for the image files names to be predicted on.
target_folder
:Union[Path|str]
- Path to the folder where to store the predicted images.
post_full_inference_transforms
:Transform
- Composition of transforms to be applied to the result of the sliding window inference (whole image).
roi_size
:Tuple[int, int]
- Size of the patch for the sliding window prediction. It must match the patch size during training.
batch_size
:int
- Number of parallel batches to run.
voting_mechanism
:str = "mode"
- Voting mechanism to assign the final class among the predictions from the ensemble of models.
One of "mode" (default" and "mean").
"mode": pick the most common class among the predictions for each pixel.
"mean": (rounded) weighted mean of the predicted classed per pixel. The
weights
argument defines the relative contribution of the models. weights
:Optional[list]
- List of weights for each of the contributions. Only used if
voting_mechanism
is "mean". overlap
:float
- Fraction of overlap between rois.
transpose
:bool
- Whether the transpose the image before saving, to compensate for the default behavior of monai.transforms.LoadImage().
save_individual_preds
:bool
- Whether to save the individual predictions of each model.
output_dtype
:Optional[np.dtype]
- Optional NumPy dtype for the output image. Omit to save the output of inference without casting.
prefix
:str = "pred_"
- Prefix to append to the file name. Set to "" to keep the original file name.
ensemble_prefix
:str = "ensemble_pred_"
- Prefix to append to the ensemble prediction file name. Set to "" to keep the original file name.
Returns
result
:bool
- True if the inference was successful, False otherwise.
def load_from_checkpoint_and_swap_output_layer(checkpoint_path: Union[pathlib.Path, str], new_out_channels: int, new_campaign_transforms: qute.campaigns._campaigns.CampaignTransforms, new_criterion, new_metrics, class_names: tuple[str, ...], previous_out_channels: int = 1, strict: bool = True, verbose: bool = False, map_location: Optional[torch.device] = None)
-
Load a model from a checkpoint and modify it by replacing the last Conv2d layer with a new Conv2d layer that has a specified number of output channels.
Parameters
checkpoint_path
:Union[Path, str]
- Full path to the checkpoint file to load the model from.
new_out_channels
:int
- The number of output channels for the new Conv2d layer.
new_campaign_transforms
:CampaignTransforms,
- New CampaignTransform for the loaded model.
new_criterion
:loss function
- New criterion for the loaded model.
new_metrics
:metrics
- New metrics for the loaded model.
class_names
:tuple[str, …]
- Class names for the new outputs.
previous_out_channels
:int = 1
- Number of output channels in the last convolutional layer of the loaded model. Since this method expects a regression model, previous_out_channels defaults to 1, but it should work also for a different number of output channels.
strict
:bool: True
- Set to True for strict loading of the model (all modules and parameters must match).
verbose
:bool = False
- Set to True for verbose info when scanning the model. Use this if something goes wrong and you want to report an issue.
map_location
:Optional[torch.device]
- The device to map the model's weights to when loading the checkpoint. Default is None, which means the model is loaded to the current device.
Returns
model: The model with the last Conv2d layer replaced by a new Conv2d layer with the specified number of output channels.
Methods
def configure_optimizers(self)
-
Configure and return the optimizer and scheduler.
def full_inference(self, data_loader: monai.data.dataloader.DataLoader, target_folder: Union[pathlib.Path, str], roi_size: Tuple[int, ...], batch_size: int, overlap: float = 0.25, transpose: bool = True, output_dtype: Union[str, numpy.dtype, ForwardRef(None)] = None, prefix: str = 'pred_')
-
Run inference on full images using given model.
Parameters
data_loader
:DataLoader
- DataLoader for the image files names to be predicted on.
target_folder
:Union[Path|str]
- Path to the folder where to store the predicted images.
roi_size
:Tuple[int, int]
- Size of the patch for the sliding window prediction. It must match the patch size during training.
batch_size
:int
- Number of parallel batches to run.
overlap
:float
- Fraction of overlap between rois.
transpose
:bool
- Whether the transpose the image before saving, to compensate for the default behavior of monai.transforms.LoadImage().
output_dtype
:Optional[np.dtype]
- Optional NumPy dtype for the output image. Omit to save the output of inference without casting.
prefix
:str = "pred_"
- Prefix to append to the file name. Set to "" to keep the original file name.
Returns
result
:bool
- True if the inference was successful, False otherwise.
def predict_step(self, batch, batch_idx, dataloader_idx=0)
-
The predict step creates a label image from the output one-hot tensor.
def test_step(self, batch, batch_idx)
-
Perform a test step.
def training_step(self, batch, batch_idx)
-
Perform a training step.
def validation_step(self, batch, batch_idx)
-
Perform a validation step.