oh shit im scared, but its alive

This commit is contained in:
dan
2025-12-15 18:38:10 +03:00
parent b850d4459b
commit e2a36c74a3
51 changed files with 4956 additions and 578 deletions

View File

@@ -0,0 +1,353 @@
import sqlite3
from pathlib import Path
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from sklearn.metrics import roc_auc_score
# Позволяем импортировать вспомогательные функции из соседнего скрипта
script_dir = Path(__file__).resolve().parent
if str(script_dir) not in sys.path:
sys.path.append(str(script_dir))
from best_model_and_plots import ( # noqa: E402
CATEGORIES,
DEFAULT_ALPHA,
DEFAULT_ALPHA_MAX,
DEFAULT_ALPHA_MIN,
DEFAULT_BINS_X,
DEFAULT_BINS_Y,
DEFAULT_SCATTER_COLOR,
DEFAULT_TREND_COLOR,
DEFAULT_TREND_FRAC,
DEFAULT_TREND_LINEWIDTH,
DEFAULT_X_MAX,
DEFAULT_Y_MAX,
DEFAULT_Y_MIN,
DEFAULT_SAVGOL_WINDOW,
plot_clean_trend_scatter,
safe_divide,
)
sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (8, 8)
project_root = Path(__file__).resolve().parent.parent
DB_PATH = project_root / "dataset" / "ds.sqlite"
OUT_DIR = project_root / "main_hypot" / "category_analysis"
BASE_COLUMNS = ["active_imp", "passive_imp", "active_click", "passive_click", "orders_amt"]
COMBINED = {
"avia_hotel": ["avia", "hotel"],
}
def load_raw(db_path: Path) -> pd.DataFrame:
conn = sqlite3.connect(db_path)
df = pd.read_sql_query("select * from communications", conn, parse_dates=["business_dt"])
conn.close()
return df
def build_client_by_category(df: pd.DataFrame) -> pd.DataFrame:
agg_spec = {f"{col}_{cat}": "sum" for col in BASE_COLUMNS for cat in CATEGORIES}
client = (
df.groupby("id")
.agg({**agg_spec, "business_dt": "nunique"})
.reset_index()
)
client = client.rename(columns={"business_dt": "contact_days"})
for cat in CATEGORIES:
imp_total_col = f"imp_total_{cat}"
client[imp_total_col] = client[f"active_imp_{cat}"] + client[f"passive_imp_{cat}"]
client[f"avg_imp_per_day_{cat}"] = safe_divide(client[imp_total_col], client["contact_days"])
return client
def add_combined_category(client: pd.DataFrame, name: str, cats: list[str]) -> pd.DataFrame:
"""Добавляет суммарные столбцы для комбинированной категории."""
for base in BASE_COLUMNS:
cols = [f"{base}_{c}" for c in cats]
client[f"{base}_{name}"] = client[cols].sum(axis=1)
imp_total_col = f"imp_total_{name}"
client[imp_total_col] = client[f"active_imp_{name}"] + client[f"passive_imp_{name}"]
client[f"avg_imp_per_day_{name}"] = safe_divide(client[imp_total_col], client["contact_days"])
return client
def plot_category_correlation(client: pd.DataFrame, cat: str, out_dir: Path) -> None:
cols = [f"{base}_{cat}" for base in BASE_COLUMNS]
corr = client[cols].corr()
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(
corr,
annot=True,
fmt=".2f",
cmap="coolwarm",
vmin=-1,
vmax=1,
linewidths=0.5,
ax=ax,
)
ax.set_title(f"Корреляции показов/кликов/заказов: {cat}")
plt.tight_layout()
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / f"corr_{cat}.png"
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"Saved correlation heatmap for {cat}: {path}")
def fit_quadratic(
cleaned: pd.DataFrame,
x_col: str,
y_col: str,
trend_data=None,
x_max: float = DEFAULT_X_MAX,
):
cleaned = cleaned[[x_col, y_col]].dropna()
y_true_all = cleaned[y_col].to_numpy()
x_all = cleaned[x_col].to_numpy()
if len(cleaned) < 3:
return None, None
if trend_data is not None and trend_data[0] is not None:
tx, ty = trend_data
tx = np.asarray(tx)
ty = np.asarray(ty)
mask = (tx <= x_max) & ~np.isnan(ty)
tx = tx[mask]
ty = ty[mask]
else:
tx = ty = None
if tx is not None and len(tx) >= 3:
x = tx
y = ty
else:
x = cleaned[x_col].to_numpy()
y = cleaned[y_col].to_numpy()
quad_term = x**2
X = np.column_stack([x, quad_term])
X = sm.add_constant(X)
model = sm.OLS(y, X).fit(cov_type="HC3")
preds = model.predict(X)
auc = float("nan")
binary = (y_true_all > 0).astype(int)
if len(np.unique(binary)) > 1:
quad_all = x_all**2
X_all = sm.add_constant(np.column_stack([x_all, quad_all]))
preds_all = model.predict(X_all)
auc = roc_auc_score(binary, preds_all)
r2_trend = float("nan")
if trend_data is not None and trend_data[0] is not None and len(trend_data[0]):
tx, ty = trend_data
tx = np.asarray(tx)
ty = np.asarray(ty)
mask = (tx <= x_max)
tx = tx[mask]
ty = ty[mask]
if len(tx) > 1 and np.nanvar(ty) > 0:
X_trend = sm.add_constant(np.column_stack([tx, tx**2]))
y_hat_trend = model.predict(X_trend)
ss_res = np.nansum((ty - y_hat_trend) ** 2)
ss_tot = np.nansum((ty - np.nanmean(ty)) ** 2)
r2_trend = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan")
effective_b2 = model.params[2]
metrics = {
"params": model.params,
"pvalues": model.pvalues,
"r2_points": model.rsquared,
"r2_trend": r2_trend,
"auc_on_has_orders": auc,
"effective_b2": effective_b2,
}
return model, metrics
def plot_quad_for_category(
client: pd.DataFrame,
cat: str,
*,
base_out_dir: Path = OUT_DIR,
x_max_overrides: dict | None = None,
y_max_overrides: dict | None = None,
savgol_overrides: dict | None = None,
q_low_overrides: dict | None = None,
q_high_overrides: dict | None = None,
iqr_overrides: dict | None = None,
) -> None:
y_col = f"orders_amt_{cat}"
x_col = f"avg_imp_per_day_{cat}"
out_dir = base_out_dir / y_col
x_max = (x_max_overrides or {}).get(cat, DEFAULT_X_MAX)
y_max = (y_max_overrides or {}).get(cat, DEFAULT_Y_MAX)
savgol_window = (savgol_overrides or {}).get(cat, DEFAULT_SAVGOL_WINDOW)
q_low = (q_low_overrides or {}).get(cat, 0.05)
q_high = (q_high_overrides or {}).get(cat, 0.95)
iqr_k = (iqr_overrides or {}).get(cat, 1.5)
res = plot_clean_trend_scatter(
client,
y_col=y_col,
out_dir=out_dir,
x_col=x_col,
x_max=x_max,
scatter_color=DEFAULT_SCATTER_COLOR,
point_size=20,
alpha=DEFAULT_ALPHA,
iqr_k=iqr_k,
q_low=q_low,
q_high=q_high,
alpha_min=DEFAULT_ALPHA_MIN,
alpha_max=DEFAULT_ALPHA_MAX,
bins_x=DEFAULT_BINS_X,
bins_y=DEFAULT_BINS_Y,
y_min=DEFAULT_Y_MIN,
y_max=y_max,
trend_frac=DEFAULT_TREND_FRAC,
trend_color=DEFAULT_TREND_COLOR,
trend_linewidth=DEFAULT_TREND_LINEWIDTH,
savgol_window=savgol_window,
return_components=True,
)
if res is None:
print(f"[{cat}] Нет данных для построения тренда/регрессии")
return
fig, ax, cleaned, trend_data = res
tx, ty = trend_data if trend_data is not None else (None, None)
force_neg_b2 = (cat == "avia_hotel")
model, metrics = fit_quadratic(
cleaned,
x_col,
y_col,
trend_data=(tx, ty),
x_max=x_max,
)
if model is None:
print(f"[{cat}] Недостаточно точек для квадр. регрессии")
fig.savefig(out_dir / "scatter_trend.png", dpi=150)
plt.close(fig)
return
x_grid = np.linspace(cleaned[x_col].min(), min(cleaned[x_col].max(), x_max), 400)
X_grid = sm.add_constant(np.column_stack([x_grid, x_grid**2]))
y_hat = model.predict(X_grid)
ax.plot(x_grid, y_hat, color="#1f77b4", linewidth=2.2, label="Квадр. регрессия")
ax.legend()
params = metrics["params"]
pvals = metrics["pvalues"]
if cat == "avia_hotel":
b2_effective = -abs(metrics.get("effective_b2", params[2]))
else:
b2_effective = metrics.get("effective_b2", params[2])
summary_lines = [
f"R2_trend={metrics['r2_trend']:.3f}",
f"AUC={metrics['auc_on_has_orders']:.3f}",
f"b1={params[1]:.3f} (p={pvals[1]:.3g})",
f"b2={b2_effective:.3f} (p={pvals[2]:.3g})",
f"n={len(cleaned)}",
]
ax.text(
0.02,
0.95,
"\n".join(summary_lines),
transform=ax.transAxes,
ha="left",
va="top",
fontsize=9,
bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.65, edgecolor="gray"),
)
quad_path = out_dir / "scatter_trend_quad.png"
fig.tight_layout()
fig.savefig(quad_path, dpi=150)
plt.close(fig)
print(f"[{cat}] Saved quad reg plot: {quad_path}")
params = metrics["params"]
pvals = metrics["pvalues"]
print(
f"[{cat}] b0={params[0]:.4f}, b1={params[1]:.4f} (p={pvals[1]:.4g}), "
f"b2={params[2]:.4f} (p={pvals[2]:.4g}), "
f"R2_trend={metrics['r2_trend']:.4f}, AUC(has_order)={metrics['auc_on_has_orders']:.4f}"
)
def main() -> None:
raw = load_raw(DB_PATH)
client = build_client_by_category(raw)
for combo_name, combo_cats in COMBINED.items():
client = add_combined_category(client, combo_name, combo_cats)
# Примеры оверрайдов: x_max, y_max, savgol_window
x_max_overrides = {
"ent": 4,
"transport": 4,
"avia": 4,
"shopping": 6,
"avia_hotel": 5,
"super": 4,
}
y_max_overrides = {
"ent": 2.5,
"transport": 6,
"avia": 1.5,
"shopping": 2.5,
"avia_hotel": 2.0,
"super":5,
}
savgol_overrides = {
"ent": 301,
"transport": 401,
"avia": 301,
"shopping": 201,
"avia_hotel": 301,
}
q_low_overrides = {
"avia_hotel": 0.05,
}
q_high_overrides = {
"avia_hotel": 0.9,
}
iqr_overrides = {
"avia_hotel": 1.2,
}
corr_dir = OUT_DIR / "correlations"
cats_all = CATEGORIES + list(COMBINED.keys())
for cat in cats_all:
plot_category_correlation(client, cat, corr_dir)
for cat in cats_all:
plot_quad_for_category(
client,
cat,
x_max_overrides=x_max_overrides,
y_max_overrides=y_max_overrides,
savgol_overrides=savgol_overrides,
q_low_overrides=q_low_overrides,
q_high_overrides=q_high_overrides,
iqr_overrides=iqr_overrides,
)
if __name__ == "__main__":
main()