学途智助
首页
分类
标签
关于网站
登录
eeettt
2026-03-31
10
作者编辑
Search-R1 代码学习指南
# Search-R1 代码学习指南 ## 项目概述 Search-R1 是一个基于强化学习的框架,用于训练**推理与搜索交互的大语言模型**。它让语言模型学会在推理过程中调用搜索引擎(工具调用)来获取外部知识。 ### 核心特点 - 基于 veRL 框架构建 - 支持多种 RL 算法(PPO, GRPO, reinforce) - 支持多种 LLM(Llama3, Qwen2.5 等) - 支持多种搜索引擎(本地稀疏/密集检索器、在线搜索引擎) - 模型能够自主学习何时搜索、如何搜索以及如何利用搜索结果 --- ## 整体架构流程(带代码位置) ``` 用户问题 ↓ 初始化 LLM 📁 verl/trainer/main_ppo.py:127-131 tokenizer = hf_tokenizer(local_path) ↓ ┌─────────────────────────────────────────────────────────────────┐ │ LLM 生成循环(run_llm_loop) │ │ 📁 search_r1/llm_agent/generation.py:220-319 │ │ │ │ ┌──────────────────────────────────────────────────────────┐ │ │ │ 1. LLM 生成响应 (推理/搜索/回答) │ │ │ │ 📁 generation.py:242-249 │ │ │ │ gen_output = self._generate_with_gpu_padding(...) │ │ │ │ responses_ids, responses_str = self._postprocess_... │ │ │ └──────────────────────────────────────────────────────────┘ │ │ ↓ │ │ ┌──────────────────────────────────────────────────────────┐ │ │ │ 2. 解析动作 │ │ │ │ 📁 generation.py:407-436 │ │ │ │ postprocess_predictions(predictions) │ │ │ │ - 使用正则提取 <search>query</search> │ │ │ │ - 或提取 <answer>answer</answer> │ │ │ │ pattern = r'<(search|answer)>(.*?)</\1>' │ │ │ └──────────────────────────────────────────────────────────┘ │ │ ↓ │ │ ┌──────────────────────────────────────────────────────────┐ │ │ │ 3. 执行动作 │ │ │ │ 📁 generation.py:353-405 │ │ │ │ execute_predictions(responses_str, ...) │ │ │ │ ├─ 调用检索 API: batch_search(search_queries) │ │ │ │ │ 📁 generation.py:438-469 │ │ │ │ │ requests.post(search_url, json=payload) │ │ │ │ └─ 获取文档并格式化 │ │ │ │ next_obs.append('<information>...</information>') │ │ │ └──────────────────────────────────────────────────────────┘ │ │ ↓ │ │ ┌──────────────────────────────────────────────────────────┐ │ │ │ 4. 更新上下文 │ │ │ │ 📁 generation.py:267-276 │ │ │ │ rollings = self._update_rolling_state(...) │ │ │ │ ├─ concatenate: [原上下文, 新生成, 搜索结果] │ │ │ │ │ 📁 generation.py:93-118 │ │ │ │ └─ 创建 attention_mask 和 position_ids │ │ │ └──────────────────────────────────────────────────────────┘ │ │ ↓ │ │ └──> 继续循环或结束 (active_mask 控制) │ │ 📁 generation.py:257-262 │ │ if not active_mask.sum(): break │ └─────────────────────────────────────────────────────────────────┘ ↓ 计算奖励(基于答案匹配) 📁 verl/trainer/main_ppo.py:32-97 (RewardManager) ├─ 解码生成序列: tokenizer.decode(sequences) ├─ 提取答案: 从 <answer>...</answer> 中 ├─ 计算 EM 分数: compute_score_em(solution_str, ground_truth) └─ reward_tensor[i, valid_response_length - 1] = score ↓ PPO 更新模型 📁 verl/trainer/ppo/ray_trainer.py (RayPPOTrainer) 📁 verl/trainer/ppo/core_algos.py (PPO 算法实现) ├─ 计算优势函数 (Advantage) ├─ Actor 更新: policy gradient └─ Critic 更新: value function ``` --- ## 核心代码文件结构 ``` Search-R1/ ├── train_ppo.sh # PPO 训练启动脚本 ├── train_grpo.sh # GRPO 训练启动脚本 ├── infer.py # 推理脚本(测试训练好的模型) ├── retrieval_launch.sh # 启动本地检索服务器 │ ├── search_r1/ │ ├── llm_agent/ │ │ ├── generation.py # 核心:LLM 生成管理器 │ │ └── tensor_helper.py # 张量操作辅助工具 │ └── search/ │ ├── retrieval_server.py # 本地检索服务器(FastAPI) │ ├── retrieval.py # 检索器实现(BM25/Dense) │ ├── google_search_server.py # Google 搜索 API │ └── index_builder.py # 索引构建工具 │ ├── verl/ # veRL 强化学习框架 │ ├── trainer/ │ │ ├── main_ppo.py # PPO 训练主入口 │ │ └── ppo/ │ │ ├── ray_trainer.py # Ray 分布式训练器 │ │ └── core_algos.py # PPO 核心算法 │ └── workers/ # 分布式 worker │ └── scripts/ └── data_process/ └── nq_search.py # 数据预处理脚本 ``` --- ## 完整代码调用链路 ### 训练时的完整调用栈 ``` 1. 启动脚本 📁 train_ppo.sh └─> python3 -m verl.trainer.main_ppo 2. 训练入口 📁 verl/trainer/main_ppo.py:105 └─> main(config) └─> ray.get(main_task.remote(config)) 3. 主任务 📁 verl/trainer/main_ppo.py:114 └─> main_task(config) ├─> tokenizer = hf_tokenizer(local_path) # 加载 tokenizer ├─> reward_fn = RewardManager(tokenizer, ...) # 创建奖励函数 ├─> trainer = RayPPOTrainer(config, ...) # 创建 PPO 训练器 ├─> trainer.init_workers() # 初始化分布式 workers └─> trainer.fit() # 开始训练 4. PPO 训练循环 📁 verl/trainer/ppo/ray_trainer.py └─> RayPPOTrainer.fit() │ ├─ 每个 epoch: │ ├─> self._gen_trajectory(data_batch) # 生成轨迹 │ │ 📁 ray_trainer.py:~400 行 │ │ │ │ │ ├─> 调用 LLMGenerationManager │ │ │ 📁 search_r1/llm_agent/generation.py:220 │ │ │ └─> run_llm_loop(gen_batch, initial_input_ids) │ │ │ │ │ │ │ ├─ for step in range(max_turns): # 多轮循环 │ │ │ │ │ │ │ │ │ ├─> [步骤1] 生成响应 │ │ │ │ │ gen_output = self._generate_with_gpu_padding(rollings_active) │ │ │ │ │ 📁 generation.py:246 │ │ │ │ │ └─> actor_rollout_wg.generate_sequences() │ │ │ │ │ (调用 vLLM 进行生成) │ │ │ │ │ │ │ │ │ ├─> [步骤2] 后处理响应 │ │ │ │ │ responses_ids, responses_str = self._postprocess_responses() │ │ │ │ │ 📁 generation.py:54-75 │ │ │ │ │ (在 </search> 或 </answer> 处截断) │ │ │ │ │ │ │ │ │ ├─> [步骤3] 执行预测 │ │ │ │ │ next_obs, dones, valid_action, is_search = │ │ │ │ │ self.execute_predictions(responses_str, ...) │ │ │ │ │ 📁 generation.py:353-405 │ │ │ │ │ │ │ │ │ │ │ ├─> postprocess_predictions() # 解析动作 │ │ │ │ │ │ 📁 generation.py:407-436 │ │ │ │ │ │ └─ 正则匹配: r'<(search|answer)>(.*?)</\1>' │ │ │ │ │ │ │ │ │ │ │ ├─> batch_search(search_queries) # 批量搜索 │ │ │ │ │ │ 📁 generation.py:438-448 │ │ │ │ │ │ └─> requests.post(search_url, json=payload) │ │ │ │ │ │ │ │ │ │ │ │ │ └─> 检索服务器 │ │ │ │ │ │ 📁 search_r1/search/retrieval_server.py:326 │ │ │ │ │ │ └─> @app.post("/retrieve") │ │ │ │ │ │ └─> retriever.batch_search() │ │ │ │ │ │ 📁 retrieval_server.py:143-271 │ │ │ │ │ │ ├─ BM25Retriever._batch_search() │ │ │ │ │ │ └─ DenseRetriever._batch_search() │ │ │ │ │ │ └─> encoder.encode(query) # 编码 │ │ │ │ │ │ └─> index.search(emb, k) # FAISS 搜索 │ │ │ │ │ │ │ │ │ │ │ └─> 格式化观察: '<information>docs</information>' │ │ │ │ │ │ │ │ │ ├─> [步骤4] 更新滚动状态 │ │ │ │ │ rollings = self._update_rolling_state( │ │ │ │ │ rollings, responses_ids, next_obs_ids) │ │ │ │ │ 📁 generation.py:93-118 │ │ │ │ │ └─ concatenate: [原上下文, 新生成, 搜索结果] │ │ │ │ │ │ │ │ │ └─> 更新 active_mask (标记完成的样本) │ │ │ │ 📁 generation.py:257-262 │ │ │ │ │ │ │ └─> 返回 final_output (完整轨迹) │ │ │ │ │ ├─> actor_rollout_wg.compute_log_prob() # 计算 log prob │ │ └─> ref_policy_wg.compute_ref_log_prob() # 计算参考策略 log prob │ │ │ ├─> self._compute_rewards(rollout_data) # 计算奖励 │ │ 📁 ray_trainer.py:~500 行 │ │ └─> reward_fn(rollout_data) │ │ 📁 verl/trainer/main_ppo.py:41-97 │ │ └─> RewardManager.__call__() │ │ │ │ │ ├─ 解码: tokenizer.decode(sequences) │ │ ├─ 选择评分函数: compute_score_fn = _select_rm_score_fn() │ │ │ 📁 main_ppo.py:25-29 │ │ │ └─ qa_em.compute_score_em() │ │ │ 📁 verl/utils/reward_score/qa_em.py │ │ │ └─ 提取答案并与 ground_truth 对比 │ │ │ │ │ └─ reward_tensor[i, last_token] = score │ │ │ ├─> apply_kl_penalty(rollout_data, kl_ctrl) # 应用 KL 惩罚 │ │ 📁 ray_trainer.py:91-124 │ │ └─ KL散度: kl = log_prob - ref_log_prob │ │ └─ 调整奖励: rewards += kl_penalty │ │ │ ├─> critic_wg.compute_values(rollout_data) # Critic 计算 value │ │ │ ├─> 计算优势函数(GAE) │ │ 📁 verl/trainer/ppo/core_algos.py │ │ └─> compute_advantage() │ │ └─ GAE: A_t = δ_t + (γλ)δ_{t+1} + ... │ │ │ ├─> self._update_actor(rollout_data) # 更新 Actor │ │ 📁 ray_trainer.py:~600 行 │ │ └─> actor_wg.update_policy(ppo_data) │ │ └─ PPO loss = -min(ratio*A, clip(ratio)*A) │ │ │ └─> self._update_critic(rollout_data) # 更新 Critic │ 📁 ray_trainer.py:~700 行 │ └─> critic_wg.update_critic(ppo_data) │ └─ MSE loss = (V - target_V)^2 │ └─ 保存 checkpoint 📁 ray_trainer.py:~800 行 ``` ### 推理时的调用链路 ``` 📁 infer.py └─> main() ├─> model = AutoModelForCausalLM.from_pretrained(model_id) ├─> tokenizer = AutoTokenizer.from_pretrained(model_id) │ └─> while True: # 生成循环 │ ├─> outputs = model.generate( │ input_ids, │ stopping_criteria=StopOnSequence(["</search>"]) │ ) │ ├─> 检查是否结束 │ if outputs[0][-1] in eos_tokens: │ break │ ├─> 提取查询 │ query = get_query(output_text) # 正则: r"<search>(.*?)</search>" │ ├─> 调用搜索 │ if query: │ search_results = search(query) │ └─> requests.post("http://127.0.0.1:8000/retrieve", ...) │ └─> 更新 prompt prompt += f'\n\n<information>{search_results}</information>\n\n' ``` --- ## 关键代码片段(带行号) ### 1. LLM 生成主循环 **文件:** `search_r1/llm_agent/generation.py:220-319` ```python def run_llm_loop(self, gen_batch, initial_input_ids: torch.Tensor): """Run main LLM generation loop.""" # 初始化状态 active_mask = torch.ones(batch_size, dtype=torch.bool) # Line 226 turns_stats = torch.ones(batch_size, dtype=torch.int) # Line 227 # 主生成循环 for step in range(self.config.max_turns): # Line 234 if not active_mask.sum(): # Line 235 - 如果都完成了就退出 break # ===== 步骤1: LLM 生成 ===== rollings_active = DataProto.from_dict({ # Line 243 k: v[active_mask] for k, v in rollings.batch.items() }) gen_output = self._generate_with_gpu_padding(rollings_active) # Line 246 # ===== 步骤2: 后处理 ===== responses_ids, responses_str = self._postprocess_responses( # Line 249 gen_output.batch['responses'] ) # ===== 步骤3: 执行动作 ===== next_obs, dones, valid_action, is_search = self.execute_predictions( # Line 253 responses_str, self.tokenizer.pad_token, active_mask ) # ===== 步骤4: 更新状态 ===== curr_active_mask = torch.tensor([not done for done in dones]) # Line 257 active_mask = active_mask * curr_active_mask # Line 258 rollings = self._update_rolling_state( # Line 267 rollings, responses_ids, next_obs_ids ) ``` ### 2. 动作解析和执行 **文件:** `search_r1/llm_agent/generation.py:353-436` ```python def execute_predictions(self, predictions: List[str], pad_token: str, active_mask=None, do_search=True): """Execute predictions across multiple environments.""" # Line 367: 解析动作 cur_actions, contents = self.postprocess_predictions(predictions) # Line 370-375: 批量搜索 search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search'] if do_search: search_results = self.batch_search(search_queries) # 调用搜索 API # Line 377-404: 为每个样本生成观察 for i, (action, active) in enumerate(zip(cur_actions, active_mask)): if action == 'answer': next_obs.append('') dones.append(1) # 标记为完成 elif action == 'search': next_obs.append(f'<information>{search_results.pop(0)}</information>') dones.append(0) # 继续生成 else: next_obs.append('My previous action is invalid. ...') dones.append(0) return next_obs, dones, valid_action, is_search def postprocess_predictions(self, predictions: List[str]): """Process predictions into actions.""" pattern = r'<(search|answer)>(.*?)</\1>' # Line 422 for prediction in predictions: match = re.search(pattern, prediction, re.DOTALL) # Line 423 if match: content = match.group(2).strip() # 标签内的内容 action = match.group(1) # 'search' 或 'answer' else: action = None return actions, contents def batch_search(self, queries: List[str]): """Batch search queries.""" payload = { # Line 452 "queries": queries, "topk": self.config.topk, "return_scores": True } results = requests.post(self.config.search_url, json=payload).json() # Line 458 return [self._passages2string(result) for result in results['result']] ``` ### 3. 奖励计算 **文件:** `verl/trainer/main_ppo.py:32-97` ```python class RewardManager: def __call__(self, data: DataProto): """Compute rewards for generated sequences.""" reward_tensor = torch.zeros_like(data.batch['responses']) # Line 48 for i in range(len(data)): # Line 54 # Line 57-66: 提取有效的 prompt 和 response valid_prompt_ids = prompt_ids[-valid_prompt_length:] valid_response_ids = response_ids[:valid_response_length] # Line 69-70: 解码序列 sequences = torch.cat((valid_prompt_ids, valid_response_ids)) sequences_str = self.tokenizer.decode(sequences) # Line 72-78: 计算得分 ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] compute_score_fn = _select_rm_score_fn(data_source) score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth) # Line 80: 奖励放在最后一个 token reward_tensor[i, valid_response_length - 1] = score return reward_tensor ``` ### 4. 检索服务器 **文件:** `search_r1/search/retrieval_server.py:326-358` ```python app = FastAPI() @app.post("/retrieve") # Line 326 def retrieve_endpoint(request: QueryRequest): """Endpoint for batch retrieval.""" if not request.topk: request.topk = config.retrieval_topk # Line 338 # Line 341-345: 批量检索 results, scores = retriever.batch_search( query_list=request.queries, num=request.topk, return_score=request.return_scores ) # Line 348-357: 格式化响应 resp = [] for i, single_result in enumerate(results): if request.return_scores: combined = [{"document": doc, "score": score} for doc, score in zip(single_result, scores[i])] resp.append(combined) return {"result": resp} ``` **Dense 检索器核心:** **文件:** `search_r1/search/retrieval_server.py:207-271` ```python class DenseRetriever(BaseRetriever): def __init__(self, config): # Line 210: 加载 FAISS 索引 self.index = faiss.read_index(self.index_path) if config.faiss_gpu: # Line 211-215: 转移到 GPU self.index = faiss.index_cpu_to_all_gpus(self.index) # Line 217: 加载语料库 self.corpus = load_corpus(self.corpus_path) # Line 218-224: 初始化编码器(E5/BGE) self.encoder = Encoder(...) def _batch_search(self, query_list: List[str], num: int = None): """Batch search with dense retriever.""" results = [] scores = [] # Line 249-254: 批量编码和搜索 for start_idx in range(0, len(query_list), self.batch_size): query_batch = query_list[start_idx:start_idx + self.batch_size] batch_emb = self.encoder.encode(query_batch) # 编码查询 batch_scores, batch_idxs = self.index.search(batch_emb, k=num) # FAISS 搜索 # Line 257-260: 加载文档 flat_idxs = sum(batch_idxs, []) batch_results = load_docs(self.corpus, flat_idxs) batch_results = [batch_results[i*num:(i+1)*num] for i in range(len(batch_idxs))] results.extend(batch_results) scores.extend(batch_scores) return results, scores ``` ### 5. 推理脚本核心循环 **文件:** `infer.py:82-129` ```python # Line 82-84: 定义停止条件 target_sequences = ["</search>", " </search>", ...] stopping_criteria = StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)]) # Line 94-128: 生成循环 while True: # Line 95-96: 编码输入 input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) # Line 99-107: 生成 outputs = model.generate( input_ids, max_new_tokens=1024, stopping_criteria=stopping_criteria, # 在 </search> 处停止 do_sample=True, temperature=0.7 ) # Line 109-113: 检查是否遇到 EOS if outputs[0][-1].item() in curr_eos: output_text = tokenizer.decode(outputs[0][input_ids.shape[1]:]) print(output_text) break # Line 115-123: 执行搜索 output_text = tokenizer.decode(outputs[0][input_ids.shape[1]:]) tmp_query = get_query(output_text) # 正则提取查询 if tmp_query: search_results = search(tmp_query) # 调用检索 API # Line 125-126: 更新 prompt search_text = curr_search_template.format( output_text=output_text, search_results=search_results ) prompt += search_text ``` --- ## 详细流程解析 ### 1. 训练流程(train_ppo.sh) #### 1.1 环境准备 ```bash # 设置 GPU export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 设置数据路径 export DATA_DIR='data/nq_search' # 选择基础模型 export BASE_MODEL='meta-llama/Llama-3.2-3B' ``` #### 1.2 启动训练 ```bash python3 -m verl.trainer.main_ppo \ data.train_files=$DATA_DIR/train.parquet \ actor_rollout_ref.model.path=$BASE_MODEL \ max_turns=2 \ retriever.url="http://127.0.0.1:8000/retrieve" ``` **关键参数解释:** - `max_turns=2`: 最多进行 2 轮搜索 - `max_start_length=2048`: 初始提示的最大长度 - `max_response_length=500`: 每次生成的最大长度 - `max_obs_length=500`: 搜索结果的最大长度 - `retriever.url`: 检索器 API 地址 - `retriever.topk=3`: 每次搜索返回前 3 个文档 #### 1.3 训练入口(verl/trainer/main_ppo.py) **核心组件:** ```python # 1. 奖励管理器 class RewardManager: def __call__(self, data: DataProto): # 对每个样本: # - 解码生成的序列 # - 提取答案 # - 与 ground_truth 比较 # - 计算 EM (Exact Match) 分数 score = compute_score_fn(solution_str, ground_truth) reward_tensor[i, valid_response_length - 1] = score return reward_tensor ``` **关键点:** - 奖励只在序列的**最后一个 token** 给出 - 使用 Exact Match (精确匹配) 作为奖励信号 - 支持多种数据集(NQ, HotpotQA, TriviaQA 等) **训练器初始化:** ```python trainer = RayPPOTrainer( config=config, tokenizer=tokenizer, role_worker_mapping=role_worker_mapping, reward_fn=reward_fn, val_reward_fn=val_reward_fn, ) trainer.init_workers() trainer.fit() ``` --- ### 2. LLM 生成核心流程(search_r1/llm_agent/generation.py) #### 2.1 LLMGenerationManager 类 这是整个项目最核心的类,管理 LLM 的多轮生成和搜索交互。 ```python class LLMGenerationManager: def __init__(self, tokenizer, actor_rollout_wg, config): self.tokenizer = tokenizer self.actor_rollout_wg = actor_rollout_wg # vLLM 生成器 self.config = config ``` #### 2.2 主生成循环(run_llm_loop) 这是最核心的方法,实现了推理-搜索的交互流程: ```python def run_llm_loop(self, gen_batch, initial_input_ids): # 初始化状态 active_mask = torch.ones(batch_size, dtype=torch.bool) # 跟踪哪些样本还在继续 turns_stats = torch.ones(batch_size, dtype=torch.int) # 记录每个样本的轮数 # 主循环:最多 max_turns 轮 for step in range(self.config.max_turns): # ===== 步骤 1: LLM 生成 ===== gen_output = self._generate_with_gpu_padding(rollings_active) responses_ids, responses_str = self._postprocess_responses(gen_output.batch['responses']) # ===== 步骤 2: 执行动作(搜索/回答) ===== next_obs, dones, valid_action, is_search = self.execute_predictions( responses_str, self.tokenizer.pad_token, active_mask ) # ===== 步骤 3: 更新状态 ===== # 更新 active_mask(标记已完成的样本) curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool) active_mask = active_mask * curr_active_mask # 更新上下文(拼接搜索结果) rollings = self._update_rolling_state(rollings, responses_ids, next_obs_ids) # 如果所有样本都完成了,退出循环 if not active_mask.sum(): break return final_output ``` **关键方法详解:** ##### (1) _postprocess_responses - 后处理生成文本 ```python def _postprocess_responses(self, responses): # 在 </search> 或 </answer> 处截断 responses_str = [ resp.split('</search>')[0] + '</search>' if '</search>' in resp else resp.split('</answer>')[0] + '</answer>' if '</answer>' in resp else resp for resp in responses_str ] return responses, responses_str ``` ##### (2) execute_predictions - 执行动作并获取观察 ```python def execute_predictions(self, predictions, pad_token, active_mask, do_search=True): # 解析动作(search/answer) cur_actions, contents = self.postprocess_predictions(predictions) # 批量搜索所有 search 动作 search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search'] if do_search: search_results = self.batch_search(search_queries) # 为每个样本生成观察 for i, (action, active) in enumerate(zip(cur_actions, active_mask)): if action == 'answer': next_obs.append('') # 结束,无观察 dones.append(1) elif action == 'search': next_obs.append(f'\n\n<information>{search_results.pop(0)}</information>\n\n') dones.append(0) # 继续 else: # 无效动作,给出提示 next_obs.append('\nMy previous action is invalid. ...') dones.append(0) return next_obs, dones, valid_action, is_search ``` ##### (3) postprocess_predictions - 解析动作 ```python def postprocess_predictions(self, predictions): pattern = r'<(search|answer)>(.*?)</\1>' for prediction in predictions: match = re.search(pattern, prediction, re.DOTALL) if match: content = match.group(2).strip() # 标签内的内容 action = match.group(1) # 'search' 或 'answer' else: content = '' action = None return actions, contents ``` ##### (4) batch_search - 批量搜索 ```python def batch_search(self, queries): payload = { "queries": queries, "topk": self.config.topk, "return_scores": True } results = requests.post(self.config.search_url, json=payload).json()['result'] return [self._passages2string(result) for result in results] def _passages2string(self, retrieval_result): format_reference = '' for idx, doc_item in enumerate(retrieval_result): content = doc_item['document']['contents'] title = content.split("\n")[0] text = "\n".join(content.split("\n")[1:]) format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" return format_reference ``` ##### (5) _update_rolling_state - 更新上下文 ```python def _update_rolling_state(self, rollings, cur_responses, next_obs_ids): # 拼接:原上下文 + 新生成 + 观察(搜索结果) new_input_ids = self.tensor_fn.concatenate_with_padding([ rollings.batch['input_ids'], cur_responses, next_obs_ids ]) # 创建 attention mask 和 position ids new_attention_mask = self.tensor_fn.create_attention_mask(new_input_ids) new_position_ids = self.tensor_fn.create_position_ids(new_attention_mask) # 截断到最大长度 max_len = min(self.config.max_prompt_length, effective_len) return DataProto.from_dict({ 'input_ids': new_input_ids[:, -max_len:], 'position_ids': new_position_ids[:, -max_len:], 'attention_mask': new_attention_mask[:, -max_len:] }) ``` --- ### 3. 检索服务器(search_r1/search/retrieval_server.py) #### 3.1 检索器类型 **BM25 检索器(稀疏检索):** ```python class BM25Retriever(BaseRetriever): def __init__(self, config): from pyserini.search.lucene import LuceneSearcher self.searcher = LuceneSearcher(self.index_path) self.corpus = load_corpus(self.corpus_path) def _search(self, query, num=None): hits = self.searcher.search(query, num) results = load_docs(self.corpus, [hit.docid for hit in hits]) return results ``` **Dense 检索器(密集检索):** ```python class DenseRetriever(BaseRetriever): def __init__(self, config): self.index = faiss.read_index(self.index_path) # FAISS 索引 if config.faiss_gpu: self.index = faiss.index_cpu_to_all_gpus(self.index) self.corpus = load_corpus(self.corpus_path) self.encoder = Encoder(...) # E5/BGE 等编码器 def _search(self, query, num=None): query_emb = self.encoder.encode(query) # 编码查询 scores, idxs = self.index.search(query_emb, k=num) # FAISS 搜索 results = load_docs(self.corpus, idxs[0]) return results ``` #### 3.2 FastAPI 服务器 ```python app = FastAPI() @app.post("/retrieve") def retrieve_endpoint(request: QueryRequest): # 批量检索 results, scores = retriever.batch_search( query_list=request.queries, num=request.topk, return_score=request.return_scores ) # 格式化响应 resp = [] for i, single_result in enumerate(results): combined = [{"document": doc, "score": score} for doc, score in zip(single_result, scores[i])] resp.append(combined) return {"result": resp} # 启动服务器 uvicorn.run(app, host="0.0.0.0", port=8000) ``` **启动方式(retrieval_launch.sh):** ```bash conda activate retriever python search_r1/search/retrieval_server.py \ --index_path /path/to/e5_Flat.index \ --corpus_path /path/to/wiki-18.jsonl \ --topk 3 \ --retriever_name e5 \ --retriever_model intfloat/e5-base-v2 \ --faiss_gpu ``` --- ### 4. 推理流程(infer.py) 训练完成后,使用推理脚本测试模型。 ```python # 1. 加载模型 model_id = "PeterJinGo/SearchR1-nq_hotpotqa_train-qwen2.5-7b-em-ppo" tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) model = transformers.AutoModelForCausalLM.from_pretrained(model_id) # 2. 构造提示 question = "What is the capital of France?" prompt = f"""Answer the given question. \ You must conduct reasoning inside <think> and </think> first. \ If you lack knowledge, call search engine by <search> query </search>. \ If no further knowledge needed, provide answer in <answer> answer </answer>. Question: {question} """ # 3. 定义停止条件(在 </search> 处停止) class StopOnSequence(transformers.StoppingCriteria): def __init__(self, target_sequences, tokenizer): self.target_ids = [tokenizer.encode(seq, add_special_tokens=False) for seq in target_sequences] def __call__(self, input_ids, scores, **kwargs): for i, target in enumerate(self.targets): if torch.equal(input_ids[0, -len(target):], target): return True return False stopping_criteria = StoppingCriteriaList([ StopOnSequence(["</search>", "</search>\n"], tokenizer) ]) # 4. 生成循环 while True: input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) # 生成 outputs = model.generate( input_ids, max_new_tokens=1024, stopping_criteria=stopping_criteria, do_sample=True, temperature=0.7 ) # 如果遇到 EOS,结束 if outputs[0][-1].item() in curr_eos: output_text = tokenizer.decode(outputs[0][input_ids.shape[1]:]) print(output_text) break # 否则,执行搜索 output_text = tokenizer.decode(outputs[0][input_ids.shape[1]:]) query = get_query(output_text) # 提取 <search>...</search> if query: search_results = search(query) # 调用检索 API prompt += f'\n\n<information>{search_results}</information>\n\n' ``` --- ## 数据格式 ### 训练数据格式 ```python data = { "data_source": "nq", # 数据集名称 "prompt": [{ "role": "user", "content": question, }], "ability": "fact-reasoning", "reward_model": { "style": "rule", # 基于规则的奖励 "ground_truth": solution # 正确答案 }, "extra_info": { 'split': 'train', 'index': 0, } } ``` ### 语料库格式(corpus.jsonl) ```json {"id": "0", "contents": "\"Evan Morris\"\nEvan L. Morris was a lobbyist..."} {"id": "1", "contents": "\"Paris\"\nParis is the capital of France..."} ``` --- ## 关键概念解释 ### 1. Prompt 格式 模型接收的提示包含特殊标签: - `<think>`: 推理过程 - `<search>query</search>`: 搜索动作 - `<information>docs</information>`: 搜索结果 - `<answer>answer</answer>`: 最终答案 **示例:** ``` Question: What is the capital of France? <think> I need to search for information about France's capital. </think> <search>capital of France</search> <information> Doc 1(Title: France) Paris is the capital and largest city of France... </information> <think> Based on the search results, the capital of France is Paris. </think> <answer>Paris</answer> ``` ### 2. 奖励机制 - **Exact Match (EM)**: 答案与 ground truth 完全匹配时奖励为 1,否则为 0 - 奖励只在序列的**最后一个 token** 给出(稀疏奖励) - 通过 PPO 学习何时搜索、如何搜索 ### 3. 多轮交互 - `max_turns=2`: 最多 2 轮搜索 - 每轮: 生成 → 解析动作 → 执行 → 获取观察 → 更新上下文 - 使用 `active_mask` 跟踪哪些样本还在继续生成 ### 4. 信息屏蔽(info_mask) - 训练时,搜索结果部分(`<information>...</information>`)被屏蔽 - 只对模型自己生成的部分(推理、搜索、答案)进行 RL 更新 - 避免对外部知识进行梯度更新 --- ## 训练技巧 ### 1. 分布式训练 - 使用 Ray + FSDP/Megatron 进行分布式训练 - 支持多节点训练(30B+ 模型) - GPU 内存优化:offload 参数、梯度、优化器状态 ### 2. 超参数 ```bash actor_rollout_ref.actor.optim.lr=1e-6 # Actor 学习率 critic.optim.lr=1e-5 # Critic 学习率 algorithm.kl_ctrl.kl_coef=0.001 # KL 散度系数 actor_rollout_ref.actor.ppo_mini_batch_size=256 actor_rollout_ref.rollout.temperature=1 # 生成温度 ``` ### 3. 数据并行 - `data.train_batch_size=512`: 每次训练 512 个样本 - `actor_rollout_ref.rollout.n_agent=1`: 使用 1 个 rollout agent --- ## 常见问题 ### Q1: 如何更换检索器? 修改 `retrieval_server.py` 中的参数: ```bash --retriever_name bge # bm25/e5/bge --retriever_model BAAI/bge-base-en-v1.5 ``` ### Q2: 如何使用在线搜索引擎? 使用 `google_search_server.py` 或 `serp_search_server.py`: ```bash python search_r1/search/google_search_server.py --api_key YOUR_KEY ``` ### Q3: 如何调整搜索轮数? 修改 `train_ppo.sh`: ```bash max_turns=3 # 允许最多 3 轮搜索 ``` ### Q4: 如何添加自定义数据集? 1. 在 `scripts/data_process/` 中创建处理脚本 2. 在 `verl/trainer/main_ppo.py` 中添加数据源和评分函数 3. 准备语料库(jsonl 格式) --- ## 进阶阅读 ### 相关论文 1. [Search-R1 Paper 1](https://arxiv.org/abs/2503.09516) 2. [Search-R1 Paper 2](https://arxiv.org/abs/2505.15117) 3. DeepSeek-R1 4. TinyZero ### 代码参考 - [veRL](https://github.com/volcengine/verl) - RL 框架 - [RAGEN](https://github.com/ZihanWang314/RAGEN) - RAG 与 RL - [vLLM](https://github.com/vllm-project/vllm) - 推理引擎 --- ## 实验日志 - [v0.1 实验日志](https://wandb.ai/peterjin/Search-R1-nq_hotpotqa_train) - [v0.2 实验日志](https://wandb.ai/peterjin/Search-R1-v0.2) - [v0.3 实验日志](https://wandb.ai/peterjin/Search-R1-v0.3) --- ## 快速开始 ### 完整流程(5 步) ```bash # 1. 安装环境 conda create -n searchr1 python=3.9 conda activate searchr1 pip install torch==2.4.0 pip install vllm==0.6.3 pip install -e . # 2. 下载数据 save_path=/path/to/save python scripts/download.py --save_path $save_path cat $save_path/part_* > $save_path/e5_Flat.index gzip -d $save_path/wiki-18.jsonl.gz # 3. 处理数据 python scripts/data_process/nq_search.py # 4. 启动检索服务器 conda activate retriever bash retrieval_launch.sh # 5. 开始训练 conda activate searchr1 bash train_ppo.sh ``` --- ## 总结 Search-R1 通过强化学习让 LLM 学会: 1. **自主决策**: 何时需要搜索外部知识 2. **工具使用**: 如何构造查询调用搜索引擎 3. **信息整合**: 如何利用搜索结果进行推理 4. **端到端**: 从原始问题到最终答案的完整流程 核心创新在于将**推理**和**搜索**交织在一起,通过稀疏奖励信号(答案正确性)让模型自己探索最佳策略,而不需要监督学习标注何时搜索。
Python
赞
博客信息
作者
eeettt
发布日期
2026-03-31
其他信息 : 其他三字母的人名首字母都是其他同学发布的哦