feat: Optimize LSTM with Attention, add Stacking Ensemble and SHAP analysis
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -17,6 +17,10 @@ database/**/*.db
|
||||
# Local demo snapshots (large)
|
||||
data/processed/
|
||||
data/demos/
|
||||
data/sequences/
|
||||
|
||||
# Reports and Artifacts
|
||||
reports/
|
||||
|
||||
# Local downloads / raw captures
|
||||
output_arena/
|
||||
|
||||
64
PROJECT_REVIEW.md
Normal file
64
PROJECT_REVIEW.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Clutch-IQ 项目评测报告
|
||||
|
||||
## 1. 项目概览 (Overview)
|
||||
**Clutch-IQ** 是一个针对 CS2 (Counter-Strike 2) 的实时胜率预测系统。项目实现了从原始 Demo 解析、特征提取、模型训练到实时推理的全链路工程。
|
||||
|
||||
* **完成度**: A- (核心功能闭环,MVP 阶段完成)
|
||||
* **技术栈**: Python, Pandas, XGBoost, Flask, Streamlit, Demoparser2
|
||||
* **应用场景**: 战术分析、直播推流、实时辅助
|
||||
|
||||
---
|
||||
|
||||
## 2. 深度评测 (Detailed Review)
|
||||
|
||||
### ✅ 亮点 (Strengths)
|
||||
|
||||
#### 1. 优秀的工程化架构
|
||||
项目没有将所有代码堆在一个文件里,而是采用了清晰的分层架构:
|
||||
* **ETL 层 (`src/etl/`)**: 实现了流式处理 (`auto_pipeline.py`),巧妙解决了海量 Demo 的存储痛点,这在个人项目中是非常亮眼的工程实践。
|
||||
* **特征层 (`src/features/`)**: 将特征定义抽离为 `definitions.py`,确保了训练和推理使用同一套标准,有效防止了**训练-推理偏差 (Training-Serving Skew)**。
|
||||
* **模型层 (`src/training/`)**: 实现了训练与验证的解耦 (`train.py` vs `evaluate.py`),并且严格执行了 Match-Level Splitting,避免了时序数据的泄露问题,体现了扎实的机器学习素养。
|
||||
|
||||
#### 2. 扎实的特征工程
|
||||
不仅使用了基础的 K/D 数据,还引入了计算几何概念:
|
||||
* **空间特征**: `t_area` (凸包面积) 和 `pincer_index` (夹击指数) 是非常高级的特征,能够捕捉“枪法”之外的“战术”维度。
|
||||
* **经济特征**: 细分了装备价值和现金流,能准确判断 Eco 局与长枪局的优劣势。
|
||||
|
||||
#### 3. 完整的文档体系
|
||||
项目包含 `AI_FULL_STACK_GUIDE.md` 和 `PROJECT_DEEP_DIVE.md`,不仅有代码,还有思考过程和理论支撑。这对后续维护和团队协作至关重要。
|
||||
|
||||
### ⚠️ 改进空间 (Areas for Improvement)
|
||||
|
||||
#### 1. 测试覆盖率 (Test Coverage)
|
||||
* **现状**: 虽然有 `evaluate.py` 验证模型效果,但缺乏针对各个函数的**单元测试 (Unit Tests)**。
|
||||
* **风险**: 如果修改了 `spatial.py` 里的凸包算法,可能会悄悄破坏特征计算逻辑,而直到模型准确率下降才能发现。
|
||||
* **建议**: 引入 `pytest`,为特征计算函数编写测试用例。
|
||||
|
||||
#### 2. 配置管理 (Configuration Management)
|
||||
* **现状**: 部分路径和参数(如文件路径、模型参数)可能硬编码在代码中。
|
||||
* **建议**: 引入 `config.yaml` 或 `.env` 管理所有可变参数,使项目更容易在不同机器上部署。
|
||||
|
||||
#### 3. 异常处理与日志 (Robustness)
|
||||
* **现状**: 虽然有基础的 logging,但在高并发场景下(如 GSI 频繁推送),`inference/app.py` 的健壮性还需加强。
|
||||
* **建议**: 增加请求队列机制,防止瞬间流量冲垮推理服务。
|
||||
|
||||
---
|
||||
|
||||
## 3. 综合评分 (Scoring)
|
||||
|
||||
| 维度 | 评分 (1-10) | 评价 |
|
||||
| :--- | :---: | :--- |
|
||||
| **架构设计** | **9.0** | 模块清晰,流式处理是加分项。 |
|
||||
| **代码质量** | **8.5** | 风格统一,可读性强,函数封装合理。 |
|
||||
| **算法深度** | **8.0** | XGBoost 选型准确,特征有新意,仍有提升空间(如时序模型)。 |
|
||||
| **完成度** | **8.5** | 核心闭环已跑通,文档齐全。 |
|
||||
| **创新性** | **7.5** | 空间特征的应用是亮点,但整体属于经典 ML 范式。 |
|
||||
|
||||
### 🏆 总评: 优秀 (Excellent)
|
||||
|
||||
**Clutch-IQ** 是一个具备**生产级潜质**的个人项目。它超越了普通的“Demo 代码”,展现了完整的全栈 AI 工程思维。特别是对存储限制的优化和防数据泄露的处理,显示了开发者对实际工程问题的深刻理解。
|
||||
|
||||
**下一步建议**:
|
||||
1. **容器化**: 编写 `Dockerfile`,一键部署环境。
|
||||
2. **可视化增强**: 优化 Dashboard,增加特征重要性解释图表(SHAP plots)。
|
||||
3. **实战接入**: 完成 GSI 配置,真正在游戏中跑起来。
|
||||
BIN
models/clutch_attention_lstm_v1.pth
Normal file
BIN
models/clutch_attention_lstm_v1.pth
Normal file
Binary file not shown.
BIN
models/clutch_lstm_v1.pth
Normal file
BIN
models/clutch_lstm_v1.pth
Normal file
Binary file not shown.
90
src/analysis/ensemble_analysis.py
Normal file
90
src/analysis/ensemble_analysis.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import xgboost as xgb
|
||||
from sklearn.metrics import accuracy_score, log_loss
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
from src.training.models import ClutchAttentionLSTM
|
||||
from src.features.definitions import FEATURE_COLUMNS
|
||||
from src.inference.stacking_ensemble import StackingEnsemble
|
||||
|
||||
# Configuration
|
||||
XGB_MODEL_PATH = "models/clutch_model_v1.json"
|
||||
LSTM_MODEL_PATH = "models/clutch_attention_lstm_v1.pth"
|
||||
TEST_DATA_PATH = "data/processed/test_set.parquet"
|
||||
|
||||
def analyze_ensemble():
|
||||
if not os.path.exists(TEST_DATA_PATH):
|
||||
print("Test data not found.")
|
||||
return
|
||||
|
||||
print(f"Loading data from {TEST_DATA_PATH}...")
|
||||
df = pd.read_parquet(TEST_DATA_PATH)
|
||||
y = df['round_winner'].values
|
||||
|
||||
# Initialize Ensemble (to reuse get_base_predictions)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
ensemble = StackingEnsemble(XGB_MODEL_PATH, LSTM_MODEL_PATH, device)
|
||||
|
||||
print("Generating predictions...")
|
||||
# Get base model predictions
|
||||
meta_features = ensemble.get_base_predictions(df)
|
||||
prob_xgb = meta_features['prob_xgb'].values
|
||||
prob_lstm = meta_features['prob_lstm'].values
|
||||
|
||||
# 1. Correlation Analysis
|
||||
correlation = np.corrcoef(prob_xgb, prob_lstm)[0, 1]
|
||||
print(f"\n[Correlation Analysis]")
|
||||
print(f"Correlation between XGBoost and LSTM predictions: {correlation:.4f}")
|
||||
|
||||
# 2. Performance Comparison (Log Loss & Accuracy)
|
||||
acc_xgb = accuracy_score(y, (prob_xgb > 0.5).astype(int))
|
||||
ll_xgb = log_loss(y, prob_xgb)
|
||||
|
||||
acc_lstm = accuracy_score(y, (prob_lstm > 0.5).astype(int))
|
||||
ll_lstm = log_loss(y, prob_lstm)
|
||||
|
||||
print(f"\n[Performance Comparison]")
|
||||
print(f"XGBoost - Acc: {acc_xgb:.2%}, LogLoss: {ll_xgb:.4f}")
|
||||
print(f"LSTM - Acc: {acc_lstm:.2%}, LogLoss: {ll_lstm:.4f}")
|
||||
|
||||
# 3. Disagreement Analysis
|
||||
# Where do they disagree?
|
||||
pred_xgb = (prob_xgb > 0.5).astype(int)
|
||||
pred_lstm = (prob_lstm > 0.5).astype(int)
|
||||
|
||||
disagreement_mask = pred_xgb != pred_lstm
|
||||
disagreement_count = np.sum(disagreement_mask)
|
||||
print(f"\n[Disagreement Analysis]")
|
||||
print(f"Models disagree on {disagreement_count} / {len(df)} samples ({disagreement_count/len(df):.2%})")
|
||||
|
||||
if disagreement_count > 0:
|
||||
# Who is right when they disagree?
|
||||
disagreements = df[disagreement_mask].copy()
|
||||
y_disagree = y[disagreement_mask]
|
||||
pred_xgb_disagree = pred_xgb[disagreement_mask]
|
||||
pred_lstm_disagree = pred_lstm[disagreement_mask]
|
||||
|
||||
xgb_correct = np.sum(pred_xgb_disagree == y_disagree)
|
||||
lstm_correct = np.sum(pred_lstm_disagree == y_disagree)
|
||||
|
||||
print(f"In disagreement cases:")
|
||||
print(f" XGBoost correct: {xgb_correct} times")
|
||||
print(f" LSTM correct: {lstm_correct} times")
|
||||
|
||||
# Show a few examples
|
||||
print("\nExample Disagreements:")
|
||||
disagreements['prob_xgb'] = prob_xgb[disagreement_mask]
|
||||
disagreements['prob_lstm'] = prob_lstm[disagreement_mask]
|
||||
disagreements['actual'] = y_disagree
|
||||
print(disagreements[['round', 'tick', 'prob_xgb', 'prob_lstm', 'actual']].head(5))
|
||||
|
||||
if __name__ == "__main__":
|
||||
analyze_ensemble()
|
||||
95
src/analysis/explain_model_shap.py
Normal file
95
src/analysis/explain_model_shap.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import shap
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
from src.features.definitions import FEATURE_COLUMNS
|
||||
|
||||
# Configuration
|
||||
MODEL_PATH = "models/clutch_model_v1.json"
|
||||
TEST_DATA_PATH = "data/processed/test_set.parquet"
|
||||
REPORT_DIR = "reports"
|
||||
|
||||
def explain_model():
|
||||
# 1. Ensure Report Directory Exists
|
||||
if not os.path.exists(REPORT_DIR):
|
||||
os.makedirs(REPORT_DIR)
|
||||
|
||||
# 2. Load Data
|
||||
if not os.path.exists(TEST_DATA_PATH):
|
||||
print(f"Error: Test data not found at {TEST_DATA_PATH}")
|
||||
return
|
||||
|
||||
print(f"Loading test data from {TEST_DATA_PATH}...")
|
||||
df = pd.read_parquet(TEST_DATA_PATH)
|
||||
X = df[FEATURE_COLUMNS]
|
||||
y = df['round_winner']
|
||||
|
||||
# 3. Load Model
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
print(f"Error: Model not found at {MODEL_PATH}")
|
||||
return
|
||||
|
||||
print(f"Loading XGBoost model from {MODEL_PATH}...")
|
||||
model = xgb.XGBClassifier()
|
||||
model.load_model(MODEL_PATH)
|
||||
|
||||
# 4. Calculate SHAP Values
|
||||
print("Calculating SHAP values (this may take a moment)...")
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values = explainer.shap_values(X)
|
||||
|
||||
# 5. Generate Summary Plot
|
||||
print(f"Generating SHAP Summary Plot...")
|
||||
plt.figure(figsize=(10, 8))
|
||||
shap.summary_plot(shap_values, X, show=False)
|
||||
|
||||
summary_plot_path = os.path.join(REPORT_DIR, "shap_summary_v1.png")
|
||||
plt.savefig(summary_plot_path, bbox_inches='tight', dpi=300)
|
||||
plt.close()
|
||||
print(f"Saved summary plot to: {summary_plot_path}")
|
||||
|
||||
# 6. Generate Bar Plot (Global Feature Importance)
|
||||
print(f"Generating SHAP Bar Plot...")
|
||||
plt.figure(figsize=(10, 8))
|
||||
shap.summary_plot(shap_values, X, plot_type="bar", show=False)
|
||||
|
||||
bar_plot_path = os.path.join(REPORT_DIR, "shap_importance_v1.png")
|
||||
plt.savefig(bar_plot_path, bbox_inches='tight', dpi=300)
|
||||
plt.close()
|
||||
print(f"Saved importance plot to: {bar_plot_path}")
|
||||
|
||||
# 7. Interview Insights Generation
|
||||
print("\n" + "="*50)
|
||||
print(" DATA ANALYST INTERVIEW INSIGHTS ")
|
||||
print("="*50)
|
||||
|
||||
# Calculate mean absolute SHAP values for importance
|
||||
mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
|
||||
feature_importance = pd.DataFrame({
|
||||
'feature': FEATURE_COLUMNS,
|
||||
'importance': mean_abs_shap
|
||||
}).sort_values('importance', ascending=False)
|
||||
|
||||
top_3 = feature_importance.head(3)
|
||||
|
||||
print("When an interviewer asks: 'What drives your model's predictions?'")
|
||||
print("You can answer based on this data:")
|
||||
print(f"1. The most critical factor is '{top_3.iloc[0]['feature']}'.")
|
||||
print(f" (Impact Score: {top_3.iloc[0]['importance']:.4f})")
|
||||
print(f"2. Followed by '{top_3.iloc[1]['feature']}' and '{top_3.iloc[2]['feature']}'.")
|
||||
print("\nBusiness Interpretation:")
|
||||
print("- If 'economy' features are top: The model confirms that money buys win rate.")
|
||||
print("- If 'spatial' features are top: The model understands map control is key.")
|
||||
print("- If 'status' (health/alive) features are top: The model relies on basic manpower advantage.")
|
||||
print("-" * 50)
|
||||
print(f"Check the visualizations in the '{REPORT_DIR}' folder to practice your storytelling.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
explain_model()
|
||||
182
src/inference/ensemble_framework.py
Normal file
182
src/inference/ensemble_framework.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Ensemble Framework: XGBoost + LSTM Fusion
|
||||
=========================================
|
||||
This script demonstrates the framework for combining the static state analysis of XGBoost
|
||||
with the temporal trend analysis of LSTM to produce a robust final prediction.
|
||||
|
||||
Methodology:
|
||||
1. Load both trained models (XGBoost .json, LSTM .pth).
|
||||
2. Prepare input data:
|
||||
- XGBoost: Takes single frame features (24 dims).
|
||||
- LSTM: Takes sequence of last 10 frames (10x24 dims).
|
||||
3. Generate independent probabilities: P_xgb, P_lstm.
|
||||
4. Fuse predictions using Weighted Averaging:
|
||||
P_final = alpha * P_xgb + (1 - alpha) * P_lstm
|
||||
5. Evaluate performance on the test set.
|
||||
|
||||
Usage:
|
||||
python src/inference/ensemble_framework.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import xgboost as xgb
|
||||
from sklearn.metrics import accuracy_score, classification_report, log_loss
|
||||
|
||||
# Ensure imports work
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
from src.training.models import ClutchLSTM
|
||||
from src.training.sequence_prep import create_sequences
|
||||
from src.features.definitions import FEATURE_COLUMNS
|
||||
|
||||
# Configuration
|
||||
XGB_MODEL_PATH = "models/clutch_v1.model.json"
|
||||
LSTM_MODEL_PATH = "models/clutch_lstm_v1.pth"
|
||||
TEST_DATA_PATH = "data/processed/test_set.parquet" # The test set saved by train.py
|
||||
|
||||
# Fusion Hyperparameter
|
||||
# 0.6 means we trust XGBoost slightly more (currently it has higher accuracy ~84% vs LSTM ~77%)
|
||||
ALPHA = 0.6
|
||||
|
||||
class ClutchEnsemble:
|
||||
def __init__(self, xgb_path, lstm_path, device='cpu'):
|
||||
self.device = device
|
||||
|
||||
# Load XGBoost
|
||||
print(f"Loading XGBoost from {xgb_path}...")
|
||||
self.xgb_model = xgb.XGBClassifier()
|
||||
self.xgb_model.load_model(xgb_path)
|
||||
|
||||
# Load LSTM
|
||||
print(f"Loading LSTM from {lstm_path}...")
|
||||
# Need to know input_dim (24 features)
|
||||
self.lstm_model = ClutchLSTM(input_dim=len(FEATURE_COLUMNS)).to(device)
|
||||
self.lstm_model.load_state_dict(torch.load(lstm_path, map_location=device))
|
||||
self.lstm_model.eval()
|
||||
|
||||
def predict(self, df):
|
||||
"""
|
||||
End-to-end prediction on a dataframe.
|
||||
Note: This handles the complexity of alignment.
|
||||
LSTM needs 10 frames, so the first 9 frames of each match cannot have LSTM predictions.
|
||||
Strategy: Fallback to XGBoost for the first 9 frames.
|
||||
"""
|
||||
# 1. XGBoost Predictions (Fast, parallel)
|
||||
print("Generating XGBoost predictions...")
|
||||
X_xgb = df[FEATURE_COLUMNS]
|
||||
# predict_proba returns [prob_0, prob_1], we want prob_1 (CT Win?)
|
||||
# Wait, check mapping. train.py: T=0, CT=1. So index 1 is CT win probability.
|
||||
# But wait, XGBoost might have different class order if not careful.
|
||||
# Usually classes_ is [0, 1].
|
||||
probs_xgb = self.xgb_model.predict_proba(X_xgb)[:, 1]
|
||||
|
||||
# 2. LSTM Predictions (Sequential)
|
||||
print("Generating LSTM predictions...")
|
||||
# We need to create sequences.
|
||||
# Ideally we reuse sequence_prep logic but we need to keep the index aligned with df.
|
||||
|
||||
# Initialize with NaN or fallback
|
||||
probs_lstm = np.full(len(df), np.nan)
|
||||
|
||||
# Group by match to avoid cross-match leakage in sequence creation
|
||||
# We need to iterate and fill `probs_lstm`
|
||||
|
||||
# For efficiency, let's extract sequences using the helper, but we need to know WHICH rows they correspond to.
|
||||
# create_sequences in sequence_prep.py returns arrays, stripping index.
|
||||
# Let's write a custom generator here that preserves alignment.
|
||||
|
||||
seq_len = 10
|
||||
inputs_list = []
|
||||
indices_list = []
|
||||
|
||||
grouped = df.groupby(['match_id', 'round'])
|
||||
|
||||
for (match_id, round_num), group in grouped:
|
||||
group = group.sort_values('tick')
|
||||
data = group[FEATURE_COLUMNS].values
|
||||
|
||||
if len(data) < seq_len:
|
||||
continue
|
||||
|
||||
for i in range(len(data) - seq_len + 1):
|
||||
# Sequence ends at index i + seq_len - 1
|
||||
# This index corresponds to the row we are predicting for
|
||||
row_idx = group.index[i + seq_len - 1]
|
||||
seq = data[i : i + seq_len]
|
||||
|
||||
inputs_list.append(seq)
|
||||
indices_list.append(row_idx)
|
||||
|
||||
if len(inputs_list) > 0:
|
||||
inputs_tensor = torch.FloatTensor(np.array(inputs_list)).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Batch prediction
|
||||
# Depending on RAM, might need mini-batches. For <10k rows, full batch is fine.
|
||||
outputs = self.lstm_model(inputs_tensor)
|
||||
# outputs is [batch, 1] (Sigmoid)
|
||||
lstm_preds = outputs.cpu().numpy().flatten()
|
||||
|
||||
# Fill the array
|
||||
probs_lstm[indices_list] = lstm_preds
|
||||
|
||||
# 3. Fusion
|
||||
print("Fusing predictions...")
|
||||
final_probs = []
|
||||
|
||||
for p_x, p_l in zip(probs_xgb, probs_lstm):
|
||||
if np.isnan(p_l):
|
||||
# Fallback to XGBoost if insufficient history (start of round)
|
||||
final_probs.append(p_x)
|
||||
else:
|
||||
# Weighted Average
|
||||
p_final = ALPHA * p_x + (1 - ALPHA) * p_l
|
||||
final_probs.append(p_final)
|
||||
|
||||
return np.array(final_probs)
|
||||
|
||||
def main():
|
||||
if not os.path.exists(TEST_DATA_PATH):
|
||||
print(f"Test set not found at {TEST_DATA_PATH}. Please run training first.")
|
||||
return
|
||||
|
||||
print(f"Loading test set from {TEST_DATA_PATH}...")
|
||||
df_test = pd.read_parquet(TEST_DATA_PATH)
|
||||
|
||||
# Ground Truth
|
||||
y_true = df_test['round_winner'].map({'T': 0, 'CT': 1}).values
|
||||
|
||||
# Initialize Ensemble
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
ensemble = ClutchEnsemble(XGB_MODEL_PATH, LSTM_MODEL_PATH, device)
|
||||
|
||||
# Predict
|
||||
y_prob = ensemble.predict(df_test)
|
||||
y_pred = (y_prob > 0.5).astype(int)
|
||||
|
||||
# Evaluate
|
||||
acc = accuracy_score(y_true, y_pred)
|
||||
ll = log_loss(y_true, y_prob)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print(" ENSEMBLE MODEL RESULTS ")
|
||||
print("="*50)
|
||||
print(f"🔥 Final Accuracy: {acc:.2%}")
|
||||
print(f"📉 Log Loss: {ll:.4f}")
|
||||
print("-" * 50)
|
||||
print("Detailed Report:")
|
||||
print(classification_report(y_true, y_pred, target_names=['T', 'CT']))
|
||||
print("="*50)
|
||||
|
||||
# Compare with standalone XGBoost for reference
|
||||
# (Since we have the loaded model, let's just check quickly)
|
||||
print("\n[Reference] Standalone XGBoost Performance:")
|
||||
y_prob_xgb = ensemble.xgb_model.predict_proba(df_test[FEATURE_COLUMNS])[:, 1]
|
||||
y_pred_xgb = (y_prob_xgb > 0.5).astype(int)
|
||||
print(f"XGB Accuracy: {accuracy_score(y_true, y_pred_xgb):.2%}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
217
src/inference/stacking_ensemble.py
Normal file
217
src/inference/stacking_ensemble.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Stacking Ensemble Framework (Advanced Fusion)
|
||||
=============================================
|
||||
Beyond simple weighted averaging, this script implements 'Stacking' (Stacked Generalization).
|
||||
It trains a Meta-Learner (Logistic Regression) to intelligently combine the predictions
|
||||
of the base models (XGBoost + LSTM) based on the current context (e.g., game time).
|
||||
|
||||
Architecture:
|
||||
1. Base Layer: XGBoost, LSTM
|
||||
2. Meta Layer: Logistic Regression
|
||||
Input: [Prob_XGB, Prob_LSTM, Game_Time, Team_Alive_Diff]
|
||||
Output: Final Probability
|
||||
|
||||
Why this is better:
|
||||
- It learns WHEN to trust which model (e.g., trust LSTM more in late-game).
|
||||
- It can correct systematic biases of base models.
|
||||
|
||||
Usage:
|
||||
python src/inference/stacking_ensemble.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import xgboost as xgb
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import accuracy_score, classification_report, log_loss
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# Ensure imports work
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
from src.training.models import ClutchAttentionLSTM
|
||||
from src.features.definitions import FEATURE_COLUMNS
|
||||
|
||||
# Configuration
|
||||
XGB_MODEL_PATH = "models/clutch_model_v1.json"
|
||||
LSTM_MODEL_PATH = "models/clutch_attention_lstm_v1.pth"
|
||||
TEST_DATA_PATH = "data/processed/test_set.parquet"
|
||||
|
||||
class StackingEnsemble:
|
||||
def __init__(self, xgb_path, lstm_path, device='cpu'):
|
||||
self.device = device
|
||||
self.meta_learner = LogisticRegression()
|
||||
self.is_fitted = False
|
||||
|
||||
# Load Base Models
|
||||
print(f"Loading Base Model: XGBoost...")
|
||||
self.xgb_model = xgb.XGBClassifier()
|
||||
self.xgb_model.load_model(xgb_path)
|
||||
|
||||
print(f"Loading Base Model: LSTM...")
|
||||
self.lstm_model = ClutchAttentionLSTM(input_dim=len(FEATURE_COLUMNS), hidden_dim=64, num_layers=2, dropout=0.5).to(device)
|
||||
self.lstm_model.load_state_dict(torch.load(lstm_path, map_location=device))
|
||||
self.lstm_model.eval()
|
||||
|
||||
def get_base_predictions(self, df):
|
||||
"""Generates features for the meta-learner."""
|
||||
# Reset index to ensure positional indexing works for probs_lstm
|
||||
df = df.reset_index(drop=True)
|
||||
|
||||
# 1. XGBoost Probabilities
|
||||
X_xgb = df[FEATURE_COLUMNS]
|
||||
probs_xgb = self.xgb_model.predict_proba(X_xgb)[:, 1]
|
||||
|
||||
# 2. LSTM Probabilities
|
||||
probs_lstm = np.full(len(df), np.nan)
|
||||
seq_len = 10
|
||||
inputs_list = []
|
||||
indices_list = []
|
||||
|
||||
grouped = df.groupby(['match_id', 'round'])
|
||||
|
||||
for (match_id, round_num), group in grouped:
|
||||
group = group.sort_values('tick')
|
||||
data = group[FEATURE_COLUMNS].values
|
||||
|
||||
if len(data) < seq_len:
|
||||
continue
|
||||
|
||||
for i in range(len(data) - seq_len + 1):
|
||||
row_idx = group.index[i + seq_len - 1]
|
||||
seq = data[i : i + seq_len]
|
||||
inputs_list.append(seq)
|
||||
indices_list.append(row_idx)
|
||||
|
||||
if len(inputs_list) > 0:
|
||||
inputs_tensor = torch.FloatTensor(np.array(inputs_list)).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self.lstm_model(inputs_tensor)
|
||||
lstm_preds = outputs.cpu().numpy().flatten()
|
||||
probs_lstm[indices_list] = lstm_preds
|
||||
|
||||
# Handle NaNs (start of rounds) - Fill with XGBoost prediction (trust base state)
|
||||
# This is a crucial imputation step for the meta-learner
|
||||
mask_nan = np.isnan(probs_lstm)
|
||||
probs_lstm[mask_nan] = probs_xgb[mask_nan]
|
||||
|
||||
# 3. Construct Meta-Features
|
||||
# We add 'game_time' to let the meta-learner learn temporal weighting
|
||||
# We add 'alive_diff' to let it know the complexity of the situation
|
||||
meta_features = pd.DataFrame({
|
||||
'prob_xgb': probs_xgb,
|
||||
'prob_lstm': probs_lstm,
|
||||
'game_time': df['game_time'].values,
|
||||
'alive_diff': df['alive_diff'].values
|
||||
})
|
||||
|
||||
return meta_features
|
||||
|
||||
def fit_meta_learner(self, df, y):
|
||||
"""Train the Meta-Learner on a validation set."""
|
||||
print("Generating meta-features for training...")
|
||||
X_meta = self.get_base_predictions(df)
|
||||
|
||||
print("Training Meta-Learner (Logistic Regression)...")
|
||||
self.meta_learner.fit(X_meta, y)
|
||||
self.is_fitted = True
|
||||
|
||||
# Analyze learned weights
|
||||
coefs = self.meta_learner.coef_[0]
|
||||
print("\n[Meta-Learner Insights]")
|
||||
print("How much does it trust each signal?")
|
||||
print(f" Weight on XGBoost: {coefs[0]:.4f}")
|
||||
print(f" Weight on LSTM: {coefs[1]:.4f}")
|
||||
print(f" Weight on GameTime: {coefs[2]:.4f}")
|
||||
print(f" Weight on AliveDiff: {coefs[3]:.4f}")
|
||||
print("(Positive weight = positive correlation with CT winning)\n")
|
||||
|
||||
def predict(self, df):
|
||||
if not self.is_fitted:
|
||||
raise ValueError("Meta-learner not fitted! Call fit_meta_learner first.")
|
||||
|
||||
X_meta = self.get_base_predictions(df)
|
||||
return self.meta_learner.predict_proba(X_meta)[:, 1]
|
||||
|
||||
def main():
|
||||
if not os.path.exists(TEST_DATA_PATH):
|
||||
print("Test data not found.")
|
||||
return
|
||||
|
||||
print(f"Loading data from {TEST_DATA_PATH}...")
|
||||
df = pd.read_parquet(TEST_DATA_PATH)
|
||||
|
||||
# Target Mapping
|
||||
# Data is already 0/1, so no need to map 'T'/'CT'
|
||||
y = df['round_winner'].values
|
||||
|
||||
# Split Data for Meta-Learning
|
||||
# We need a 'Meta-Train' set to train the unifier, and 'Meta-Test' to evaluate it.
|
||||
# Since we only have 2 matches in test_set, let's split by match_id if possible.
|
||||
unique_matches = df['match_id'].unique()
|
||||
|
||||
if len(unique_matches) >= 2:
|
||||
# Split 50/50 by match
|
||||
mid = len(unique_matches) // 2
|
||||
meta_train_matches = unique_matches[:mid]
|
||||
meta_test_matches = unique_matches[mid:]
|
||||
|
||||
train_mask = df['match_id'].isin(meta_train_matches)
|
||||
test_mask = df['match_id'].isin(meta_test_matches)
|
||||
|
||||
df_meta_train = df[train_mask]
|
||||
y_meta_train = y[train_mask]
|
||||
|
||||
df_meta_test = df[test_mask]
|
||||
y_meta_test = y[test_mask]
|
||||
|
||||
print(f"Split: Meta-Train ({len(df_meta_train)} rows) | Meta-Test ({len(df_meta_test)} rows)")
|
||||
else:
|
||||
print("Not enough matches for match-level split. Using random split (Caution: Leakage).")
|
||||
df_meta_train, df_meta_test, y_meta_train, y_meta_test = train_test_split(df, y, test_size=0.5, random_state=42)
|
||||
|
||||
# Initialize Ensemble
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
stacking_model = StackingEnsemble(XGB_MODEL_PATH, LSTM_MODEL_PATH, device)
|
||||
|
||||
# 1. Train Meta-Learner
|
||||
stacking_model.fit_meta_learner(df_meta_train, y_meta_train)
|
||||
|
||||
# 2. Evaluate on Held-out Meta-Test set
|
||||
print("Evaluating Stacking Ensemble...")
|
||||
y_prob = stacking_model.predict(df_meta_test)
|
||||
y_pred = (y_prob > 0.5).astype(int)
|
||||
|
||||
acc = accuracy_score(y_meta_test, y_pred)
|
||||
ll = log_loss(y_meta_test, y_prob)
|
||||
|
||||
print("\n" + "="*50)
|
||||
print(" STACKING ENSEMBLE RESULTS ")
|
||||
print("="*50)
|
||||
print(f"Final Accuracy: {acc:.2%}")
|
||||
print(f"Log Loss: {ll:.4f}")
|
||||
print("-" * 50)
|
||||
print(classification_report(y_meta_test, y_pred, target_names=['T', 'CT']))
|
||||
|
||||
# Baseline Comparison
|
||||
print("="*50)
|
||||
print("[Baselines on Meta-Test Set]")
|
||||
|
||||
# XGBoost Baseline
|
||||
X_test_xgb = df_meta_test[FEATURE_COLUMNS]
|
||||
y_pred_xgb = stacking_model.xgb_model.predict(X_test_xgb)
|
||||
acc_xgb = accuracy_score(y_meta_test, y_pred_xgb)
|
||||
print(f"XGBoost Only: {acc_xgb:.2%}")
|
||||
|
||||
# LSTM Baseline
|
||||
meta_features_test = stacking_model.get_base_predictions(df_meta_test)
|
||||
probs_lstm = meta_features_test['prob_lstm'].values
|
||||
y_pred_lstm = (probs_lstm > 0.5).astype(int)
|
||||
acc_lstm = accuracy_score(y_meta_test, y_pred_lstm)
|
||||
print(f"LSTM Only: {acc_lstm:.2%}")
|
||||
print("="*50)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
111
src/training/models.py
Normal file
111
src/training/models.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, hidden_dim):
|
||||
super(Attention, self).__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.attn = nn.Linear(hidden_dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: (batch, seq, hidden)
|
||||
|
||||
# Calculate attention scores
|
||||
# scores shape: (batch, seq, 1)
|
||||
scores = self.attn(x)
|
||||
|
||||
# Softmax over sequence dimension
|
||||
# weights shape: (batch, seq, 1)
|
||||
weights = F.softmax(scores, dim=1)
|
||||
|
||||
# Weighted sum
|
||||
# context shape: (batch, hidden)
|
||||
# element-wise multiplication broadcasted, then sum over seq
|
||||
context = torch.sum(x * weights, dim=1)
|
||||
|
||||
return context, weights
|
||||
|
||||
class ClutchLSTM(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim=64, num_layers=2, output_dim=1, dropout=0.2):
|
||||
super(ClutchLSTM, self).__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
# LSTM Layer
|
||||
# batch_first=True means input shape is (batch, seq, feature)
|
||||
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers,
|
||||
batch_first=True, dropout=dropout)
|
||||
|
||||
# Fully Connected Layer
|
||||
self.fc = nn.Linear(hidden_dim, output_dim)
|
||||
|
||||
# Sigmoid activation for binary classification
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: (batch, seq, feature)
|
||||
|
||||
# Initialize hidden state with zeros
|
||||
# Using x.device ensures tensors are on the same device (CPU/GPU)
|
||||
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)
|
||||
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)
|
||||
|
||||
# Forward propagate LSTM
|
||||
# out shape: (batch, seq, hidden_dim)
|
||||
out, _ = self.lstm(x, (h0, c0))
|
||||
|
||||
# Decode the hidden state of the last time step
|
||||
# out[:, -1, :] takes the last output of the sequence
|
||||
out = self.fc(out[:, -1, :])
|
||||
out = self.sigmoid(out)
|
||||
return out
|
||||
|
||||
class ClutchAttentionLSTM(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim=128, num_layers=2, output_dim=1, dropout=0.3):
|
||||
super(ClutchAttentionLSTM, self).__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
# 1. Input Layer Norm (Stabilizes training)
|
||||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
|
||||
# 2. LSTM (Increased hidden_dim for capacity)
|
||||
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers,
|
||||
batch_first=True, dropout=dropout, bidirectional=False)
|
||||
|
||||
# 3. Attention Mechanism
|
||||
self.attention = Attention(hidden_dim)
|
||||
|
||||
# 4. Fully Connected Layers
|
||||
self.fc1 = nn.Linear(hidden_dim, 64)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc2 = nn.Linear(64, output_dim)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: (batch, seq, feature)
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# LSTM
|
||||
# out: (batch, seq, hidden)
|
||||
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)
|
||||
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)
|
||||
lstm_out, _ = self.lstm(x, (h0, c0))
|
||||
|
||||
# Attention
|
||||
# context: (batch, hidden)
|
||||
context, attn_weights = self.attention(lstm_out)
|
||||
|
||||
# MLP Head
|
||||
out = self.fc1(context)
|
||||
out = self.relu(out)
|
||||
out = self.dropout(out)
|
||||
out = self.fc2(out)
|
||||
out = self.sigmoid(out)
|
||||
|
||||
return out
|
||||
114
src/training/sequence_prep.py
Normal file
114
src/training/sequence_prep.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Sequence Data Preparation for LSTM/GRU Models
|
||||
|
||||
This script transforms the L2 frame-level data into L3 sequence data suitable for RNNs.
|
||||
Output: (Batch_Size, Sequence_Length, Num_Features)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
# Ensure we can import from src
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
from src.training.train import load_data, preprocess_features
|
||||
from src.features.definitions import FEATURE_COLUMNS
|
||||
|
||||
# Configuration
|
||||
SEQ_LEN = 10 # 10 frames * 2s/frame = 20 seconds of context
|
||||
DATA_DIR = "data/processed"
|
||||
OUTPUT_DIR = "data/sequences"
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
class ClutchSequenceDataset(Dataset):
|
||||
def __init__(self, sequences, targets):
|
||||
self.sequences = torch.FloatTensor(sequences)
|
||||
self.targets = torch.FloatTensor(targets)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sequences)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.sequences[idx], self.targets[idx]
|
||||
|
||||
def create_sequences(df, seq_len=10):
|
||||
"""
|
||||
Creates sliding window sequences from the dataframe.
|
||||
Assumes df is already sorted by match, round, and time.
|
||||
"""
|
||||
sequences = []
|
||||
targets = []
|
||||
|
||||
# Group by Match and Round to ensure we don't sequence across boundaries
|
||||
# Using 'tick' to sort, assuming it increases with time
|
||||
grouped = df.groupby(['match_id', 'round'])
|
||||
|
||||
match_ids = [] # To store match_id for each sequence
|
||||
|
||||
for (match_id, round_num), group in grouped:
|
||||
group = group.sort_values('tick')
|
||||
|
||||
data = group[FEATURE_COLUMNS].values
|
||||
labels = group['round_winner'].values # 0 for T, 1 for CT (need to verify mapping)
|
||||
|
||||
# Check if we have enough data for at least one sequence
|
||||
if len(data) < seq_len:
|
||||
continue
|
||||
|
||||
# Create sliding windows
|
||||
for i in range(len(data) - seq_len + 1):
|
||||
seq = data[i : i + seq_len]
|
||||
label = labels[i + seq_len - 1] # Label of the last frame in sequence
|
||||
|
||||
sequences.append(seq)
|
||||
targets.append(label)
|
||||
match_ids.append(match_id)
|
||||
|
||||
return np.array(sequences), np.array(targets), np.array(match_ids)
|
||||
|
||||
def prepare_sequence_data(save=True):
|
||||
if not os.path.exists(OUTPUT_DIR):
|
||||
os.makedirs(OUTPUT_DIR)
|
||||
|
||||
logging.info("Loading raw data...")
|
||||
raw_df = load_data(DATA_DIR)
|
||||
|
||||
logging.info("Preprocessing to frame-level features...")
|
||||
df = preprocess_features(raw_df)
|
||||
|
||||
logging.info(f"Frame-level df shape: {df.shape}")
|
||||
logging.info(f"Columns: {df.columns.tolist()}")
|
||||
logging.info(f"Round Winner unique values: {df['round_winner'].unique()}")
|
||||
|
||||
# Ensure target is numeric
|
||||
# Check if mapping is needed
|
||||
if df['round_winner'].dtype == 'object':
|
||||
logging.info("Mapping targets from T/CT to 0/1...")
|
||||
df['round_winner'] = df['round_winner'].map({'T': 0, 'CT': 1})
|
||||
|
||||
logging.info(f"Round Winner unique values after mapping: {df['round_winner'].unique()}")
|
||||
df = df.dropna(subset=['round_winner'])
|
||||
logging.info(f"df shape after dropna: {df.shape}")
|
||||
|
||||
logging.info(f"Creating sequences (Length={SEQ_LEN})...")
|
||||
X, y, matches = create_sequences(df, SEQ_LEN)
|
||||
|
||||
logging.info(f"Generated {len(X)} sequences.")
|
||||
logging.info(f"Shape: {X.shape}")
|
||||
|
||||
if save:
|
||||
logging.info(f"Saving to {OUTPUT_DIR}...")
|
||||
np.save(os.path.join(OUTPUT_DIR, "X_seq.npy"), X)
|
||||
np.save(os.path.join(OUTPUT_DIR, "y_seq.npy"), y)
|
||||
np.save(os.path.join(OUTPUT_DIR, "matches_seq.npy"), matches)
|
||||
|
||||
return X, y, matches
|
||||
|
||||
if __name__ == "__main__":
|
||||
prepare_sequence_data()
|
||||
@@ -93,6 +93,9 @@ def preprocess_features(df):
|
||||
df['is_t'] = (df['team_num'] == 2).astype(int)
|
||||
df['is_ct'] = (df['team_num'] == 3).astype(int)
|
||||
|
||||
# Fill NA in 'is_alive' before conversion
|
||||
df['is_alive'] = df['is_alive'].fillna(0)
|
||||
|
||||
# Calculate player specific metrics
|
||||
df['t_alive'] = df['is_t'] * df['is_alive'].astype(int)
|
||||
df['ct_alive'] = df['is_ct'] * df['is_alive'].astype(int)
|
||||
|
||||
207
src/training/train_lstm.py
Normal file
207
src/training/train_lstm.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Train Attention-LSTM Model for Clutch-IQ
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score, log_loss, classification_report
|
||||
|
||||
# Ensure we can import from src
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
||||
from src.training.models import ClutchAttentionLSTM
|
||||
|
||||
# Config
|
||||
SEQ_DIR = os.path.join("data", "sequences")
|
||||
MODEL_DIR = "models"
|
||||
MODEL_PATH = os.path.join(MODEL_DIR, "clutch_attention_lstm_v1.pth")
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS = 50
|
||||
LR = 0.001
|
||||
PATIENCE = 10
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience=5, min_delta=0):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.best_loss = None
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss):
|
||||
if self.best_loss is None:
|
||||
self.best_loss = val_loss
|
||||
elif val_loss > self.best_loss - self.min_delta:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_loss = val_loss
|
||||
self.counter = 0
|
||||
|
||||
def train_lstm():
|
||||
if not os.path.exists(MODEL_DIR):
|
||||
os.makedirs(MODEL_DIR)
|
||||
|
||||
# 1. Load Data
|
||||
x_path = os.path.join(SEQ_DIR, "X_seq.npy")
|
||||
y_path = os.path.join(SEQ_DIR, "y_seq.npy")
|
||||
m_path = os.path.join(SEQ_DIR, "matches_seq.npy")
|
||||
|
||||
if not os.path.exists(x_path) or not os.path.exists(y_path):
|
||||
print(f"Data not found at {SEQ_DIR}. Please run src/training/sequence_prep.py first.")
|
||||
return
|
||||
|
||||
print("Loading sequence data...")
|
||||
X = np.load(x_path)
|
||||
y = np.load(y_path)
|
||||
|
||||
# Load match IDs if available, else warn and use random split
|
||||
if os.path.exists(m_path):
|
||||
matches = np.load(m_path)
|
||||
print(f"Loaded match IDs. Shape: {matches.shape}")
|
||||
else:
|
||||
print("Warning: matches_seq.npy not found. Using random split (risk of leakage).")
|
||||
matches = None
|
||||
|
||||
print(f"Data Shape: X={X.shape}, y={y.shape}")
|
||||
|
||||
# 2. Split
|
||||
if matches is not None:
|
||||
# GroupSplit
|
||||
unique_matches = np.unique(matches)
|
||||
print(f"Total unique matches: {len(unique_matches)}")
|
||||
|
||||
# Shuffle matches
|
||||
np.random.seed(42)
|
||||
np.random.shuffle(unique_matches)
|
||||
|
||||
n_train = int(len(unique_matches) * 0.8)
|
||||
train_match_ids = unique_matches[:n_train]
|
||||
test_match_ids = unique_matches[n_train:]
|
||||
|
||||
print(f"Train matches: {len(train_match_ids)}, Test matches: {len(test_match_ids)}")
|
||||
|
||||
train_mask = np.isin(matches, train_match_ids)
|
||||
test_mask = np.isin(matches, test_match_ids)
|
||||
|
||||
X_train, X_test = X[train_mask], X[test_mask]
|
||||
y_train, y_test = y[train_mask], y[test_mask]
|
||||
else:
|
||||
# Stratify is important for imbalanced datasets
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
print(f"Train set: {X_train.shape}, Test set: {X_test.shape}")
|
||||
|
||||
# 3. Convert to PyTorch Tensors
|
||||
train_data = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).float())
|
||||
test_data = TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float())
|
||||
|
||||
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
|
||||
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)
|
||||
|
||||
# 4. Model Setup
|
||||
input_dim = X.shape[2] # Number of features
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Training on {device}...")
|
||||
|
||||
# Initialize Attention LSTM with higher dropout and lower complexity
|
||||
model = ClutchAttentionLSTM(input_dim=input_dim, hidden_dim=64, num_layers=2, dropout=0.5).to(device)
|
||||
|
||||
criterion = nn.BCELoss()
|
||||
# Add weight decay for L2 regularization
|
||||
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
|
||||
early_stopping = EarlyStopping(patience=PATIENCE, min_delta=0.0001)
|
||||
|
||||
# 5. Train Loop
|
||||
best_loss = float('inf')
|
||||
|
||||
print("-" * 50)
|
||||
print(f"{'Epoch':<6} | {'Train Loss':<12} | {'Val Loss':<12} | {'Val Acc':<10} | {'LR':<10}")
|
||||
print("-" * 50)
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
for inputs, labels in train_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
labels = labels.unsqueeze(1) # (batch, 1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item() * inputs.size(0)
|
||||
|
||||
train_loss /= len(train_loader.dataset)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
labels = labels.unsqueeze(1)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
val_loss += loss.item() * inputs.size(0)
|
||||
|
||||
predicted = (outputs > 0.5).float()
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
val_loss /= len(test_loader.dataset)
|
||||
val_acc = correct / total
|
||||
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
print(f"{epoch+1:<6} | {train_loss:.4f} | {val_loss:.4f} | {val_acc:.2%} | {current_lr:.1e}")
|
||||
|
||||
# Scheduler & Checkpointing
|
||||
scheduler.step(val_loss)
|
||||
|
||||
if val_loss < best_loss:
|
||||
best_loss = val_loss
|
||||
torch.save(model.state_dict(), MODEL_PATH)
|
||||
# print(f" -> Model saved (Val Loss: {val_loss:.4f})")
|
||||
|
||||
# Early Stopping
|
||||
early_stopping(val_loss)
|
||||
if early_stopping.early_stop:
|
||||
print("Early stopping triggered!")
|
||||
break
|
||||
|
||||
print("Training Complete.")
|
||||
|
||||
# 6. Final Evaluation
|
||||
print(f"Loading best model from {MODEL_PATH}...")
|
||||
model.load_state_dict(torch.load(MODEL_PATH))
|
||||
model.eval()
|
||||
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_loader:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
outputs = model(inputs)
|
||||
preds = (outputs.squeeze() > 0.5).float()
|
||||
all_preds.extend(preds.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
print("-" * 50)
|
||||
print("Detailed Classification Report (Best Model):")
|
||||
print(classification_report(all_labels, all_preds, target_names=['T (Terrorist)', 'CT (Counter-Terrorist)']))
|
||||
print("="*50)
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_lstm()
|
||||
Reference in New Issue
Block a user