跳转至

Function 函数系统概览

Function 函数系统是 SAGE Core 中用户逻辑的载体,它将用户定义的处理函数包装成可在算子系统中执行的标准化组件。函数系统提供了类型安全、资源管理、异常处理等功能,是连接用户代码和底层执行引擎的桥梁。

🏗️ 设计架构

函数系统采用分层抽象设计,支持多种类型的用户函数:

graph TD
    A[BaseFunction] --> B[SourceFunction]
    A --> C[MapFunction] 
    A --> D[FilterFunction]
    A --> E[SinkFunction]
    A --> F[FlatMapFunction]
    A --> G[KeyByFunction]
    A --> H[JoinFunction]
    A --> I[BatchFunction]
    A --> J[StatefulFunction]

    B --> K[LambdaSourceFunction]

    C --> L[LambdaMapFunction]

    D --> M[LambdaFilterFunction]

    E --> N[LambdaSinkFunction]

    F --> O[LambdaFlatMapFunction]

🧩 核心组件

1. 基础函数类 (BaseFunction)

所有函数的抽象基类,定义了函数的通用接口:

from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING
import logging

if TYPE_CHECKING:
    from sage.kernel.runtime.task_context import TaskContext

class BaseFunction(ABC):
    """
    BaseFunction is the abstract base class for all operator functions in SAGE.
    It defines the core interface and initializes a logger.
    """
    def __init__(self, *args, **kwargs):
        self.ctx: 'TaskContext' = None # 运行时注入
        self.router = None  # 运行时注入
        self._logger = None

    @property
    def logger(self):
        if not hasattr(self, "_logger") or self._logger is None:
            if self.ctx is None:
                self._logger = logging.getLogger("")
            else:
                self._logger = self.ctx.logger
        return self._logger

    @property
    def name(self):
        if self.ctx is None:
            return self.__class__.__name__
        return self.ctx.name

    @property
    def call_service(self):
        """
        同步服务调用语法糖

        用法:
            result = self.call_service["cache_service"].get("key1")
            data = self.call_service["db_service"].query("SELECT * FROM users")
        """
        if self.ctx is None:
            raise RuntimeError("Runtime context not initialized. Cannot access services.")

        return self.ctx.call_service()

    @property 
    def call_service_async(self):
        """
        异步服务调用语法糖

        用法:
            future = self.call_service_async["cache_service"].get("key1")
            result = future.result()  # 阻塞等待结果

            # 或者非阻塞检查
            if future.done():
                result = future.result()
        """
        if self.ctx is None:
            raise RuntimeError("Runtime context not initialized. Cannot access services.")

        return self.ctx.call_service_async()

    @abstractmethod
    def execute(self, data: Any):
        """
        Abstract method to be implemented by subclasses.

        Each function must define its own execute logic that processes input data
        and returns the output.

        :param data: Input data.
        :return: Output data.
        """
        pass

2. 任务上下文 (TaskContext)

提供函数执行时的环境信息和工具:

class TaskContext:
    def __init__(self, graph_node, transformation, env, execution_graph=None):
        self.name = graph_node.name
        self.env_name = env.name
        self.env_base_dir = env.env_base_dir
        self.parallel_index = graph_node.parallel_index
        self.parallelism = graph_node.parallelism
        self._logger = None
        self.is_spout = transformation.is_spout
        self.delay = 0.01
        self.stop_signal_num = graph_node.stop_signal_num

    @property
    def logger(self):
        """获取日志记录器"""
        return self._logger

    def call_service(self):
        """同步服务调用接口"""
        # 实际实现由运行时提供
        pass

    def call_service_async(self):
        """异步服务调用接口"""
        # 实际实现由运行时提供
        pass

3. 有状态函数 (StatefulFunction)

SAGE 提供了内置的状态管理功能:

import os
from sage.core.api.function.base_function import BaseFunction
from sage.kernel.utils.persistence.state import load_function_state, save_function_state

class StatefulFunction(BaseFunction):
    """
    有状态算子基类:自动在 init 恢复状态,
    并可通过 save_state() 持久化。
    """
    # 子类可覆盖:只保存 include 中字段
    __state_include__ = []
    # 默认排除 logger、私有属性和 runtime_context
    __state_exclude__ = ['logger', '_logger', 'ctx']

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # 恢复上次 checkpoint
        if self.ctx:
            chkpt_dir = os.path.join(self.ctx.env_base_dir, ".sage_checkpoints")
            chkpt_path = os.path.join(chkpt_dir, f"{self.ctx.name}.chkpt")
            load_function_state(self, chkpt_path)

    def save_state(self):
        """
        将当前对象状态持久化到 disk
        """
        if self.ctx:
            base = os.path.join(self.ctx.env_base_dir, ".sage_checkpoints")
            os.makedirs(base, exist_ok=True)
            path = os.path.join(base, f"{self.ctx.name}.chkpt")
            save_function_state(self, path)

🔧 函数类型

1. 源函数 (Source Functions)

负责数据生成和输入的函数:

from sage.core.api.function.base_function import BaseFunction

class StopSignal:
    """停止信号类,用于标识任务停止"""
    def __init__(self, name: str):
        self.name = name

    def __repr__(self) -> str:
        return f"<StopSignal {self.name}>"

class SourceFunction(BaseFunction):
    """
    源函数基类 - 数据生产者

    源函数不接收输入数据,只产生输出数据
    通常用于读取文件、数据库、API等外部数据源
    """

    @abstractmethod
    def execute(self) -> Any:
        """
        执行源函数逻辑,生产数据

        Returns:
            生产的数据
        """
        pass

# 使用示例
class SimpleSourceFunction(SourceFunction):
    def __init__(self, data_list):
        super().__init__()
        self.data_list = data_list
        self.index = 0

    def execute(self):
        if self.index < len(self.data_list):
            data = self.data_list[self.index]
            self.index += 1
            return data
        else:
            return StopSignal("data_exhausted")

2. 映射函数 (Map Functions)

对数据进行一对一转换的函数:

from sage.core.api.function.base_function import BaseFunction
from sage.core.api.function.map_function import MapFunction

class MapFunction(BaseFunction):
    """
    映射函数基类 - 一对一数据变换

    映射函数接收一个输入,产生一个输出
    用于数据转换、增强、格式化等操作
    """

    @abstractmethod
    def execute(self, data: Any) -> Any:
        """
        执行映射变换

        Args:
            data: 输入数据

        Returns:
            变换后的数据
        """
        pass

# Lambda函数包装器
from typing import Callable
from sage.core.api.function.lambda_function import LambdaMapFunction

class LambdaMapFunction(MapFunction):
    """将 lambda 函数包装为 MapFunction"""

    def __init__(self, lambda_func: Callable[[Any], Any], **kwargs):
        super().__init__(**kwargs)
        self.lambda_func = lambda_func

    def execute(self, data: Any) -> Any:
        return self.lambda_func(data)

# 使用示例
text_processor = LambdaMapFunction(lambda x: x.strip().upper())
number_doubler = LambdaMapFunction(lambda x: x * 2)

3. 过滤函数 (Filter Functions)

用于数据过滤的谓词函数:

from sage.core.api.function.filter_function import FilterFunction
from sage.core.api.function.lambda_function import LambdaFilterFunction

class FilterFunction(BaseFunction):
    """
    FilterFunction 是专门用于 Filter 操作的函数基类。
    它定义了过滤条件函数的接口,用于判断数据是否应该通过过滤器。

    Filter 函数的主要作用是接收输入数据,返回布尔值表示数据是否通过过滤条件。
    """

    @abstractmethod
    def execute(self, data: Any) -> bool:
        """
        抽象方法,由子类实现具体的过滤逻辑。

        Args:
            data: 输入数据

        Returns:
            bool: True表示数据应该通过,False表示应该被过滤掉
        """
        pass

class LambdaFilterFunction(FilterFunction):
    """将返回布尔值的 lambda 函数包装为 FilterFunction"""

    def __init__(self, lambda_func: Callable[[Any], bool], **kwargs):
        super().__init__(**kwargs)
        self.lambda_func = lambda_func

    def execute(self, data: Any) -> bool:
        return self.lambda_func(data)

# 使用示例
positive_filter = LambdaFilterFunction(lambda x: x > 0)
non_empty_filter = LambdaFilterFunction(lambda x: x is not None and str(x).strip() != "")

4. 汇函数 (Sink Functions)

负责数据输出的函数:

from sage.core.api.function.sink_function import SinkFunction
from sage.core.api.function.lambda_function import LambdaSinkFunction

class SinkFunction(BaseFunction):
    """
    汇聚函数基类 - 数据消费者

    汇聚函数接收输入数据,通常不产生输出
    用于数据存储、发送、打印等终端操作
    """

    @abstractmethod
    def execute(self, data: Any) -> None:
        """
        执行汇聚操作

        Args:
            data: 输入数据
        """
        pass

class LambdaSinkFunction(SinkFunction):
    """将 lambda 函数包装为 SinkFunction"""

    def __init__(self, lambda_func: Callable[[Any], None], **kwargs):
        super().__init__(**kwargs)
        self.lambda_func = lambda_func

    def execute(self, data: Any) -> None:
        self.lambda_func(data)

# 使用示例
print_sink = LambdaSinkFunction(lambda x: print(f"Processing: {x}"))

class FileSinkFunction(SinkFunction):
    def __init__(self, filename):
        super().__init__()
        self.filename = filename
        self.file_handle = None

    def setup(self):
        self.file_handle = open(self.filename, 'w')

    def execute(self, data: Any) -> None:
        if self.file_handle:
            self.file_handle.write(str(data) + '\n')
            self.file_handle.flush()

    def cleanup(self):
        if self.file_handle:
            self.file_handle.close()

5. 其他函数类型

SAGE 还支持其他类型的函数:

FlatMapFunction - 扁平化映射

from sage.core.api.function.flatmap_function import FlatMapFunction
from sage.core.api.function.lambda_function import LambdaFlatMapFunction

class FlatMapFunction(BaseFunction):
    """
    扁平化映射函数基类 - 一对多数据变换

    FlatMap函数接收一个输入,产生多个输出(列表形式)
    用于数据分解、展开等操作
    """

    @abstractmethod
    def execute(self, data: Any) -> List[Any]:
        """
        执行扁平化映射变换

        Args:
            data: 输入数据

        Returns:
            变换后的数据列表
        """
        pass

class LambdaFlatMapFunction(FlatMapFunction):
    """将返回列表的 lambda 函数包装为 FlatMapFunction"""

    def __init__(self, lambda_func: Callable[[Any], List[Any]], **kwargs):
        super().__init__(**kwargs)
        self.lambda_func = lambda_func

    def execute(self, data: Any) -> List[Any]:
        result = self.lambda_func(data)
        if not isinstance(result, list):
            raise TypeError(f"FlatMap lambda function must return a list, got {type(result)}")
        return result

# 使用示例
sentence_splitter = LambdaFlatMapFunction(lambda x: x.split())

KeyByFunction - 键值分组

from sage.core.api.function.keyby_function import KeyByFunction

class KeyByFunction(BaseFunction):
    """
    KeyBy函数基类 - 数据分组

    用于根据键值对数据进行分组操作
    """

    @abstractmethod
    def execute(self, data: Any) -> Any:
        """
        提取分组键

        Args:
            data: 输入数据

        Returns:
            分组键
        """
        pass

⚡ 高级特性

1. Lambda函数包装器

SAGE 提供了便捷的Lambda函数包装器,可以快速将普通函数转换为SAGE函数:

from sage.core.api.function.lambda_function import (
    LambdaMapFunction, LambdaFilterFunction, LambdaFlatMapFunction, 
    LambdaSinkFunction
)

# 快速创建各种类型的函数
map_func = LambdaMapFunction(lambda x: x * 2)
filter_func = LambdaFilterFunction(lambda x: x > 0)
flatmap_func = LambdaFlatMapFunction(lambda x: x.split())
sink_func = LambdaSinkFunction(lambda x: print(x))

2. 服务调用功能

函数可以通过上下文调用系统服务:

class ServiceCallFunction(BaseFunction):
    def execute(self, data):
        # 同步调用服务
        result = self.call_service["cache_service"].get("key1")

        # 异步调用服务
        future = self.call_service_async["db_service"].query("SELECT * FROM users")

        # 处理结果
        if future.done():
            db_result = future.result()

        return {"cache": result, "db": db_result}

📋 最佳实践

1. 函数设计

  • 单一职责: 每个函数应该只做一件事情
  • 类型明确: 明确输入和输出的数据类型
  • 错误处理: 合理处理异常并记录日志
class TextCleanerFunction(MapFunction):
    """文本清理函数 - 良好的设计示例"""

    def execute(self, text: str) -> str:
        if not isinstance(text, str):
            self.logger.error(f"Expected string input, got {type(text)}")
            raise TypeError("Input must be a string")

        if not text.strip():
            return ""

        # 清理逻辑
        cleaned = text.strip().lower()
        cleaned = ' '.join(cleaned.split())  # 规范化空格

        return cleaned

2. 状态管理

  • 使用StatefulFunction: 对于需要维护状态的函数
  • 定期保存状态: 在关键点调用save_state()
  • 合理设计状态结构: 避免状态过大导致序列化问题
class CounterFunction(StatefulFunction):
    def __init__(self):
        super().__init__()
        self.count = 0

    def execute(self, data):
        self.count += 1

        # 每处理100个数据保存一次状态
        if self.count % 100 == 0:
            self.save_state()

        return {"data": data, "count": self.count}

3. 服务调用

  • 异步优先: 对于I/O操作使用异步服务调用
  • 错误处理: 检查服务调用的返回状态
  • 资源管理: 及时释放服务连接
class DatabaseQueryFunction(MapFunction):
    def execute(self, query_params):
        try:
            # 使用异步服务调用
            future = self.call_service_async["db_service"].query(query_params)

            if future.done():
                result = future.result()
                if result.get("success"):
                    return result.get("data")
                else:
                    self.logger.error(f"Database query failed: {result.get('error')}")
                    return None
            else:
                self.logger.warning("Database query timeout")
                return None

        except Exception as e:
            self.logger.error(f"Service call failed: {e}")
            return None

4. 测试和调试

  • 单元测试: 为每个函数编写测试
  • 模拟上下文: 在测试中模拟TaskContext
  • 日志记录: 充分利用logger记录关键信息
import unittest
from unittest.mock import Mock

class TestTextCleanerFunction(unittest.TestCase):
    def setUp(self):
        self.function = TextCleanerFunction()
        # 模拟上下文
        self.function.ctx = Mock()
        self.function.ctx.logger = Mock()

    def test_clean_normal_text(self):
        result = self.function.execute("  Hello   World  ")
        self.assertEqual(result, "hello world")

    def test_clean_empty_text(self):
        result = self.function.execute("   ")
        self.assertEqual(result, "")

    def test_invalid_input(self):
        with self.assertRaises(TypeError):
            self.function.execute(123)

下一步: 了解 Transformation 转换系统 如何优化函数执行。