Source code for mlopus.kedro.hooks.mlflow_artifacts._mlflow_artifacts

import logging
from typing import TypeVar, Dict

from kedro.framework.hooks import hook_impl
from kedro.io import DataCatalog
from kedro.pipeline import Pipeline

import mlopus
from mlopus.utils import pydantic
from .input import PipelineInput
from .output import PipelineOutput
from ..hook_factory import HookWithFactory  # noqa: TID252

logger = logging.getLogger(__name__)

T = TypeVar("T")


[docs]class MlflowArtifacts(mlopus.mlflow.MlflowRunMixin, HookWithFactory): """Hook to set up inputs and collect outputs using MLflow artifacts. Find `here <https://github.com/lariel-fernandes/mlopus/tree/main/examples/2_a_kedro_project/conf/full/parameters/hooks/mlflow_artifacts.yml>`_ a fully commented example covering all settings that can be customized in this hook. """ collect_on_error: bool = pydantic.Field( default=False, description="Collect available outputs if pipeline fails.", ) inputs: Dict[str, PipelineInput] = pydantic.Field( default_factory=dict, description="Configure inputs.", ) outputs: Dict[str, PipelineOutput] = pydantic.Field( default_factory=dict, description="Configure outputs.", ) @pydantic.model_validator(mode="before") # noqa @classmethod def _backwards_compat_inputs_outputs(cls, values: dict): for key in ["inputs", "outputs"]: if isinstance(value := values.get(key), list): values[key] = {str(n): x for n, x in enumerate(value)} return values # ======================================================================================================= # === Artifact handlers ================================================================================= def _setup_inputs(self, pipeline_name: str): for name, i in self.inputs.items(): if i.enabled and i.used_by(pipeline_name): logger.info("Setting up input '%s'", name) lineage_arg = i.using(self.run_manager.mlflow_api).setup(default_run_id=self.run_manager.run.id) if i.log_lineage and lineage_arg is not None: mlopus.lineage.of(self.run_manager.run).with_input(lineage_arg).register() def _collect_outputs(self, pipeline_name: str): for name, o in self.outputs.items(): if o.enabled and o.used_by(pipeline_name): logger.info("Collecting output '%s'", name) lineage_arg = o.using(self.run_manager.mlflow_api).collect(default_run_id=self.run_manager.run.id) if lineage_arg is None: logger.info("Output '%s' is missing (skipped)", name) elif o.log_lineage: mlopus.lineage.of(self.run_manager.run).with_output(lineage_arg).register() # ======================================================================================================= # === Hook triggers ===================================================================================== @hook_impl def before_pipeline_run(self, run_params: dict, pipeline: Pipeline, catalog: DataCatalog): # noqa self._setup_inputs(run_params["pipeline_name"]) @hook_impl def after_pipeline_run(self, run_params: dict, run_result: dict, pipeline: Pipeline, catalog: DataCatalog): # noqa self._collect_outputs(run_params["pipeline_name"]) @hook_impl def on_pipeline_error(self, error: Exception, run_params: dict, pipeline: Pipeline, catalog: DataCatalog): # noqa if self.collect_on_error: self._collect_outputs(run_params["pipeline_name"])