Module qute.director

Training directors.

Classes

class CellRestorationDemoDirector (config_file: Union[pathlib.Path, str])

Restoration Demo Training Director.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
Expand source code
class CellRestorationDemoDirector(RestorationDirector):
    """Restoration Demo Training Director."""

    @override
    def _setup_data_module(self):
        """Set up data module."""

        # Data module
        data_module = CellRestorationDemo(
            campaign_transforms=self.campaign_transforms,
            download_dir=self.config.project_dir,
            seed=self.config.seed,
            batch_size=self.config.batch_size,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
            train_fraction=self.config.train_fraction,
            val_fraction=self.config.val_fraction,
            test_fraction=self.config.test_fraction,
            inference_batch_size=self.config.inference_batch_size,
        )

        # Return data module
        return data_module

Ancestors

  • qute.director._director.RestorationDirector
  • qute.director._director.Director
  • abc.ABC
class CellSegmentationDemoDirector (config_file: Union[pathlib.Path, str])

Segmentation Demo Training Director.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
Expand source code
class CellSegmentationDemoDirector(SegmentationDirector):
    """Segmentation Demo Training Director."""

    @override
    def _setup_data_module(self):
        """Set up data module."""

        # Data module
        data_module = CellSegmentationDemo(
            campaign_transforms=self.campaign_transforms,
            download_dir=self.config.project_dir,
            seed=self.config.seed,
            batch_size=self.config.batch_size,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
            train_fraction=self.config.train_fraction,
            val_fraction=self.config.val_fraction,
            test_fraction=self.config.test_fraction,
            inference_batch_size=self.config.inference_batch_size,
        )

        # Return data module
        return data_module

Ancestors

  • qute.director._director.SegmentationDirector
  • qute.director._director.Director
  • abc.ABC
class Director (config_file: Union[pathlib.Path, str])

Abstract base class defining the interface for all directors.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
Expand source code
class Director(ABC):
    """Abstract base class defining the interface for all directors."""

    def __init__(self, config_file: Union[Path, str]) -> None:
        """Constructor.

        Parameters
        ----------

        config_file: Union[Path, str]
            Full path to the configuration file.
        """

        # Check if the director is instantiated from a program entry point
        frame = inspect.currentframe()
        cont = True
        while cont:
            caller_module = frame.f_globals["__name__"]
            if caller_module == "qute.director._director":
                frame = frame.f_back
            else:
                cont = False
        caller_module = frame.f_globals["__name__"]
        if caller_module != "__main__":
            print(
                "Warning: the Director is not instantiated from the program entry point (`__main__`).",
                file=sys.stderr,
            )
            print(
                "This could cause issues in particular on macOS and Windows.",
                file=sys.stderr,
            )

        # Store the configuration file
        self.config_file = config_file

        # Parse it
        self.config = Config(self.config_file)
        if not self.config.parse():
            raise Exception("Invalid config file")

        # Keep a reference to the project
        self.project = None

        # Keep references to other important objects
        self.campaign_transforms = None
        self.data_module = None
        self.criterion = None
        self.metrics = None
        self.lr_scheduler_class = None
        self.lr_scheduler_parameters = None
        self.early_stopping = None
        self.model_checkpoint = None
        self.lr_monitor = None
        self.training_callbacks = []
        self.trainer = None
        self.model = None
        self.steps_per_epoch = 0

    @abstractmethod
    def _setup_campaign_transforms(self):
        """Set up campaign transforms."""
        raise NotImplementedError("Reimplement in child class.")

    @abstractmethod
    def _setup_data_module(self):
        """Set up data module."""
        raise NotImplementedError("Reimplement in child class.")

    @abstractmethod
    def _setup_loss(self):
        """Set up loss function."""
        raise NotImplementedError("Reimplement in child class.")

    @abstractmethod
    def _setup_metrics(self):
        """Set up metrics."""
        raise NotImplementedError("Reimplement in child class.")

    def run(self):
        """Run the process as configured in the configuration file."""

        # Seeding
        set_global_rng_seed(self.config.seed, workers=True)

        # Get the mode
        if self.config.trainer_mode == "train":
            self._train()

        elif self.config.trainer_mode == "resume":
            self._resume()

        elif self.config.trainer_mode == "predict":
            self._predict()

        else:
            raise ValueError(
                "Trainer mode must be one of 'train' or 'resume' or 'predict'."
            )

    def _train(self):
        """Run a training from scratch."""

        # Set up components common to train and resume
        self._setup_basis_for_training_and_resume()

        # Set up model
        self.model = self._setup_model()

        # Run common training and testing operations for 'train' and 'resume' trained modes.
        self._run_common_train_and_test()

    def _resume(self):
        """Resume training from a saved state."""

        # Set up components common to train and resume
        self._setup_basis_for_training_and_resume()

        # Load specified model
        model_class = self._get_model_class()
        self.model = model_class.load_from_checkpoint(
            self.project.selected_model_path,
            criterion=self.criterion,
            metrics=self.metrics,
            class_names=self.config.class_names,
        )

        # Run common training and testing operations for 'train' and 'resume' trained modes.
        self._run_common_train_and_test()

    def _predict(self):
        """Predict using a trained model."""

        # Check that the config contains the path to the model to load
        if self.config.source_model_path is None:
            raise ValueError("No path for model to load found in the configuration!")

        # Check that the source path for prediction is specified in the configuration
        if self.config.source_for_prediction is None:
            raise ValueError(
                "No source path for prediction specified in the configuration!"
            )

        # Set up project
        self.project = Project(self.config)

        # Check that the model exists
        if not self.config.source_model_path.is_file():
            raise ValueError(
                f"The model {self.config.source_model_path} does not exist!"
            )

        # Initialize the campaign transforms
        self.campaign_transforms = self._setup_campaign_transforms()

        # Initialize data module
        self.data_module = self._setup_data_module()

        # Load existing model
        model_class = self._get_model_class()
        self.model = model_class.load_from_checkpoint(self.config.source_model_path)

        # Inform
        print(f"Predicting with model {self.config.source_model_path}")

        # Display target folder
        if self.config.target_for_prediction is None:
            print("Target for prediction not specified in configuration.")
        print(f"Predictions saved to {self.project.target_for_prediction}.")

        # Run full inference
        self.model.full_inference(
            data_loader=self.data_module.inference_dataloader(
                input_folder=self.project.source_for_prediction
            ),
            target_folder=self.project.target_for_prediction,
            roi_size=self.config.patch_size,
            batch_size=self.config.inference_batch_size,
            transpose=False,
            output_dtype=self.config.output_dtype,
        )

    def _setup_basis_for_training_and_resume(self):
        """Initialize the basic components for training or resume."""

        # Set up the project
        self.project = self._setup_project()

        # Set up the transform campaign
        self.campaign_transforms = self._setup_campaign_transforms()

        # Set up the data module
        self.data_module = self._setup_data_module()

        # Inform
        print(f"Working directory: {self.project.run_dir}")

        # Calculate the number of steps per epoch
        self.data_module.prepare_data()
        self.data_module.setup("train")
        self.steps_per_epoch = len(self.data_module.train_dataloader())

        # Print the train, validation and test sets to file
        self.data_module.print_sets(filename=self.project.run_dir / "image_sets.txt")

        # Set up loss function
        self.criterion = self._setup_loss()

        # Set up metrics
        self.metrics = self._setup_metrics()

        # Set up trainer callbacks
        (
            self.training_callbacks,
            self.early_stopping,
            self.model_checkpoint,
            self.lr_monitor,
        ) = self._setup_trainer_callbacks()

        # Set up the scheduler
        self.lr_scheduler_class, self.lr_scheduler_parameters = self._setup_scheduler()

        # Set up trainer
        self.trainer = self._setup_trainer()

    def _run_common_train_and_test(self):
        """Run common training and testing operations for 'train' and
        'resume' trained modes."""

        # Copy the configuration file to the run folder
        self.project.copy_configuration_file()

        # Train
        self.trainer.fit(self.model, datamodule=self.data_module)

        # Print path to best model
        print(f"Best model: {self.model_checkpoint.best_model_path}")
        print(f"Best model score: {self.model_checkpoint.best_model_score}")

        # Store the best score
        self.project.store_best_score(
            self.config.checkpoint_monitor, self.model_checkpoint.best_model_score
        )

        # Set it into the project
        self.project.selected_model_path = self.model_checkpoint.best_model_path

        # Re-load weights from best model
        model_class = self._get_model_class()
        model = model_class.load_from_checkpoint(
            self.project.selected_model_path,
            strict=False,
            criterion=self.criterion,
            metrics=self.metrics,
        )

        # Test
        self.trainer.test(model, dataloaders=self.data_module.test_dataloader())

        # If there is no source_for_prediction in the configuration file,
        # we inform and skip the full inference
        if self.config.source_for_prediction is None:
            print(
                "Source for prediction not specified in configuration. Skipping full inference."
            )
            return

        # Display target folder
        if self.config.target_for_prediction is None:
            print("Target for prediction not specified in configuration.")
        print(f"Predictions saved to {self.project.target_for_prediction}.")

        # Run full inference
        self.model.full_inference(
            data_loader=self.data_module.inference_dataloader(
                input_folder=self.project.source_for_prediction
            ),
            target_folder=self.project.target_for_prediction,
            roi_size=self.config.patch_size,
            batch_size=self.config.inference_batch_size,
            transpose=False,
            output_dtype=self.config.output_dtype,
        )

    def _setup_project(self):
        """Set up the project."""
        if self.config is None:
            raise Exception("No configuration found.")

        # Initialize the project
        project = Project(self.config, clean=True)

        # Return the project
        return project

    def _setup_scheduler(self):
        """Set up scheduler."""
        if self.data_module is None:
            raise Exception("Data module is not set.")

        # Set up learning rate scheduler
        lr_scheduler_class = OneCycleLR
        lr_scheduler_parameters = {
            "total_steps": self.steps_per_epoch * self.config.max_epochs,
            "div_factor": 5.0,
            "max_lr": self.config.learning_rate,
            "pct_start": 0.5,
            "anneal_strategy": "cos",
        }

        return lr_scheduler_class, lr_scheduler_parameters

    def _setup_trainer_callbacks(self):
        """Set up trainer callbacks."""

        # Callbacks
        if self.config.use_early_stopping:
            early_stopping = EarlyStopping(
                monitor=self.config.checkpoint_monitor,
                patience=self.config.early_stopping_patience,
                mode=self.config.checkpoint_mode,
                verbose=True,
            )
        else:
            early_stopping = None
        model_checkpoint = ModelCheckpoint(
            dirpath=self.project.models_dir,
            monitor=self.config.checkpoint_monitor,
            mode=self.config.checkpoint_mode,
            verbose=True,
        )
        lr_monitor = LearningRateMonitor(logging_interval="step")

        # Add them to the training callbacks list
        if self.config.use_early_stopping:
            training_callbacks = [
                early_stopping,
                model_checkpoint,
                lr_monitor,
            ]
        else:
            training_callbacks = [
                model_checkpoint,
                lr_monitor,
            ]

        return training_callbacks, early_stopping, model_checkpoint, lr_monitor

    def _setup_trainer(self):
        """Set up the trainer."""
        # Instantiate the Trainer
        trainer = pl.Trainer(
            default_root_dir=self.project.results_dir,
            accelerator=device.get_accelerator(),
            devices=1,
            precision=self.config.precision,
            callbacks=self.training_callbacks,
            max_epochs=self.config.max_epochs,
            log_every_n_steps=1,
        )

        # Store parameters
        trainer.logger._default_hp_metric = False

        # Return the trainer
        return trainer

    def _setup_model(self):
        """Set up the model."""

        # Get and check the model from the configuration
        model = self.config.model_class
        if model not in ["unet", "attention_unet", "swin_unetr"]:
            raise ValueError(
                "The 'model' must be one of 'unet', 'attention_unet', or 'swin_unetr'."
            )

        if model == "unet":

            # Set up the UNet model
            model = UNet(
                campaign_transforms=self.campaign_transforms,
                spatial_dims=3 if self.config.is_3d else 2,
                in_channels=self.config.in_channels,
                out_channels=self.config.out_channels,
                class_names=self.config.class_names,
                num_res_units=self.config.num_res_units,
                criterion=self.criterion,
                channels=self.config.channels,
                strides=self.config.strides,
                metrics=self.metrics,
                learning_rate=self.config.learning_rate,
                lr_scheduler_class=self.lr_scheduler_class,
                lr_scheduler_parameters=self.lr_scheduler_parameters,
            )

        elif model == "attention_unet":

            # Set up the Attention UNet model
            model = AttentionUNet(
                campaign_transforms=self.campaign_transforms,
                spatial_dims=3 if self.config.is_3d else 2,
                in_channels=self.config.in_channels,
                out_channels=self.config.out_channels,
                class_names=self.config.class_names,
                criterion=self.criterion,
                channels=self.config.channels,
                strides=self.config.strides,
                metrics=self.metrics,
                learning_rate=self.config.learning_rate,
                lr_scheduler_class=self.lr_scheduler_class,
                lr_scheduler_parameters=self.lr_scheduler_parameters,
            )

        elif model == "swin_unetr":

            # Set up the SwinUNETR model
            model = SwinUNETR(
                campaign_transforms=self.campaign_transforms,
                in_channels=self.config.in_channels,
                out_channels=self.config.out_channels,
                class_names=self.config.class_names,
                spatial_dims=3 if self.config.is_3d else 2,
                depths=self.config.depths,
                num_heads=self.config.num_heads,
                feature_size=self.config.feature_size,
                criterion=self.criterion,
                metrics=self.metrics,
                learning_rate=self.config.learning_rate,
                lr_scheduler_class=self.lr_scheduler_class,
                lr_scheduler_parameters=self.lr_scheduler_parameters,
            )

        else:
            raise ValueError(
                "The 'model' must be one of 'unet', 'attention_unet', or 'swin_unetr'."
            )

        # Inform
        print(f"Using model: {self._get_model_class()} ")

        # Return the model
        return model

    def _get_model_class(self):
        """Return the class of the model being used."""
        if self.config.model_class == "unet":
            model_class = UNet
        elif self.config.model_class == "attention_unet":
            model_class = AttentionUNet
        elif self.config.model_class == "swin_unetr":
            model_class = SwinUNETR
        else:
            raise ValueError(f"Bad value for model type {self.config.model_class};")
        return model_class

Ancestors

  • abc.ABC

Subclasses

  • qute.director._director.EnsembleDirector
  • qute.director._director.RestorationDirector
  • qute.director._director.SegmentationDirector

Methods

def run(self)

Run the process as configured in the configuration file.

class EnsembleCellSegmentationDemoDirector (config_file: Union[pathlib.Path, str], num_folds: int)

Ensemble Segmentation Demo Training Director.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
num_folds : int
Number of folds for cross-correlation validation.
Expand source code
class EnsembleCellSegmentationDemoDirector(EnsembleSegmentationDirector):
    """Ensemble Segmentation Demo Training Director."""

    @override
    def _setup_data_module(self):
        """Set up data module."""

        # Data module
        data_module = CellSegmentationDemo(
            campaign_transforms=self.campaign_transforms,
            download_dir=self.config.project_dir,
            seed=self.config.seed,
            num_folds=self.num_folds,
            batch_size=self.config.batch_size,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
            train_fraction=self.config.train_fraction,
            val_fraction=self.config.val_fraction,
            test_fraction=self.config.test_fraction,
            inference_batch_size=self.config.inference_batch_size,
        )

        # Return the data module
        return data_module

Ancestors

  • qute.director._director.EnsembleSegmentationDirector
  • qute.director._director.EnsembleDirector
  • qute.director._director.SegmentationDirector
  • qute.director._director.Director
  • abc.ABC
class EnsembleDirector (config_file: Union[pathlib.Path, str], num_folds: int)

Ensemble Training Director.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
num_folds : int
Number of folds for cross-correlation validation.
Expand source code
class EnsembleDirector(Director, ABC):
    """Ensemble Training Director."""

    def __init__(self, config_file: Union[Path, str], num_folds: int) -> None:
        """Constructor.

        Parameters
        ----------

        config_file: Union[Path, str]
            Full path to the configuration file.

        num_folds: int
            Number of folds for cross-correlation validation.
        """

        super().__init__(config_file=config_file)
        self.num_folds = num_folds
        self.current_fold = -1

        # Keep track of the trained models
        self._best_models = []

    def _setup_trainer_callbacks(self):
        """Set up trainer callbacks."""

        # Callbacks
        if self.config.use_early_stopping:
            early_stopping = EarlyStopping(
                monitor=self.config.checkpoint_monitor,
                patience=self.config.early_stopping_patience,
                mode=self.config.checkpoint_mode,
                verbose=True,
            )
        else:
            early_stopping = None
        model_checkpoint = ModelCheckpoint(
            dirpath=self.project.models_dir / f"fold_{self.current_fold}",
            monitor=self.config.checkpoint_monitor,
            mode=self.config.checkpoint_mode,
            verbose=True,
        )
        lr_monitor = LearningRateMonitor(logging_interval="step")

        # Add them to the training callbacks list
        if self.config.use_early_stopping:
            training_callbacks = [
                early_stopping,
                model_checkpoint,
                lr_monitor,
            ]
        else:
            training_callbacks = [
                model_checkpoint,
                lr_monitor,
            ]

        return training_callbacks, early_stopping, model_checkpoint, lr_monitor

    def _setup_trainer(self):
        """Set up the trainer."""
        # Instantiate the Trainer
        trainer = pl.Trainer(
            default_root_dir=self.project.results_dir / f"fold_{self.current_fold}",
            accelerator=device.get_accelerator(),
            devices=1,
            precision=self.config.precision,
            callbacks=self.training_callbacks,
            max_epochs=self.config.max_epochs,
            log_every_n_steps=1,
        )

        # Store parameters
        trainer.logger._default_hp_metric = False

        # Return the trainer
        return trainer

    def _setup_basis_for_training_and_resume(self):
        """Initialize the basic components for training or resume."""

        # Set up the project
        self.project = self._setup_project()

        # Set up the transform campaign
        self.campaign_transforms = self._setup_campaign_transforms()

        # Set up the data module
        self.data_module = self._setup_data_module()

        # Inform
        print(f"Working directory: {self.project.run_dir}")

        # Calculate the number of steps per epoch
        self.data_module.prepare_data()
        self.data_module.setup("train")
        self.steps_per_epoch = len(self.data_module.train_dataloader())

        # Print the train, validation and test sets to file
        self.data_module.print_sets(filename=self.project.run_dir / "image_sets.txt")

        # Set up loss function
        self.criterion = self._setup_loss()

        # Set up metrics
        self.metrics = self._setup_metrics()

        # Set up trainer callbacks
        (
            self.training_callbacks,
            self.early_stopping,
            self.model_checkpoint,
            self.lr_monitor,
        ) = self._setup_trainer_callbacks()

        # Set up the scheduler
        self.lr_scheduler_class, self.lr_scheduler_parameters = self._setup_scheduler()

        # Set up trainer
        self.trainer = self._setup_trainer()

    def _train(self):
        """Run a training from scratch."""

        # Set up components common to train and resume
        self._setup_basis_for_training_and_resume()

        # Copy the configuration file to the run folder
        self.project.copy_configuration_file()

        # Run training with n-fold cross-validation
        for fold in range(self.num_folds):

            # Store current fold
            self.current_fold = fold

            # Print path to best model
            print(f"Fold {fold}: starting training.")

            # Set the fold for current training
            self.data_module.set_fold(self.current_fold)

            # Update the number of steps per epoch
            self.steps_per_epoch = len(self.data_module.train_dataloader())

            # Reset the learning rate scheduler
            self.lr_scheduler_class, self.lr_scheduler_parameters = (
                self._setup_scheduler()
            )

            # Initialize new model
            self.model = self._setup_model()

            # Set up trainer callbacks
            (
                self.training_callbacks,
                self.early_stopping,
                self.model_checkpoint,
                self.lr_monitor,
            ) = self._setup_trainer_callbacks()

            # Instantiate the Trainer
            self.trainer = pl.Trainer(
                default_root_dir=self.project.results_dir / f"fold_{self.current_fold}",
                accelerator=device.get_accelerator(),
                devices=1,
                precision=self.config.precision,
                callbacks=self.training_callbacks,
                max_epochs=self.config.max_epochs,
                log_every_n_steps=1,
            )

            # Store parameters
            # trainer.hparams = { }
            self.trainer.logger._default_hp_metric = False

            # Train with the optimal learning rate found above
            self.trainer.fit(self.model, datamodule=self.data_module)

            # Print path to best model
            print(f"Fold {fold}: best model = {self.model_checkpoint.best_model_path}")
            print(
                f"Fold {fold}: best model score: {self.model_checkpoint.best_model_score}"
            )

            # Store the best score
            self.project.store_best_score(
                monitor=self.config.checkpoint_monitor,
                score=float(self.model_checkpoint.best_model_score),
                fold=self.current_fold,
            )

            # Set it into the project
            self.project.selected_model_path = self.model_checkpoint.best_model_path

            # Re-load weights from best model
            model_class = self._get_model_class()
            model = model_class.load_from_checkpoint(
                self.model_checkpoint.best_model_path,
                strict=False,
                criterion=self.criterion,
                metrics=self.metrics,
            )

            # Append to list of best models for inference
            self._best_models.append(model)

            # Test
            self.trainer.test(model, dataloaders=self.data_module.test_dataloader())

        # If there is no source_for_prediction in the configuration file,
        # we inform and skip the full inference
        if self.config.source_for_prediction is None:
            print(
                "Source for prediction not specified in configuration. Skipping full inference."
            )
            return

        # Display target folder
        if self.config.target_for_prediction is None:
            print("Target for prediction not specified in configuration.")
        print(f"Predictions saved to {self.project.target_for_prediction}.")

        # Run ensemble prediction
        UNet.full_inference_ensemble(
            self._best_models,
            data_loader=self.data_module.inference_dataloader(
                input_folder=self.project.source_for_prediction
            ),
            target_folder=self.project.target_for_prediction,
            post_full_inference_transforms=self.campaign_transforms.get_post_full_inference_transforms(),
            roi_size=self.config.patch_size,
            batch_size=self.config.inference_batch_size,
            transpose=False,
            save_individual_preds=True,
            voting_mechanism="mode",
            weights=None,
            output_dtype=self.config.output_dtype,
        )

    def _resume(self):
        """Resume ensemble training using trained models."""
        raise NotImplementedError("Currently not supported!")

    def _predict(self):
        """Predict using a trained model."""
        # Check that the config contains the path to the model to load
        if self.config.source_model_path is None:
            raise ValueError("No path for models to load found in the configuration!")

        # Check that the source path for prediction is specified in the configuration
        if self.config.source_for_prediction is None:
            raise ValueError(
                "No source path for prediction specified in the configuration!"
            )

        # Check the model path
        if not self.config.source_model_path.is_dir():
            raise ValueError("Invalid path for models!")

        # Set up project
        self.project = Project(self.config)

        # Initialize the campaign transforms
        self.campaign_transforms = self._setup_campaign_transforms()

        # Initialize data module
        self.data_module = self._setup_data_module()

        # Load all models
        models = self._load_models(models_dir=self.config.source_model_path)

        # Run ensemble prediction
        UNet.full_inference_ensemble(
            models,
            data_loader=self.data_module.inference_dataloader(
                input_folder=self.project.source_for_prediction
            ),
            target_folder=self.project.target_for_prediction,
            post_full_inference_transforms=self.campaign_transforms.get_post_full_inference_transforms(),
            roi_size=self.config.patch_size,
            batch_size=self.config.inference_batch_size,
            transpose=False,
            save_individual_preds=True,
            voting_mechanism="mode",
            weights=None,
            output_dtype=self.config.output_dtype,
        )

    def _load_models(self, models_dir: Path):
        """Reload all model found in the model folds."""

        # Re-load all (best) models
        models = []
        fold = 0
        found = True
        while found:

            # Look for the model for current fold
            found = list(models_dir.glob(f"fold_{fold}/*.ckpt"))
            if len(found) == 0:
                found = False
                continue

            # Try loading the model
            try:
                model = UNet.load_from_checkpoint(found[0])
            except:
                print(f"Could not load the trained model {found[0]} for fold {fold}!")
                continue

            # Add it to the list
            models.append(model)

            # Inform
            print(f"Fold {fold}: re-loaded model = {found[0]}")

            # Increase fold number
            fold += 1

        print(f"Loaded {len(models)} trained models.")
        return models

Ancestors

  • qute.director._director.Director
  • abc.ABC

Subclasses

  • qute.director._director.EnsembleSegmentationDirector
class EnsembleSegmentationDirector (config_file: Union[pathlib.Path, str], num_folds: int)

Ensemble Segmentation Training Director.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
num_folds : int
Number of folds for cross-correlation validation.
Expand source code
class EnsembleSegmentationDirector(EnsembleDirector, SegmentationDirector):
    """Ensemble Segmentation Training Director."""

    def __init__(self, config_file: Union[Path, str], num_folds: int) -> None:
        super().__init__(config_file, num_folds)

    def _setup_data_module(self):
        """Set up data module with folds."""

        # Data module
        data_module = DataModuleLocalFolder(
            campaign_transforms=self.campaign_transforms,
            data_dir=self.config.data_dir,  # Point to the root of the data directory
            seed=self.config.seed,
            num_folds=self.num_folds,  # Ensemble
            batch_size=self.config.batch_size,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
            train_fraction=self.config.train_fraction,
            val_fraction=self.config.val_fraction,
            test_fraction=self.config.test_fraction,
            source_images_sub_folder=self.config.source_images_sub_folder,
            target_images_sub_folder=self.config.target_images_sub_folder,
            source_images_label=self.config.source_images_label,
            target_images_label=self.config.target_images_label,
            inference_batch_size=self.config.inference_batch_size,
        )

        # Return data module
        return data_module

    def _setup_basis_for_training_and_resume(self):
        """Initialize the basic components for training or resume."""

        # Set up the project
        self.project = self._setup_project()

        # Set up the transform campaign
        self.campaign_transforms = self._setup_campaign_transforms()

        # Set up the data module (with folds)
        self.data_module = self._setup_data_module()

        # Calculate the number of steps per epoch
        self.data_module.prepare_data()
        self.data_module.setup("train")
        self.steps_per_epoch = len(self.data_module.train_dataloader())

        # Inform
        print(f"Working directory: {self.project.run_dir}")

        # Set up loss function
        self.criterion = self._setup_loss()

        # Set up metrics
        self.metrics = self._setup_metrics()

Ancestors

  • qute.director._director.EnsembleDirector
  • qute.director._director.SegmentationDirector
  • qute.director._director.Director
  • abc.ABC

Subclasses

  • qute.director._director.EnsembleCellSegmentationDemoDirector
class RestorationDirector (config_file: Union[pathlib.Path, str])

Restoration Training Director.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
Expand source code
class RestorationDirector(Director):
    """Restoration Training Director."""

    @override
    def _setup_metrics(self):
        """Set up metrics."""

        # Metrics
        metrics = MeanAbsoluteError()

        # Return metrics
        return metrics

    @override
    def _setup_loss(self):
        """Set up loss function."""

        # Set up loss function
        criterion = MSELoss()

        # Return loss
        return criterion

    @override
    def _setup_data_module(self):
        """Set up data module."""

        # Data module
        self.data_module = DataModuleLocalFolder(
            campaign_transforms=self.campaign_transforms,
            data_dir=self.config.data_dir,  # Point to the root of the data directory
            seed=self.config.seed,
            batch_size=self.config.batch_size,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
            train_fraction=self.config.train_fraction,
            val_fraction=self.config.val_fraction,
            test_fraction=self.config.test_fraction,
            source_images_sub_folder=self.config.source_images_sub_folder,
            target_images_sub_folder=self.config.target_images_sub_folder,
            source_images_label=self.config.source_images_label,
            target_images_label=self.config.target_images_label,
            inference_batch_size=self.config.inference_batch_size,
        )

    @override
    def _setup_campaign_transforms(self):
        """Set up campaign transforms."""

        # Initialize default, example Restoration Campaign Transform
        campaign_transforms = RestorationCampaignTransforms(
            min_intensity=0,
            max_intensity=15472,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
        )

        # Return campaign transforms
        return campaign_transforms

Ancestors

  • qute.director._director.Director
  • abc.ABC

Subclasses

  • qute.director._director.CellRestorationDemoDirector
class SegmentationDirector (config_file: Union[pathlib.Path, str])

Segmentation Training Director.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
Expand source code
class SegmentationDirector(Director):
    """Segmentation Training Director."""

    @override
    def _setup_metrics(self):
        """Set up metrics."""

        # Metrics
        metrics = DiceMetric(
            include_background=self.config.include_background,
            reduction="mean_batch",
            get_not_nans=False,
        )

        # Return metrics
        return metrics

    @override
    def _setup_loss(self):
        """Set up loss function."""

        # Set up loss function
        criterion = DiceCELoss(
            include_background=self.config.include_background,
            to_onehot_y=False,
            softmax=True,
        )

        # Return criterion
        return criterion

    @override
    def _setup_data_module(self):
        """Set up data module."""

        # Data module
        data_module = DataModuleLocalFolder(
            campaign_transforms=self.campaign_transforms,
            data_dir=self.config.data_dir,  # Point to the root of the data directory
            seed=self.config.seed,
            num_folds=1,  # No ensemble
            batch_size=self.config.batch_size,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
            train_fraction=self.config.train_fraction,
            val_fraction=self.config.val_fraction,
            test_fraction=self.config.test_fraction,
            source_images_sub_folder=self.config.source_images_sub_folder,
            target_images_sub_folder=self.config.target_images_sub_folder,
            source_images_label=self.config.source_images_label,
            target_images_label=self.config.target_images_label,
            inference_batch_size=self.config.inference_batch_size,
        )

        # Return data module
        return data_module

    @override
    def _setup_campaign_transforms(self):
        """Set up campaign transforms."""

        # Consistency check
        if self.config.is_3d:
            raise ValueError("Check the value of `is_3d` in the configuration file.")

        # Initialize default, example Segmentation Campaign Transform
        campaign_transforms = SegmentationCampaignTransforms2D(
            num_classes=self.config.out_channels,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
        )

        # Return the campaign transforms
        return campaign_transforms

Ancestors

  • qute.director._director.Director
  • abc.ABC

Subclasses

  • qute.director._director.CellSegmentationDemoDirector
  • qute.director._director.EnsembleSegmentationDirector
  • qute.director._director.SegmentationDirector3D
class SegmentationDirector3D (config_file: Union[pathlib.Path, str])

Segmentation 3D Training Director.

Constructor.

Parameters

config_file : Union[Path, str]
Full path to the configuration file.
Expand source code
class SegmentationDirector3D(SegmentationDirector):
    """Segmentation 3D Training Director."""

    @override
    def _setup_campaign_transforms(self):
        """Set up campaign transforms."""

        # Consistency check
        if not self.config.is_3d:
            raise ValueError("Check the value of `is_3d` in the configuration file.")

        # Initialize default, example Segmentation Campaign Transform
        campaign_transforms = SegmentationCampaignTransforms3D(
            num_classes=self.config.out_channels,
            patch_size=self.config.patch_size,
            num_patches=self.config.num_patches,
            voxel_size=self.config.voxel_size,
            to_isotropic=self.config.to_isotropic,
            upscale_z=self.config.up_scale_z,
        )

        # Return the campaign transforms
        return campaign_transforms

Ancestors

  • qute.director._director.SegmentationDirector
  • qute.director._director.Director
  • abc.ABC