AI 驱动的数据库内核:当学习型查询优化器遇上代价模型
AI 驱动的数据库内核当学习型查询优化器遇上代价模型一、优化器选错执行计划比没有索引更可怕生产环境一条核心报表 SQL数据量从百万级增长到千万级后查询耗时从 800ms 飙到 45 秒。EXPLAIN显示优化器选择了全表扫描而非索引范围扫描。原因统计信息ndv唯一值数量严重失真优化器基于错误代价估算做出了错误决策。手动ANALYZE TABLE后恢复但三天后问题复现。传统基于统计信息的代价模型在高数据倾斜、高频写入场景下统计信息滞后是结构性缺陷。AI 驱动的学习型优化器Learned Query Optimizer试图从历史执行反馈中学习替代或增强传统代价估算这正是本文要拆解的核心命题。二、学习型查询优化器的架构与代价模型重构2.1 传统代价模型的结构性缺陷InnoDB 优化器的代价估算依赖mysql.innodb_table_stats和mysql.innodb_index_stats核心公式Cost IO_Cost CPU_Cost IO_Cost 页面数 × 随机读取代价权重 CPU_Cost 评估行数 × 条件评估代价权重问题在于评估行数来自统计信息采样采样率默认仅 1/16 页高倾斜列的ndv误差可达 10 倍以上。2.2 学习型优化器的架构学习型优化器的核心思路用历史查询的实际执行数据真实行数、真实 IO 次数、真实耗时训练模型预测新查询的代价。flowchart LR A[SQL 输入] -- B[解析与 Plan 枚举] B -- C[特征提取] C -- D{代价预测模型} D --|传统代价模型| E[统计信息估算代价] D --|学习型代价模型| F[ML 模型预测代价] E -- G[代价融合加权] F -- G G -- H[选择最优 Plan] H -- I[执行并采集反馈] I -- J[反馈数据入训练集] J -- D2.3 特征工程SQL 如何变成模型输入学习型优化器的关键在于特征提取。一条 SQL 需要编码为固定维度向量特征类别具体特征编码方式表级特征行数、平均行长度、索引数量归一化数值列级特征ndv、null_ratio、数据倾斜度归一化数值谓词特征等值/范围/IN、选择性估计one-hot 数值Join 特征Join 类型、连接列 ndv 比率one-hot 数值Plan 特征扫描方式、Join 算法、排序方式one-hot2.4 代价融合策略生产环境不会直接用 ML 模型替换传统优化器而是加权融合Final_Cost α × Traditional_Cost (1 - α) × Learned_Costα 初始值为 1.0完全信任传统模型随着反馈数据积累逐步降低。当模型置信度低于阈值时自动回退到传统代价。三、基于反馈学习的代价校准实现3.1 执行反馈采集器import time import threading import pymysql from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple import numpy as np import logging logger logging.getLogger(__name__) dataclass class QueryFeedback: 单次查询执行反馈记录 query_fingerprint: str # SQL 指纹, 归一化后的模板 estimated_rows: int # 优化器估算行数 actual_rows: int # 实际返回行数 estimated_cost: float # 优化器估算代价 execution_time_ms: float # 实际执行耗时(ms) scan_type: str # 扫描方式: full_scan / index_scan / range_scan plan_hash: str # 执行计划哈希 timestamp: float field(default_factorytime.time) property def estimation_error(self) - float: 估算误差倍数, 越大说明统计信息越不准 if self.estimated_rows 0: return float(inf) return max(self.actual_rows / self.estimated_rows, self.estimated_rows / max(self.actual_rows, 1)) class FeedbackCollector: 执行反馈采集器, 通过 performance_schema 采集真实执行数据 def __init__(self, mysql_config: dict, sample_rate: float 0.1): self.mysql_config mysql_config self.sample_rate sample_rate # 采样率, 避免全量采集影响性能 self._feedback_buffer: List[QueryFeedback] [] self._lock threading.Lock() self._stop_event threading.Event() def _get_connection(self): return pymysql.connect(**self.mysql_config) def collect_from_events_statements(self) - List[QueryFeedback]: 从 performance_schema.events_statements_summary_by_digest 采集 sql SELECT DIGEST_TEXT, SUM_ROWS_EXAMINED, SUM_ROWS_AFFECTED, SUM_TIMER_WAIT / 1000000000 AS total_ms, COUNT_STAR, FIRST_SEEN, LAST_SEEN FROM performance_schema.events_statements_summary_by_digest WHERE DIGEST_TEXT IS NOT NULL AND SUM_ROWS_EXAMINED 0 AND COUNT_STAR 5 ORDER BY SUM_TIMER_WAIT DESC LIMIT 200 feedbacks [] try: with self._get_connection() as conn: with conn.cursor() as cur: cur.execute(sql) for row in cur.fetchall(): digest_text row[0] rows_examined row[1] rows_affected row[2] exec_time_ms row[3] / max(row[4], 1) # 平均单次耗时 count_star row[4] # 采样: 只采集部分查询 if hash(digest_text) % 100 self.sample_rate * 100: continue feedback QueryFeedback( query_fingerprintdigest_text[:200], estimated_rows0, # 需要从 EXPLAIN 补充 actual_rowsrows_examined, estimated_cost0.0, execution_time_msexec_time_ms, scan_typeunknown, plan_hash, ) feedbacks.append(feedback) except pymysql.err.OperationalError as e: logger.error(f采集反馈失败: {e}) return feedbacks def enrich_with_explain(self, feedback: QueryFeedback) - Optional[QueryFeedback]: 用 EXPLAIN 补充优化器估算信息 try: with self._get_connection() as conn: with conn.cursor() as cur: cur.execute(fEXPLAIN {feedback.query_fingerprint}) explain_rows cur.fetchall() if explain_rows: first_row explain_rows[0] feedback.estimated_rows first_row[9] or 0 # rows 列 feedback.scan_type first_row[1] or unknown # access_type return feedback except Exception as e: logger.warning(fEXPLAIN 补充失败: {e}) return None def compute_correction_factor(self) - Dict[str, float]: 计算各 SQL 指纹的代价校准因子 with self._lock: buffer list(self._feedback_buffer) # 按 SQL 指纹分组, 计算平均估算误差 grouped: Dict[str, List[float]] defaultdict(list) for fb in buffer: if fb.estimation_error ! float(inf): grouped[fb.query_fingerprint].append(fb.estimation_error) correction {} for fingerprint, errors in grouped.items(): if len(errors) 3: # 至少 3 次采样才计算校准因子 # 取中位数, 避免极端值干扰 correction[fingerprint] float(np.median(errors)) return correction def start_background_collection(self, interval_seconds: int 60): 启动后台采集线程 def _worker(): while not self._stop_event.is_set(): try: new_feedbacks self.collect_from_events_statements() with self._lock: self._feedback_buffer.extend(new_feedbacks) # 限制缓冲区大小, 防止内存溢出 if len(self._feedback_buffer) 10000: self._feedback_buffer self._feedback_buffer[-5000:] logger.info(f采集到 {len(new_feedbacks)} 条反馈) except Exception as e: logger.error(f后台采集异常: {e}) self._stop_event.wait(interval_seconds) thread threading.Thread(target_worker, daemonTrue) thread.start() logger.info(反馈采集后台线程已启动) def stop(self): self._stop_event.set()3.2 校准因子注入优化器采集到校准因子后通过 MySQL 8.0 的optimizer_switch和engine_condition_pushdown等机制间接影响优化器决策或通过 Proxy 层如 ProxySQL改写 SQL Hint-- 对已知估算偏差大的查询, 强制指定索引 SELECT /* INDEX(orders idx_create_time) */ order_id, amount FROM orders WHERE create_time BETWEEN 2025-01-01 AND 2025-01-31; -- 或通过 session 级别调整代价权重 SET SESSION optimizer_switch index_condition_pushdownon;更彻底的方案是在数据库代理层实现代价校准Proxy 拦截 SQL查询校准因子表自动注入 Hint。四、学习型优化器的现实边界与架构妥协4.1 冷启动问题模型训练需要足够的反馈数据。新上线的业务、数据量突变的场景模型无历史数据可用。解决方案用传统代价模型作为先验模型置信度不足时自动回退。但回退阈值的选择本身就是一个需要调参的问题。4.2 数据分布漂移训练数据来自历史分布当业务模式变化如大促期间数据倾斜剧变模型预测精度急剧下降。需要设置反馈数据的衰减窗口如只使用最近 7 天数据训练但这又与数据量需求矛盾。4.3 推理延迟ML 模型推理需要时间即使轻量级模型如 XGBoost单次推理也在微秒级。对于执行时间本就小于 1ms 的简单查询优化器开销占比过高。生产方案必须设置短路机制简单查询单表、无 Join、索引命中跳过模型推理。4.4 可解释性缺失优化器选错 Plan 时DBA 需要知道原因。传统代价模型可以逐步推演ML 模型是黑盒。生产环境必须保留完整的代价计算日志包括传统代价、学习代价、融合权重、最终决策依据。4.5 禁用场景数据量小于 10 万行的小表优化器本身很少选错高频简单点查QPS 10 万推理开销不可接受强合规场景审计要求优化器决策可追溯黑盒模型无法满足五、总结AI 驱动的学习型优化器并非替代传统代价模型而是在统计信息失真场景下的增强手段。其核心架构是传统代价 学习代价的融合策略通过执行反馈采集、校准因子计算、代价加权融合三个环节形成闭环。生产落地的关键挑战不在模型精度而在冷启动、数据漂移、推理延迟和可解释性四个工程问题。务实的路径是先用反馈数据量化统计信息偏差再逐步引入模型校准始终保持传统优化器作为兜底。任何没有回退机制的 AI 优化器都是生产事故的定时炸弹。