Files
clutch/src/etl/extract_snapshots.py
2026-02-12 16:32:45 +08:00

368 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
L1B 快照引擎 (Parquet 版本)
这是第一阶段 (Phase 1) 的核心 ETL 脚本。
它负责从 CS2 .dem 文件中提取 Tick 级别的快照,并将其保存为高压缩率的 Parquet 文件。
用法:
python src/etl/extract_snapshots.py --demo_dir data/demos --output_dir data/processed
配置:
调整下方的参数以控制数据粒度
"""
import os
import argparse
import pandas as pd
import numpy as np
from demoparser2 import DemoParser # 核心依赖
import logging
import sys
# ==============================================================================
# ⚙️ 配置与调优参数 (可修改参数区)
# ==============================================================================
# [重要] 采样率
# 多久截取一次快照?
# 较低值 = 数据更多,精度更高,处理更慢。
# 较高值 = 数据更少,处理更快。
SNAPSHOT_INTERVAL_SECONDS = 2 # 👈 建议值: 1-5秒 (默认: 2s)
# [重要] 回合过滤器
# 包含哪些回合?
# 'clutch_only': 仅保留发生残局 (<= 3v3) 的回合。
# 'all': 保留所有回合 (数据集会非常巨大)。
FILTER_MODE = 'clutch_only' # 👈 选项: 'all' | 'clutch_only'
# [重要] 残局定义
# 什么样的局面算作“残局”?
MAX_PLAYERS_PER_TEAM = 2 # 👈 建议值: 2 (意味着 <= 2vX 或 Xv2)
# 字段选择 (用于优化)
# 仅从 demo 中提取这些字段以节省内存
WANTED_FIELDS = [
"game_time", # 游戏时间
"team_num", # 队伍编号
"player_name", # 玩家昵称
"steamid", # Steam ID
"X", "Y", "Z", # 坐标位置
"view_X", "view_Y", # 视角角度
"health", # 生命值
"armor_value", # 护甲值
"has_defuser", # 是否有拆弹钳
"has_helmet", # 是否有头盔
"active_weapon_name", # 当前手持武器
"flash_duration", # 致盲持续时间 (是否被白)
"is_alive", # 是否存活
"balance" # [NEW] 剩余金钱 (Correct field name)
]
# ==============================================================================
# 配置结束
# ==============================================================================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def is_clutch_situation(ct_alive, t_alive):
"""
检查当前状态是否符合“残局”场景。
条件: 至少有一方队伍的存活人数 <= MAX_PLAYERS_PER_TEAM。
(例如: 2v5 对于剩2人的那队来说就是残局)
"""
if ct_alive == 0 or t_alive == 0:
return False
# 用户需求: "对面有几个人都无所谓,只要一方剩两个人"
# 含义: 如果 CT <= N 或者 T <= N即视为残局。
is_ct_clutch = (ct_alive <= MAX_PLAYERS_PER_TEAM)
is_t_clutch = (t_alive <= MAX_PLAYERS_PER_TEAM)
return is_ct_clutch or is_t_clutch
def process_demo(demo_path, output_dir, delete_source=False):
"""
解析单个 .dem 文件并将快照导出为 Parquet 格式。
"""
demo_name = os.path.basename(demo_path).replace('.dem', '')
output_path = os.path.join(output_dir, f"{demo_name}.parquet")
if os.path.exists(output_path):
logging.info(f"跳过 {demo_name}, 文件已存在。")
if delete_source:
try:
os.remove(demo_path)
logging.info(f"已删除源文件 (因为已存在处理结果): {demo_path}")
except Exception as e:
logging.warning(f"删除源文件失败: {e}")
return
logging.info(f"正在处理: {demo_name}")
try:
parser = DemoParser(demo_path)
# 1. 解析元数据 (地图, 头部信息)
header = parser.parse_header()
map_name = header.get("map_name", "unknown")
# 2. 提取事件 (回合开始/结束, 炸弹) 以识别回合边界
# [修复] 解析 round_start 事件以获取 round 信息,解决 KeyError: 'round'
# [新增] 解析 round_end 事件以获取 round_winner 信息
# [新增] 解析 bomb 事件以获取 is_bomb_planted 和 bomb_site
event_names = ["round_start", "round_end", "bomb_planted", "bomb_defused", "bomb_exploded"]
parsed_events = parser.parse_events(event_names)
round_df = None
winner_df = None
bomb_events = []
# parse_events 返回 [(event_name, df), ...]
for event_name, event_data in parsed_events:
if event_name == "round_start":
round_df = event_data
elif event_name == "round_end":
winner_df = event_data
elif event_name in ["bomb_planted", "bomb_defused", "bomb_exploded"]:
# 统一处理炸弹事件
# bomb_planted 有 site 字段
# 其他可能没有,需要填充
temp_df = event_data.copy()
temp_df['event_type'] = event_name
if 'site' not in temp_df.columns:
temp_df['site'] = 0
bomb_events.append(temp_df[['tick', 'event_type', 'site']])
# 3. 提取玩家状态 (繁重的工作)
# 我们先获取所有 Tick 的数据,然后再进行过滤
df = parser.parse_ticks(WANTED_FIELDS)
# [修复] 将 Round 信息合并到 DataFrame
if round_df is not None and not round_df.empty:
# 确保按 tick 排序
round_df = round_df.sort_values('tick')
df = df.sort_values('tick')
# 使用 merge_asof 将最近的 round_start 匹配给每个 tick
# direction='backward' 意味着找 tick <= 当前tick 的最近一次 round_start
df = pd.merge_asof(df, round_df[['tick', 'round']], on='tick', direction='backward')
# 填充 NaN (比赛开始前的 tick) 为 0
df['round'] = df['round'].fillna(0).astype(int)
else:
logging.warning(f"{demo_name} 中未找到 round_start 事件,默认为第 1 回合")
df['round'] = 1
# [新增] 将 Winner 信息合并到 DataFrame
if winner_df is not None and not winner_df.empty:
# winner_df 包含 'round' 和 'winner'
# 这里的 'round' 是结束的回合号。
# 我们直接将 winner 映射到 df 中的 round 列
# 清洗 winner 数据 (T -> 0, CT -> 1)
# 注意: demoparser2 返回的 winner 可能是 int (2/3) 也可能是 str ('T'/'CT')
# 我们先统一转为字符串处理
winner_map = df[['round']].copy().drop_duplicates()
# 建立 round -> winner 字典
# 过滤无效的 winner
valid_winners = winner_df.dropna(subset=['winner'])
round_winner_dict = {}
for _, row in valid_winners.iterrows():
r = row['round']
w = row['winner']
if w == 'T' or w == 2:
round_winner_dict[r] = 0 # T wins
elif w == 'CT' or w == 3:
round_winner_dict[r] = 1 # CT wins
# 映射到主 DataFrame
df['round_winner'] = df['round'].map(round_winner_dict)
# 移除没有结果的回合 (例如 warmup 或未结束的回合)
# df = df.dropna(subset=['round_winner']) # 暂时保留,由后续步骤决定是否丢弃
else:
logging.warning(f"{demo_name} 中未找到 round_end 事件,无法标记胜者")
df['round_winner'] = None
# [新增] 合并炸弹状态 (is_bomb_planted)
if bomb_events:
bomb_df = pd.concat(bomb_events).sort_values('tick')
# 逻辑:
# bomb_planted -> is_planted=1, site=X
# bomb_defused/exploded -> is_planted=0, site=0
# round_start/end -> 也可以作为重置点 (state=0),但我们没有把它们放入 bomb_events
# 我们假设 round_start 时炸弹肯定没下,但 merge_asof 会延续上一个状态
# 所以我们需要把 round_start 也加入作为重置事件
if round_df is not None:
reset_df = round_df[['tick']].copy()
reset_df['event_type'] = 'reset'
reset_df['site'] = 0
bomb_df = pd.concat([bomb_df, reset_df]).sort_values('tick')
# 计算状态
# 1 = Planted, 0 = Not Planted
bomb_df['is_bomb_planted'] = bomb_df['event_type'].apply(lambda x: 1 if x == 'bomb_planted' else 0)
# site 已经在 bomb_planted 事件中有值,其他为 0
# 使用 merge_asof 传播状态
# 注意bomb_df 可能有同一 tick 多个事件merge_asof 取最后一个
# 所以我们要确保排序正确 (reset 应该在 planted 之前reset 是 round_start肯定在 planted 之前)
# 只需要 tick, is_bomb_planted, site
state_df = bomb_df[['tick', 'is_bomb_planted', 'site']].copy()
df = pd.merge_asof(df, state_df, on='tick', direction='backward')
# 填充 NaN 为 0 (未下包)
df['is_bomb_planted'] = df['is_bomb_planted'].fillna(0).astype(int)
df['site'] = df['site'].fillna(0).astype(int)
else:
df['is_bomb_planted'] = 0
df['site'] = 0
# 4. 数据清洗与优化
# 将 team_num 转换为 int (CT=3, T=2)
df['team_num'] = df['team_num'].fillna(0).astype(int)
# 5. 应用采样间隔过滤器
# 我们不需要每一帧 (128/s),而是每 N 秒取一帧
# 近似计算: tick_rate 大约是 64 或 128。
# 我们使用 'game_time' 来过滤。
df['time_bin'] = (df['game_time'] // SNAPSHOT_INTERVAL_SECONDS).astype(int)
# [修复] 采样逻辑优化:找出每个 (round, time_bin) 的起始 tick保留该 tick 的所有玩家数据
# 旧逻辑 groupby().first() 会丢失其他玩家数据
bin_start_ticks = df.groupby(['round', 'time_bin'])['tick'].min()
selected_ticks = bin_start_ticks.values
# 提取快照 (包含被选中 tick 的所有玩家行)
snapshot_df = df[df['tick'].isin(selected_ticks)].copy()
# 6. 应用残局逻辑过滤器
if FILTER_MODE == 'clutch_only':
# 我们需要计算每一帧各队的存活人数
# snapshot_df 已经是采样后的数据 (每个 tick 包含所有玩家)
# 高效的存活人数计算:
alive_counts = snapshot_df[snapshot_df['is_alive'] == True].groupby(['round', 'time_bin', 'team_num']).size().unstack(fill_value=0)
# 确保列存在 (2=T, 3=CT)
if 2 not in alive_counts.columns: alive_counts[2] = 0
if 3 not in alive_counts.columns: alive_counts[3] = 0
# 过滤出满足残局条件的帧
# alive_counts 的索引是 (round, time_bin)
clutch_mask = [is_clutch_situation(row[3], row[2]) for index, row in alive_counts.iterrows()]
valid_indices = alive_counts[clutch_mask].index
# 过滤主 DataFrame
# 构建一个复合键用于快速过滤
snapshot_df['frame_id'] = list(zip(snapshot_df['round'], snapshot_df['time_bin']))
valid_frame_ids = set(valid_indices)
snapshot_df = snapshot_df[snapshot_df['frame_id'].isin(valid_frame_ids)].copy()
snapshot_df.drop(columns=['frame_id'], inplace=True)
if snapshot_df.empty:
logging.warning(f"{demo_name} 中未找到有效快照 (过滤器: {FILTER_MODE})")
return
# 7. 添加元数据
snapshot_df['match_id'] = demo_name
snapshot_df['map_name'] = map_name
# [优化] 数据类型降维与压缩
# 这一步能显著减少内存占用和文件体积
# Float64 -> Float32
float_cols = ['X', 'Y', 'Z', 'view_X', 'view_Y', 'game_time', 'flash_duration']
for col in float_cols:
if col in snapshot_df.columns:
snapshot_df[col] = snapshot_df[col].astype('float32')
# Int64 -> Int8/Int16
# team_num: 2 or 3 -> int8
snapshot_df['team_num'] = snapshot_df['team_num'].astype('int8')
# health, armor: 0-100 -> int16 (uint8 也可以但 pandas 对 uint 支持有时候有坑)
for col in ['health', 'armor_value', 'balance', 'site']:
if col in snapshot_df.columns:
snapshot_df[col] = snapshot_df[col].fillna(0).astype('int16')
# round, tick: int32 (enough for millions)
snapshot_df['round'] = snapshot_df['round'].astype('int16')
snapshot_df['tick'] = snapshot_df['tick'].astype('int32')
# Booleans -> int8 or bool
bool_cols = ['is_alive', 'has_defuser', 'has_helmet', 'is_bomb_planted']
for col in bool_cols:
if col in snapshot_df.columns:
snapshot_df[col] = snapshot_df[col].astype('int8') # 0/1 is better for ML sometimes
# Drop redundant columns
if 'time_bin' in snapshot_df.columns:
snapshot_df.drop(columns=['time_bin'], inplace=True)
# 8. 保存为 Parquet (L1B 层)
# 使用 zstd 压缩算法,通常比 snappy 压缩率高 30-50%
snapshot_df.to_parquet(output_path, index=False, compression='zstd')
logging.info(f"已保存 {len(snapshot_df)} 条快照到 {output_path} (压缩模式: ZSTD)")
# [NEW] 删除源文件逻辑
if delete_source:
try:
os.remove(demo_path)
logging.info(f"处理成功,已删除源文件: {demo_path}")
except Exception as e:
logging.warning(f"删除源文件失败: {e}")
except Exception as e:
logging.error(f"处理失败 {demo_name}: {str(e)}")
# 如果是 Source 1 错误,给予明确提示
if "Source1" in str(e):
logging.error("❌ 这是一个 CS:GO (Source 1) 的 Demo本系统仅支持 CS2 (Source 2) Demo。")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(description="L1B 快照引擎")
parser.add_argument('--demo_dir', type=str, default='data/demos', help='输入 .dem 文件的目录')
parser.add_argument('--file', type=str, help='处理单个 .dem 文件 (如果指定此参数,将忽略 --demo_dir)')
parser.add_argument('--output_dir', type=str, default='data/processed', help='输出 .parquet 文件的目录')
parser.add_argument('--delete-source', action='store_true', help='处理成功后删除源文件')
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# 模式 1: 单文件处理
if args.file:
if not os.path.exists(args.file):
logging.error(f"文件不存在: {args.file}")
return
if not args.file.endswith('.dem'):
logging.error(f"无效的文件扩展名: {args.file}")
return
process_demo(args.file, args.output_dir, delete_source=args.delete_source)
return
# 模式 2: 目录批处理
if not os.path.exists(args.demo_dir):
logging.warning(f"目录不存在: {args.demo_dir}")
return
demo_files = [os.path.join(args.demo_dir, f) for f in os.listdir(args.demo_dir) if f.endswith('.dem')]
if not demo_files:
logging.warning(f"{args.demo_dir} 中未找到 .dem 文件。请添加 demo 文件。")
return
for demo_path in demo_files:
process_demo(demo_path, args.output_dir, delete_source=args.delete_source)
if __name__ == "__main__":
main()