函数接口 (Functions)¶
SAGE Kernel 提供了丰富的函数接口,支持用户定义各种数据处理逻辑。所有函数都继承自基础函数类,提供类型安全和性能优化。
🧩 函数类型概览¶
BaseFunction (抽象基类)
├── MapFunction # 一对一转换
├── FlatMapFunction # 一对多转换
├── FilterFunction # 过滤操作
├── ReduceFunction # 归约操作
├── AggregateFunction # 聚合操作
├── ProcessFunction # 通用处理函数
├── SinkFunction # 输出函数
├── SourceFunction # 数据源函数
├── KeySelector # 键选择器
├── JoinFunction # 连接函数
└── CoMapFunction # 协同映射函数
🔄 转换函数¶
MapFunction - 一对一转换¶
from sage.core.api.function import MapFunction
from typing import TypeVar
T = TypeVar('T')
U = TypeVar('U')
class MapFunction(BaseFunction[T, U]):
"""一对一转换函数基类"""
def map(self, value: T) -> U:
"""转换单个元素"""
raise NotImplementedError()
# 示例实现
class SquareFunction(MapFunction[int, int]):
def map(self, value: int) -> int:
return value * value
class ParseJsonFunction(MapFunction[str, dict]):
def map(self, json_str: str) -> dict:
try:
return json.loads(json_str)
except json.JSONDecodeError:
return {"error": "invalid_json", "raw": json_str}
class UserProfileExtractor(MapFunction[dict, UserProfile]):
def map(self, user_data: dict) -> UserProfile:
return UserProfile(
id=user_data["id"],
name=user_data["name"],
email=user_data.get("email"),
age=user_data.get("age", 0)
)
# 使用方式
numbers.map(SquareFunction())
json_lines.map(ParseJsonFunction())
user_data.map(UserProfileExtractor())
FlatMapFunction - 一对多转换¶
from sage.core.api.function import FlatMapFunction
from typing import Iterable
class FlatMapFunction(BaseFunction[T, Iterable[U]]):
"""一对多转换函数基类"""
def flat_map(self, value: T) -> Iterable[U]:
"""将一个元素转换为多个元素"""
raise NotImplementedError()
# 示例实现
class SplitWordsFunction(FlatMapFunction[str, str]):
def flat_map(self, sentence: str) -> Iterable[str]:
return sentence.lower().split()
class ExpandEventsFunction(FlatMapFunction[dict, dict]):
def flat_map(self, batch: dict) -> Iterable[dict]:
for event in batch.get("events", []):
event["batch_id"] = batch["id"]
event["batch_timestamp"] = batch["timestamp"]
yield event
class GenerateNGramsFunction(FlatMapFunction[str, str]):
def __init__(self, n: int = 2):
self.n = n
def flat_map(self, text: str) -> Iterable[str]:
words = text.split()
for i in range(len(words) - self.n + 1):
yield " ".join(words[i:i + self.n])
# 使用方式
sentences.flat_map(SplitWordsFunction())
batches.flat_map(ExpandEventsFunction())
text.flat_map(GenerateNGramsFunction(3)) # 3-grams
FilterFunction - 过滤操作¶
from sage.core.api.function import FilterFunction
class FilterFunction(BaseFunction[T, bool]):
"""过滤函数基类"""
def filter(self, value: T) -> bool:
"""判断是否保留该元素"""
raise NotImplementedError()
# 示例实现
class AdultUserFilter(FilterFunction[dict]):
def filter(self, user: dict) -> bool:
return user.get("age", 0) >= 18
class ValidEmailFilter(FilterFunction[str]):
def filter(self, email: str) -> bool:
return "@" in email and "." in email.split("@")[1]
class PriceRangeFilter(FilterFunction[dict]):
def __init__(self, min_price: float, max_price: float):
self.min_price = min_price
self.max_price = max_price
def filter(self, product: dict) -> bool:
price = product.get("price", 0)
return self.min_price <= price <= self.max_price
# 使用方式
users.filter(AdultUserFilter())
emails.filter(ValidEmailFilter())
products.filter(PriceRangeFilter(10.0, 100.0))
🔑 键值函数¶
KeySelector - 键选择器¶
from sage.core.api.function import KeySelector
K = TypeVar('K') # 键类型
class KeySelector(BaseFunction[T, K]):
"""键选择器基类"""
def get_key(self, value: T) -> K:
"""提取元素的键"""
raise NotImplementedError()
# 示例实现
class UserIdKeySelector(KeySelector[dict, str]):
def get_key(self, user: dict) -> str:
return user["id"]
class TimestampKeySelector(KeySelector[dict, int]):
def get_key(self, event: dict) -> int:
# 按小时分组
return event["timestamp"] // 3600
class CompositeKeySelector(KeySelector[dict, tuple]):
def get_key(self, record: dict) -> tuple:
return (record["category"], record["region"])
# 使用方式
users.key_by(UserIdKeySelector())
events.key_by(TimestampKeySelector())
sales.key_by(CompositeKeySelector())
ReduceFunction - 归约操作¶
from sage.core.api.function import ReduceFunction
class ReduceFunction(BaseFunction[T, T]):
"""归约函数基类"""
def reduce(self, value1: T, value2: T) -> T:
"""合并两个相同键的值"""
raise NotImplementedError()
# 示例实现
class SumReduceFunction(ReduceFunction[int]):
def reduce(self, value1: int, value2: int) -> int:
return value1 + value2
class MaxReduceFunction(ReduceFunction[float]):
def reduce(self, value1: float, value2: float) -> float:
return max(value1, value2)
class MergeUserFunction(ReduceFunction[dict]):
def reduce(self, user1: dict, user2: dict) -> dict:
# 合并用户信息,保留最新时间戳的数据
if user1.get("timestamp", 0) >= user2.get("timestamp", 0):
result = user1.copy()
result.update({k: v for k, v in user2.items() if k != "timestamp"})
else:
result = user2.copy()
result.update({k: v for k, v in user1.items() if k != "timestamp"})
return result
# 使用方式
numbers.key_by(lambda x: x % 2).reduce(SumReduceFunction())
scores.key_by(lambda x: x["user_id"]).reduce(MaxReduceFunction())
user_updates.key_by(lambda x: x["id"]).reduce(MergeUserFunction())
📊 聚合函数¶
AggregateFunction - 聚合操作¶
from sage.core.api.function import AggregateFunction
ACC = TypeVar('ACC') # 累加器类型
OUT = TypeVar('OUT') # 输出类型
class AggregateFunction(BaseFunction[T, ACC, OUT]):
"""聚合函数基类"""
def create_accumulator(self) -> ACC:
"""创建累加器初始值"""
raise NotImplementedError()
def add(self, accumulator: ACC, value: T) -> ACC:
"""将新值添加到累加器"""
raise NotImplementedError()
def get_result(self, accumulator: ACC) -> OUT:
"""从累加器获取最终结果"""
raise NotImplementedError()
def merge(self, acc1: ACC, acc2: ACC) -> ACC:
"""合并两个累加器(用于分布式聚合)"""
raise NotImplementedError()
# 示例实现
class CountAggregateFunction(AggregateFunction[Any, int, int]):
def create_accumulator(self) -> int:
return 0
def add(self, accumulator: int, value: Any) -> int:
return accumulator + 1
def get_result(self, accumulator: int) -> int:
return accumulator
def merge(self, acc1: int, acc2: int) -> int:
return acc1 + acc2
class AvgAggregateFunction(AggregateFunction[float, tuple, float]):
def create_accumulator(self) -> tuple:
return (0.0, 0) # (sum, count)
def add(self, accumulator: tuple, value: float) -> tuple:
sum_val, count = accumulator
return (sum_val + value, count + 1)
def get_result(self, accumulator: tuple) -> float:
sum_val, count = accumulator
return sum_val / count if count > 0 else 0.0
def merge(self, acc1: tuple, acc2: tuple) -> tuple:
return (acc1[0] + acc2[0], acc1[1] + acc2[1])
class TopKAggregateFunction(AggregateFunction[int, list, list]):
def __init__(self, k: int = 10):
self.k = k
def create_accumulator(self) -> list:
return []
def add(self, accumulator: list, value: int) -> list:
accumulator.append(value)
accumulator.sort(reverse=True)
return accumulator[:self.k]
def get_result(self, accumulator: list) -> list:
return accumulator
def merge(self, acc1: list, acc2: list) -> list:
merged = acc1 + acc2
merged.sort(reverse=True)
return merged[:self.k]
🔧 处理函数¶
ProcessFunction - 通用处理¶
from sage.core.api.function import ProcessFunction, ProcessContext
class ProcessFunction(BaseFunction[T, U]):
"""通用处理函数,支持副输出、定时器等高级功能"""
def process(self, value: T, ctx: ProcessContext[U]) -> None:
"""处理单个元素"""
raise NotImplementedError()
def on_timer(self, timestamp: int, ctx: ProcessContext[U]) -> None:
"""定时器回调"""
pass
# 示例实现
class ValidationFunction(ProcessFunction[dict, dict]):
def process(self, record: dict, ctx: ProcessContext[dict]):
# 数据验证
if self.is_valid(record):
ctx.emit(record) # 输出到主流
else:
# 输出到错误流
ctx.output_to_side("errors", f"Invalid: {record}")
def is_valid(self, record: dict) -> bool:
required_fields = ["id", "timestamp", "data"]
return all(field in record for field in required_fields)
class SessionTimeoutFunction(ProcessFunction[dict, dict]):
def __init__(self, timeout_ms: int = 30000):
self.timeout_ms = timeout_ms
self.sessions = {}
def process(self, event: dict, ctx: ProcessContext[dict]):
session_id = event["session_id"]
current_time = ctx.timestamp()
# 更新会话
self.sessions[session_id] = current_time
# 设置超时定时器
ctx.register_timer(current_time + self.timeout_ms)
ctx.emit(event)
def on_timer(self, timestamp: int, ctx: ProcessContext[dict]):
# 清理超时会话
expired_sessions = [
sid for sid, last_time in self.sessions.items()
if timestamp - last_time >= self.timeout_ms
]
for session_id in expired_sessions:
del self.sessions[session_id]
ctx.output_to_side("timeouts", {"session_id": session_id, "timeout": timestamp})
📤 输入输出函数¶
SourceFunction - 数据源¶
from sage.core.api.function import SourceFunction, SourceContext
class SourceFunction(BaseFunction[None, T]):
"""数据源函数基类"""
def run(self, ctx: SourceContext[T]) -> None:
"""生成数据"""
raise NotImplementedError()
def cancel(self) -> None:
"""取消数据源"""
pass
# 示例实现
class CounterSourceFunction(SourceFunction[int]):
def __init__(self, max_count: int = 100, interval_ms: int = 1000):
self.max_count = max_count
self.interval_ms = interval_ms
self.running = True
def run(self, ctx: SourceContext[int]):
count = 0
while self.running and count < self.max_count:
ctx.emit(count)
count += 1
time.sleep(self.interval_ms / 1000.0)
def cancel(self):
self.running = False
class FileSourceFunction(SourceFunction[str]):
def __init__(self, file_path: str):
self.file_path = file_path
def run(self, ctx: SourceContext[str]):
with open(self.file_path, 'r') as f:
for line in f:
ctx.emit(line.strip())
class KafkaSourceFunction(SourceFunction[dict]):
def __init__(self, bootstrap_servers: str, topic: str, group_id: str):
self.bootstrap_servers = bootstrap_servers
self.topic = topic
self.group_id = group_id
self.running = True
def run(self, ctx: SourceContext[dict]):
from kafka import KafkaConsumer
consumer = KafkaConsumer(
self.topic,
bootstrap_servers=self.bootstrap_servers,
group_id=self.group_id,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
for message in consumer:
if not self.running:
break
ctx.emit(message.value)
def cancel(self):
self.running = False
SinkFunction - 数据输出¶
from sage.core.api.function import SinkFunction
class SinkFunction(BaseFunction[T, None]):
"""数据输出函数基类"""
def open(self, context) -> None:
"""初始化资源"""
pass
def sink(self, value: T) -> None:
"""输出单个元素"""
raise NotImplementedError()
def close(self) -> None:
"""清理资源"""
pass
# 示例实现
class PrintSinkFunction(SinkFunction[Any]):
def __init__(self, prefix: str = ""):
self.prefix = prefix
def sink(self, value: Any):
print(f"{self.prefix}{value}")
class FileSinkFunction(SinkFunction[str]):
def __init__(self, file_path: str):
self.file_path = file_path
self.file = None
def open(self, context):
self.file = open(self.file_path, 'w')
def sink(self, value: str):
self.file.write(f"{value}\n")
self.file.flush()
def close(self):
if self.file:
self.file.close()
class DatabaseSinkFunction(SinkFunction[dict]):
def __init__(self, connection_string: str, table_name: str):
self.connection_string = connection_string
self.table_name = table_name
self.connection = None
def open(self, context):
import psycopg2
self.connection = psycopg2.connect(self.connection_string)
def sink(self, record: dict):
cursor = self.connection.cursor()
columns = list(record.keys())
values = list(record.values())
query = f"INSERT INTO {self.table_name} ({','.join(columns)}) VALUES ({','.join(['%s'] * len(values))})"
cursor.execute(query, values)
self.connection.commit()
cursor.close()
def close(self):
if self.connection:
self.connection.close()
🔗 连接函数¶
JoinFunction - 流连接¶
from sage.core.api.function import JoinFunction
class JoinFunction(BaseFunction[T1, T2, OUT]):
"""连接函数基类"""
def join(self, left: T1, right: T2) -> OUT:
"""连接两个流的元素"""
raise NotImplementedError()
# 示例实现
class UserOrderJoinFunction(JoinFunction[dict, dict, dict]):
def join(self, user: dict, order: dict) -> dict:
return {
"order_id": order["id"],
"user_name": user["name"],
"user_email": user["email"],
"order_amount": order["amount"],
"order_time": order["timestamp"]
}
class ClickImpressionJoinFunction(JoinFunction[dict, dict, dict]):
def join(self, click: dict, impression: dict) -> dict:
return {
"ad_id": click["ad_id"],
"user_id": click["user_id"],
"click_time": click["timestamp"],
"impression_time": impression["timestamp"],
"conversion_delay": click["timestamp"] - impression["timestamp"]
}
CoMapFunction - 协同映射¶
from sage.core.api.function import CoMapFunction
class CoMapFunction(BaseFunction[T1, T2, OUT]):
"""协同映射函数基类"""
def map1(self, value: T1) -> OUT:
"""处理第一个流的元素"""
raise NotImplementedError()
def map2(self, value: T2) -> OUT:
"""处理第二个流的元素"""
raise NotImplementedError()
# 示例实现
class AlertCoMapFunction(CoMapFunction[dict, dict, str]):
def map1(self, user_action: dict) -> str:
if user_action["action"] == "login_failed":
return f"Security Alert: Failed login attempt by user {user_action['user_id']}"
return None
def map2(self, system_event: dict) -> str:
if system_event["level"] == "ERROR":
return f"System Alert: {system_event['message']}"
return None
class MetricsCoMapFunction(CoMapFunction[dict, dict, dict]):
def map1(self, user_metric: dict) -> dict:
return {
"type": "user_metric",
"metric": user_metric["metric_name"],
"value": user_metric["value"],
"timestamp": user_metric["timestamp"]
}
def map2(self, system_metric: dict) -> dict:
return {
"type": "system_metric",
"metric": system_metric["metric_name"],
"value": system_metric["value"],
"timestamp": system_metric["timestamp"]
}
🎯 最佳实践¶
1. 函数状态管理¶
class StatefulProcessFunction(ProcessFunction[str, int]):
def __init__(self):
self.word_count = {} # 状态
def process(self, word: str, ctx: ProcessContext[int]):
self.word_count[word] = self.word_count.get(word, 0) + 1
ctx.emit(self.word_count[word])
2. 错误处理¶
class RobustMapFunction(MapFunction[str, dict]):
def map(self, json_str: str) -> dict:
try:
return json.loads(json_str)
except Exception as e:
return {
"error": str(e),
"raw_input": json_str,
"timestamp": time.time()
}
3. 性能优化¶
class OptimizedAggregateFunction(AggregateFunction[int, int, int]):
def __init__(self):
self.batch_size = 1000
self.batch = []
def add(self, accumulator: int, value: int) -> int:
self.batch.append(value)
if len(self.batch) >= self.batch_size:
# 批量处理
accumulator += sum(self.batch)
self.batch.clear()
return accumulator
4. 资源管理¶
class DatabaseSinkFunction(SinkFunction[dict]):
def open(self, context):
self.connection_pool = create_connection_pool()
def sink(self, record: dict):
with self.connection_pool.get_connection() as conn:
# 使用连接池
self.insert_record(conn, record)
def close(self):
self.connection_pool.close()