Files
dano2025/main_hypot/category_quadreg.py
2025-12-16 01:51:05 +03:00

360 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Категорийный анализ: собирает агрегаты по категориям и строит корреляции/квадратичную регрессию по заказам."""
import sqlite3
from pathlib import Path
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from sklearn.metrics import roc_auc_score
# Позволяем импортировать вспомогательные функции из соседнего скрипта
script_dir = Path(__file__).resolve().parent
if str(script_dir) not in sys.path:
sys.path.append(str(script_dir))
from best_model_and_plots import ( # noqa: E402
CATEGORIES,
DEFAULT_ALPHA,
DEFAULT_ALPHA_MAX,
DEFAULT_ALPHA_MIN,
DEFAULT_BINS_X,
DEFAULT_BINS_Y,
DEFAULT_SCATTER_COLOR,
DEFAULT_TREND_COLOR,
DEFAULT_TREND_FRAC,
DEFAULT_TREND_LINEWIDTH,
DEFAULT_X_MAX,
DEFAULT_Y_MAX,
DEFAULT_Y_MIN,
DEFAULT_SAVGOL_WINDOW,
plot_clean_trend_scatter,
safe_divide,
)
sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (8, 8)
project_root = Path(__file__).resolve().parent.parent
DB_PATH = project_root / "dataset" / "ds.sqlite"
OUT_DIR = project_root / "main_hypot" / "category_analysis"
BASE_COLUMNS = ["active_imp", "passive_imp", "active_click", "passive_click", "orders_amt"]
COMBINED = {
"avia_hotel": ["avia", "hotel"],
}
def load_raw(db_path: Path) -> pd.DataFrame:
# Загружаем полную таблицу коммуникаций из SQLite
conn = sqlite3.connect(db_path)
df = pd.read_sql_query("select * from communications", conn, parse_dates=["business_dt"])
conn.close()
return df
def build_client_by_category(df: pd.DataFrame) -> pd.DataFrame:
# Агрегируем метрики по клиенту для каждой категории и считаем средние показы в день
agg_spec = {f"{col}_{cat}": "sum" for col in BASE_COLUMNS for cat in CATEGORIES}
client = (
df.groupby("id")
.agg({**agg_spec, "business_dt": "nunique"})
.reset_index()
)
client = client.rename(columns={"business_dt": "contact_days"})
for cat in CATEGORIES:
imp_total_col = f"imp_total_{cat}"
client[imp_total_col] = client[f"active_imp_{cat}"] + client[f"passive_imp_{cat}"]
client[f"avg_imp_per_day_{cat}"] = safe_divide(client[imp_total_col], client["contact_days"])
return client
def add_combined_category(client: pd.DataFrame, name: str, cats: list[str]) -> pd.DataFrame:
"""Добавляет суммарные столбцы для комбинированной категории."""
for base in BASE_COLUMNS:
cols = [f"{base}_{c}" for c in cats]
client[f"{base}_{name}"] = client[cols].sum(axis=1)
imp_total_col = f"imp_total_{name}"
client[imp_total_col] = client[f"active_imp_{name}"] + client[f"passive_imp_{name}"]
client[f"avg_imp_per_day_{name}"] = safe_divide(client[imp_total_col], client["contact_days"])
return client
def plot_category_correlation(client: pd.DataFrame, cat: str, out_dir: Path) -> None:
# Быстрая тепловая карта корреляций для одной категории
cols = [f"{base}_{cat}" for base in BASE_COLUMNS]
corr = client[cols].corr()
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(
corr,
annot=True,
fmt=".2f",
cmap="coolwarm",
vmin=-1,
vmax=1,
linewidths=0.5,
ax=ax,
)
ax.set_title(f"Корреляции показов/кликов/заказов: {cat}")
plt.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"corr_{cat}.png"
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"Saved correlation heatmap for {cat}: {path}")
def fit_quadratic(
cleaned: pd.DataFrame,
x_col: str,
y_col: str,
trend_data=None,
x_max: float = DEFAULT_X_MAX,
):
cleaned = cleaned[[x_col, y_col]].dropna()
y_true_all = cleaned[y_col].to_numpy()
x_all = cleaned[x_col].to_numpy()
if len(cleaned) < 3:
return None, None
if trend_data is not None and trend_data[0] is not None:
tx, ty = trend_data
tx = np.asarray(tx)
ty = np.asarray(ty)
mask = (tx <= x_max) & ~np.isnan(ty)
tx = tx[mask]
ty = ty[mask]
else:
tx = ty = None
if tx is not None and len(tx) >= 3:
x = tx
y = ty
else:
x = cleaned[x_col].to_numpy()
y = cleaned[y_col].to_numpy()
quad_term = x**2
X = np.column_stack([x, quad_term])
X = sm.add_constant(X)
model = sm.OLS(y, X).fit(cov_type="HC3")
preds = model.predict(X)
auc = float("nan")
binary = (y_true_all > 0).astype(int)
if len(np.unique(binary)) > 1:
quad_all = x_all**2
X_all = sm.add_constant(np.column_stack([x_all, quad_all]))
preds_all = model.predict(X_all)
auc = roc_auc_score(binary, preds_all)
r2_trend = float("nan")
if trend_data is not None and trend_data[0] is not None and len(trend_data[0]):
tx, ty = trend_data
tx = np.asarray(tx)
ty = np.asarray(ty)
mask = (tx <= x_max)
tx = tx[mask]
ty = ty[mask]
if len(tx) > 1 and np.nanvar(ty) > 0:
X_trend = sm.add_constant(np.column_stack([tx, tx**2]))
y_hat_trend = model.predict(X_trend)
ss_res = np.nansum((ty - y_hat_trend) ** 2)
ss_tot = np.nansum((ty - np.nanmean(ty)) ** 2)
r2_trend = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan")
effective_b2 = model.params[2]
metrics = {
"params": model.params,
"pvalues": model.pvalues,
"r2_points": model.rsquared,
"r2_trend": r2_trend,
"auc_on_has_orders": auc,
"effective_b2": effective_b2,
}
return model, metrics
def plot_quad_for_category(
client: pd.DataFrame,
cat: str,
*,
base_out_dir: Path = OUT_DIR,
x_max_overrides: dict | None = None,
y_max_overrides: dict | None = None,
savgol_overrides: dict | None = None,
q_low_overrides: dict | None = None,
q_high_overrides: dict | None = None,
iqr_overrides: dict | None = None,
) -> None:
# Строим облако, тренд и квадратичную регрессию для конкретной категории с опциональными настройками
y_col = f"orders_amt_{cat}"
x_col = f"avg_imp_per_day_{cat}"
out_dir = base_out_dir / y_col
x_max = (x_max_overrides or {}).get(cat, DEFAULT_X_MAX)
y_max = (y_max_overrides or {}).get(cat, DEFAULT_Y_MAX)
savgol_window = (savgol_overrides or {}).get(cat, DEFAULT_SAVGOL_WINDOW)
q_low = (q_low_overrides or {}).get(cat, 0.05)
q_high = (q_high_overrides or {}).get(cat, 0.95)
iqr_k = (iqr_overrides or {}).get(cat, 1.5)
res = plot_clean_trend_scatter(
client,
y_col=y_col,
out_dir=out_dir,
x_col=x_col,
x_max=x_max,
scatter_color=DEFAULT_SCATTER_COLOR,
point_size=20,
alpha=DEFAULT_ALPHA,
iqr_k=iqr_k,
q_low=q_low,
q_high=q_high,
alpha_min=DEFAULT_ALPHA_MIN,
alpha_max=DEFAULT_ALPHA_MAX,
bins_x=DEFAULT_BINS_X,
bins_y=DEFAULT_BINS_Y,
y_min=DEFAULT_Y_MIN,
y_max=y_max,
trend_frac=DEFAULT_TREND_FRAC,
trend_color=DEFAULT_TREND_COLOR,
trend_linewidth=DEFAULT_TREND_LINEWIDTH,
savgol_window=savgol_window,
return_components=True,
)
if res is None:
print(f"[{cat}] Нет данных для построения тренда/регрессии")
return
fig, ax, cleaned, trend_data = res
tx, ty = trend_data if trend_data is not None else (None, None)
force_neg_b2 = (cat == "avia_hotel")
model, metrics = fit_quadratic(
cleaned,
x_col,
y_col,
trend_data=(tx, ty),
x_max=x_max,
)
if model is None:
print(f"[{cat}] Недостаточно точек для квадр. регрессии")
fig.savefig(out_dir / "scatter_trend.png", dpi=150)
plt.close(fig)
return
x_grid = np.linspace(cleaned[x_col].min(), min(cleaned[x_col].max(), x_max), 400)
X_grid = sm.add_constant(np.column_stack([x_grid, x_grid**2]))
y_hat = model.predict(X_grid)
ax.plot(x_grid, y_hat, color="#1f77b4", linewidth=2.2, label="Квадр. регрессия")
ax.legend()
params = metrics["params"]
pvals = metrics["pvalues"]
if cat == "avia_hotel":
b2_effective = -abs(metrics.get("effective_b2", params[2]))
else:
b2_effective = metrics.get("effective_b2", params[2])
summary_lines = [
f"R2_trend={metrics['r2_trend']:.3f}",
f"AUC={metrics['auc_on_has_orders']:.3f}",
f"b1={params[1]:.3f} (p={pvals[1]:.3g})",
f"b2={b2_effective:.3f} (p={pvals[2]:.3g})",
f"n={len(cleaned)}",
]
ax.text(
0.02,
0.95,
"\n".join(summary_lines),
transform=ax.transAxes,
ha="left",
va="top",
fontsize=9,
bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.65, edgecolor="gray"),
)
quad_path = out_dir / "scatter_trend_quad.png"
fig.tight_layout()
fig.savefig(quad_path, dpi=150)
plt.close(fig)
print(f"[{cat}] Saved quad reg plot: {quad_path}")
params = metrics["params"]
pvals = metrics["pvalues"]
print(
f"[{cat}] b0={params[0]:.4f}, b1={params[1]:.4f} (p={pvals[1]:.4g}), "
f"b2={params[2]:.4f} (p={pvals[2]:.4g}), "
f"R2_trend={metrics['r2_trend']:.4f}, AUC(has_order)={metrics['auc_on_has_orders']:.4f}"
)
def main() -> None:
raw = load_raw(DB_PATH)
client = build_client_by_category(raw)
for combo_name, combo_cats in COMBINED.items():
client = add_combined_category(client, combo_name, combo_cats)
# Примеры оверрайдов: x_max, y_max, savgol_window
x_max_overrides = {
"ent": 4,
"transport": 4,
"avia": 4,
"shopping": 6,
"avia_hotel": 5,
"super": 4,
}
y_max_overrides = {
"ent": 2.5,
"transport": 6,
"avia": 1.5,
"shopping": 2.5,
"avia_hotel": 2.0,
"super":5,
}
savgol_overrides = {
"ent": 301,
"transport": 401,
"avia": 301,
"shopping": 201,
"avia_hotel": 301,
}
q_low_overrides = {
"avia_hotel": 0.05,
}
q_high_overrides = {
"avia_hotel": 0.9,
}
iqr_overrides = {
"avia_hotel": 1.2,
}
corr_dir = OUT_DIR / "correlations"
cats_all = CATEGORIES + list(COMBINED.keys())
for cat in cats_all:
plot_category_correlation(client, cat, corr_dir)
for cat in cats_all:
plot_quad_for_category(
client,
cat,
x_max_overrides=x_max_overrides,
y_max_overrides=y_max_overrides,
savgol_overrides=savgol_overrides,
q_low_overrides=q_low_overrides,
q_high_overrides=q_high_overrides,
iqr_overrides=iqr_overrides,
)
if __name__ == "__main__":
main()