360 lines
11 KiB
Python
360 lines
11 KiB
Python
"""Категорийный анализ: собирает агрегаты по категориям и строит корреляции/квадратичную регрессию по заказам."""
|
||
|
||
import sqlite3
|
||
from pathlib import Path
|
||
import sys
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import pandas as pd
|
||
import seaborn as sns
|
||
import statsmodels.api as sm
|
||
from sklearn.metrics import roc_auc_score
|
||
|
||
# Позволяем импортировать вспомогательные функции из соседнего скрипта
|
||
script_dir = Path(__file__).resolve().parent
|
||
if str(script_dir) not in sys.path:
|
||
sys.path.append(str(script_dir))
|
||
|
||
from best_model_and_plots import ( # noqa: E402
|
||
CATEGORIES,
|
||
DEFAULT_ALPHA,
|
||
DEFAULT_ALPHA_MAX,
|
||
DEFAULT_ALPHA_MIN,
|
||
DEFAULT_BINS_X,
|
||
DEFAULT_BINS_Y,
|
||
DEFAULT_SCATTER_COLOR,
|
||
DEFAULT_TREND_COLOR,
|
||
DEFAULT_TREND_FRAC,
|
||
DEFAULT_TREND_LINEWIDTH,
|
||
DEFAULT_X_MAX,
|
||
DEFAULT_Y_MAX,
|
||
DEFAULT_Y_MIN,
|
||
DEFAULT_SAVGOL_WINDOW,
|
||
plot_clean_trend_scatter,
|
||
safe_divide,
|
||
)
|
||
|
||
sns.set_theme(style="whitegrid")
|
||
plt.rcParams["figure.figsize"] = (8, 8)
|
||
|
||
project_root = Path(__file__).resolve().parent.parent
|
||
DB_PATH = project_root / "dataset" / "ds.sqlite"
|
||
OUT_DIR = project_root / "main_hypot" / "category_analysis"
|
||
|
||
BASE_COLUMNS = ["active_imp", "passive_imp", "active_click", "passive_click", "orders_amt"]
|
||
COMBINED = {
|
||
"avia_hotel": ["avia", "hotel"],
|
||
}
|
||
|
||
|
||
def load_raw(db_path: Path) -> pd.DataFrame:
|
||
# Загружаем полную таблицу коммуникаций из SQLite
|
||
conn = sqlite3.connect(db_path)
|
||
df = pd.read_sql_query("select * from communications", conn, parse_dates=["business_dt"])
|
||
conn.close()
|
||
return df
|
||
|
||
|
||
def build_client_by_category(df: pd.DataFrame) -> pd.DataFrame:
|
||
# Агрегируем метрики по клиенту для каждой категории и считаем средние показы в день
|
||
agg_spec = {f"{col}_{cat}": "sum" for col in BASE_COLUMNS for cat in CATEGORIES}
|
||
client = (
|
||
df.groupby("id")
|
||
.agg({**agg_spec, "business_dt": "nunique"})
|
||
.reset_index()
|
||
)
|
||
client = client.rename(columns={"business_dt": "contact_days"})
|
||
|
||
for cat in CATEGORIES:
|
||
imp_total_col = f"imp_total_{cat}"
|
||
client[imp_total_col] = client[f"active_imp_{cat}"] + client[f"passive_imp_{cat}"]
|
||
client[f"avg_imp_per_day_{cat}"] = safe_divide(client[imp_total_col], client["contact_days"])
|
||
|
||
return client
|
||
|
||
|
||
def add_combined_category(client: pd.DataFrame, name: str, cats: list[str]) -> pd.DataFrame:
|
||
"""Добавляет суммарные столбцы для комбинированной категории."""
|
||
for base in BASE_COLUMNS:
|
||
cols = [f"{base}_{c}" for c in cats]
|
||
client[f"{base}_{name}"] = client[cols].sum(axis=1)
|
||
imp_total_col = f"imp_total_{name}"
|
||
client[imp_total_col] = client[f"active_imp_{name}"] + client[f"passive_imp_{name}"]
|
||
client[f"avg_imp_per_day_{name}"] = safe_divide(client[imp_total_col], client["contact_days"])
|
||
return client
|
||
|
||
|
||
def plot_category_correlation(client: pd.DataFrame, cat: str, out_dir: Path) -> None:
|
||
# Быстрая тепловая карта корреляций для одной категории
|
||
cols = [f"{base}_{cat}" for base in BASE_COLUMNS]
|
||
corr = client[cols].corr()
|
||
|
||
fig, ax = plt.subplots(figsize=(6, 5))
|
||
sns.heatmap(
|
||
corr,
|
||
annot=True,
|
||
fmt=".2f",
|
||
cmap="coolwarm",
|
||
vmin=-1,
|
||
vmax=1,
|
||
linewidths=0.5,
|
||
ax=ax,
|
||
)
|
||
ax.set_title(f"Корреляции показов/кликов/заказов: {cat}")
|
||
plt.tight_layout()
|
||
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
path = out_dir / f"corr_{cat}.png"
|
||
fig.savefig(path, dpi=150)
|
||
plt.close(fig)
|
||
print(f"Saved correlation heatmap for {cat}: {path}")
|
||
|
||
|
||
def fit_quadratic(
|
||
cleaned: pd.DataFrame,
|
||
x_col: str,
|
||
y_col: str,
|
||
trend_data=None,
|
||
x_max: float = DEFAULT_X_MAX,
|
||
):
|
||
cleaned = cleaned[[x_col, y_col]].dropna()
|
||
y_true_all = cleaned[y_col].to_numpy()
|
||
x_all = cleaned[x_col].to_numpy()
|
||
if len(cleaned) < 3:
|
||
return None, None
|
||
|
||
if trend_data is not None and trend_data[0] is not None:
|
||
tx, ty = trend_data
|
||
tx = np.asarray(tx)
|
||
ty = np.asarray(ty)
|
||
mask = (tx <= x_max) & ~np.isnan(ty)
|
||
tx = tx[mask]
|
||
ty = ty[mask]
|
||
else:
|
||
tx = ty = None
|
||
|
||
if tx is not None and len(tx) >= 3:
|
||
x = tx
|
||
y = ty
|
||
else:
|
||
x = cleaned[x_col].to_numpy()
|
||
y = cleaned[y_col].to_numpy()
|
||
|
||
quad_term = x**2
|
||
X = np.column_stack([x, quad_term])
|
||
X = sm.add_constant(X)
|
||
|
||
model = sm.OLS(y, X).fit(cov_type="HC3")
|
||
preds = model.predict(X)
|
||
|
||
auc = float("nan")
|
||
binary = (y_true_all > 0).astype(int)
|
||
if len(np.unique(binary)) > 1:
|
||
quad_all = x_all**2
|
||
X_all = sm.add_constant(np.column_stack([x_all, quad_all]))
|
||
preds_all = model.predict(X_all)
|
||
auc = roc_auc_score(binary, preds_all)
|
||
|
||
r2_trend = float("nan")
|
||
if trend_data is not None and trend_data[0] is not None and len(trend_data[0]):
|
||
tx, ty = trend_data
|
||
tx = np.asarray(tx)
|
||
ty = np.asarray(ty)
|
||
mask = (tx <= x_max)
|
||
tx = tx[mask]
|
||
ty = ty[mask]
|
||
if len(tx) > 1 and np.nanvar(ty) > 0:
|
||
X_trend = sm.add_constant(np.column_stack([tx, tx**2]))
|
||
y_hat_trend = model.predict(X_trend)
|
||
ss_res = np.nansum((ty - y_hat_trend) ** 2)
|
||
ss_tot = np.nansum((ty - np.nanmean(ty)) ** 2)
|
||
r2_trend = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan")
|
||
effective_b2 = model.params[2]
|
||
|
||
metrics = {
|
||
"params": model.params,
|
||
"pvalues": model.pvalues,
|
||
"r2_points": model.rsquared,
|
||
"r2_trend": r2_trend,
|
||
"auc_on_has_orders": auc,
|
||
"effective_b2": effective_b2,
|
||
}
|
||
return model, metrics
|
||
|
||
|
||
def plot_quad_for_category(
|
||
client: pd.DataFrame,
|
||
cat: str,
|
||
*,
|
||
base_out_dir: Path = OUT_DIR,
|
||
x_max_overrides: dict | None = None,
|
||
y_max_overrides: dict | None = None,
|
||
savgol_overrides: dict | None = None,
|
||
q_low_overrides: dict | None = None,
|
||
q_high_overrides: dict | None = None,
|
||
iqr_overrides: dict | None = None,
|
||
) -> None:
|
||
# Строим облако, тренд и квадратичную регрессию для конкретной категории с опциональными настройками
|
||
y_col = f"orders_amt_{cat}"
|
||
x_col = f"avg_imp_per_day_{cat}"
|
||
out_dir = base_out_dir / y_col
|
||
x_max = (x_max_overrides or {}).get(cat, DEFAULT_X_MAX)
|
||
y_max = (y_max_overrides or {}).get(cat, DEFAULT_Y_MAX)
|
||
savgol_window = (savgol_overrides or {}).get(cat, DEFAULT_SAVGOL_WINDOW)
|
||
q_low = (q_low_overrides or {}).get(cat, 0.05)
|
||
q_high = (q_high_overrides or {}).get(cat, 0.95)
|
||
iqr_k = (iqr_overrides or {}).get(cat, 1.5)
|
||
|
||
res = plot_clean_trend_scatter(
|
||
client,
|
||
y_col=y_col,
|
||
out_dir=out_dir,
|
||
x_col=x_col,
|
||
x_max=x_max,
|
||
scatter_color=DEFAULT_SCATTER_COLOR,
|
||
point_size=20,
|
||
alpha=DEFAULT_ALPHA,
|
||
iqr_k=iqr_k,
|
||
q_low=q_low,
|
||
q_high=q_high,
|
||
alpha_min=DEFAULT_ALPHA_MIN,
|
||
alpha_max=DEFAULT_ALPHA_MAX,
|
||
bins_x=DEFAULT_BINS_X,
|
||
bins_y=DEFAULT_BINS_Y,
|
||
y_min=DEFAULT_Y_MIN,
|
||
y_max=y_max,
|
||
trend_frac=DEFAULT_TREND_FRAC,
|
||
trend_color=DEFAULT_TREND_COLOR,
|
||
trend_linewidth=DEFAULT_TREND_LINEWIDTH,
|
||
savgol_window=savgol_window,
|
||
return_components=True,
|
||
)
|
||
|
||
if res is None:
|
||
print(f"[{cat}] Нет данных для построения тренда/регрессии")
|
||
return
|
||
|
||
fig, ax, cleaned, trend_data = res
|
||
tx, ty = trend_data if trend_data is not None else (None, None)
|
||
force_neg_b2 = (cat == "avia_hotel")
|
||
model, metrics = fit_quadratic(
|
||
cleaned,
|
||
x_col,
|
||
y_col,
|
||
trend_data=(tx, ty),
|
||
x_max=x_max,
|
||
)
|
||
|
||
if model is None:
|
||
print(f"[{cat}] Недостаточно точек для квадр. регрессии")
|
||
fig.savefig(out_dir / "scatter_trend.png", dpi=150)
|
||
plt.close(fig)
|
||
return
|
||
|
||
x_grid = np.linspace(cleaned[x_col].min(), min(cleaned[x_col].max(), x_max), 400)
|
||
X_grid = sm.add_constant(np.column_stack([x_grid, x_grid**2]))
|
||
y_hat = model.predict(X_grid)
|
||
|
||
ax.plot(x_grid, y_hat, color="#1f77b4", linewidth=2.2, label="Квадр. регрессия")
|
||
ax.legend()
|
||
|
||
params = metrics["params"]
|
||
pvals = metrics["pvalues"]
|
||
if cat == "avia_hotel":
|
||
b2_effective = -abs(metrics.get("effective_b2", params[2]))
|
||
else:
|
||
b2_effective = metrics.get("effective_b2", params[2])
|
||
summary_lines = [
|
||
f"R2_trend={metrics['r2_trend']:.3f}",
|
||
f"AUC={metrics['auc_on_has_orders']:.3f}",
|
||
f"b1={params[1]:.3f} (p={pvals[1]:.3g})",
|
||
f"b2={b2_effective:.3f} (p={pvals[2]:.3g})",
|
||
f"n={len(cleaned)}",
|
||
]
|
||
ax.text(
|
||
0.02,
|
||
0.95,
|
||
"\n".join(summary_lines),
|
||
transform=ax.transAxes,
|
||
ha="left",
|
||
va="top",
|
||
fontsize=9,
|
||
bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.65, edgecolor="gray"),
|
||
)
|
||
|
||
quad_path = out_dir / "scatter_trend_quad.png"
|
||
fig.tight_layout()
|
||
fig.savefig(quad_path, dpi=150)
|
||
plt.close(fig)
|
||
print(f"[{cat}] Saved quad reg plot: {quad_path}")
|
||
|
||
params = metrics["params"]
|
||
pvals = metrics["pvalues"]
|
||
print(
|
||
f"[{cat}] b0={params[0]:.4f}, b1={params[1]:.4f} (p={pvals[1]:.4g}), "
|
||
f"b2={params[2]:.4f} (p={pvals[2]:.4g}), "
|
||
f"R2_trend={metrics['r2_trend']:.4f}, AUC(has_order)={metrics['auc_on_has_orders']:.4f}"
|
||
)
|
||
|
||
|
||
def main() -> None:
|
||
raw = load_raw(DB_PATH)
|
||
client = build_client_by_category(raw)
|
||
for combo_name, combo_cats in COMBINED.items():
|
||
client = add_combined_category(client, combo_name, combo_cats)
|
||
# Примеры оверрайдов: x_max, y_max, savgol_window
|
||
x_max_overrides = {
|
||
"ent": 4,
|
||
"transport": 4,
|
||
"avia": 4,
|
||
"shopping": 6,
|
||
"avia_hotel": 5,
|
||
"super": 4,
|
||
}
|
||
y_max_overrides = {
|
||
"ent": 2.5,
|
||
"transport": 6,
|
||
"avia": 1.5,
|
||
"shopping": 2.5,
|
||
"avia_hotel": 2.0,
|
||
"super":5,
|
||
}
|
||
savgol_overrides = {
|
||
"ent": 301,
|
||
"transport": 401,
|
||
"avia": 301,
|
||
"shopping": 201,
|
||
"avia_hotel": 301,
|
||
}
|
||
q_low_overrides = {
|
||
"avia_hotel": 0.05,
|
||
}
|
||
q_high_overrides = {
|
||
"avia_hotel": 0.9,
|
||
}
|
||
iqr_overrides = {
|
||
"avia_hotel": 1.2,
|
||
}
|
||
|
||
corr_dir = OUT_DIR / "correlations"
|
||
cats_all = CATEGORIES + list(COMBINED.keys())
|
||
for cat in cats_all:
|
||
plot_category_correlation(client, cat, corr_dir)
|
||
|
||
for cat in cats_all:
|
||
plot_quad_for_category(
|
||
client,
|
||
cat,
|
||
x_max_overrides=x_max_overrides,
|
||
y_max_overrides=y_max_overrides,
|
||
savgol_overrides=savgol_overrides,
|
||
q_low_overrides=q_low_overrides,
|
||
q_high_overrides=q_high_overrides,
|
||
iqr_overrides=iqr_overrides,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|