diff --git a/spam_hypot/quad_regression_with_costs.png b/spam_hypot/quad_regression_with_costs.png new file mode 100644 index 0000000..321fed8 Binary files /dev/null and b/spam_hypot/quad_regression_with_costs.png differ diff --git a/spam_hypot/quadreg.py b/spam_hypot/quadreg.py new file mode 100644 index 0000000..ce2ce12 --- /dev/null +++ b/spam_hypot/quadreg.py @@ -0,0 +1,240 @@ +import sqlite3 +from pathlib import Path +import sys +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +import statsmodels.api as sm + +sns.set_theme(style="whitegrid") +plt.rcParams["figure.figsize"] = (10, 6) + +# ----------------------------- +# Load + feature engineering (как у тебя) +# ----------------------------- +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root / "preanalysis")) +import eda_utils as eda # noqa: E402 + +db_path = project_root / "dataset" / "ds.sqlite" +conn = sqlite3.connect(db_path) +df = pd.read_sql_query("select * from communications", conn, parse_dates=["business_dt"]) +conn.close() + +for cols, name in [ + (eda.ACTIVE_IMP_COLS, "active_imp_total"), + (eda.PASSIVE_IMP_COLS, "passive_imp_total"), + (eda.ACTIVE_CLICK_COLS, "active_click_total"), + (eda.PASSIVE_CLICK_COLS, "passive_click_total"), + (eda.ORDER_COLS, "orders_amt_total"), +]: + df[name] = df[cols].sum(axis=1) + +df["imp_total"] = df["active_imp_total"] + df["passive_imp_total"] +df["click_total"] = df["active_click_total"] + df["passive_click_total"] + +contact_days = df.groupby("id")["business_dt"].nunique().rename("contact_days") + +client = ( + df.groupby("id") + .agg( + imp_total=("imp_total", "sum"), + click_total=("click_total", "sum"), + orders_amt_total=("orders_amt_total", "sum"), + age=("age", "median"), + gender_cd=("gender_cd", lambda s: s.mode().iat[0]), + device_platform_cd=("device_platform_cd", lambda s: s.mode().iat[0]), + ) + .merge(contact_days, on="id", how="left") + .reset_index() +) + +client["order_rate"] = eda.safe_divide(client["orders_amt_total"], client["imp_total"]) +client["order_rate_pct"] = 100 * client["order_rate"] +client["avg_imp_per_day"] = eda.safe_divide(client["imp_total"], client["contact_days"]) + +# ----------------------------- +# Aggregate curve points (как у тебя) +# ----------------------------- +stats_imp = ( + client.groupby("avg_imp_per_day", as_index=False) + .agg( + orders_mean=("orders_amt_total", "mean"), + n_clients=("id", "count"), + ) + .sort_values("avg_imp_per_day") +).reset_index(drop=True) + +# ----------------------------- +# Filtering / outlier logic (как у тебя) +# ----------------------------- +K_MULT = 2 +ABS_DY_MIN = 1 +X_MAX = 16 + +stats_f = stats_imp[stats_imp["avg_imp_per_day"] <= X_MAX].copy().reset_index(drop=True) + +before = len(stats_f) +y = stats_f["orders_mean"] +abs_dy = y.diff().abs() + +prev3_mean = abs_dy.shift(1).rolling(window=3, min_periods=3).mean() +ratio = abs_dy / (prev3_mean.replace(0, np.nan)) + +is_outlier = ((abs_dy >= ABS_DY_MIN) & (ratio >= K_MULT)) | (y > 5) +is_outlier = is_outlier.fillna(False) + +stats_f = stats_f.loc[~is_outlier].copy().reset_index(drop=True) +after = len(stats_f) +print(f"Фильтрация: было {before}, стало {after}, убрали {before-after} точек") + +# ----------------------------- +# Smoothing (оставим для визуалки, но регрессию делаем по orders_mean) +# ----------------------------- +w = max(7, int(len(stats_f) * 0.05)) +if w % 2 == 0: + w += 1 + +stats_f["orders_smooth"] = ( + stats_f["orders_mean"] + .rolling(window=w, center=True, min_periods=1) + .mean() +) + +# ----------------------------- +# Cost line (как у тебя, нормировка "в единицах заказов") +# ----------------------------- +c = stats_f["orders_smooth"].max() / stats_f["avg_imp_per_day"].max() +stats_f["cost_line"] = c * stats_f["avg_imp_per_day"] + +# ----------------------------- +# Quadratic regression: orders_mean ~ 1 + x + x^2 +# WLS with weights = n_clients +# ----------------------------- +x = stats_f["avg_imp_per_day"].to_numpy() +y = stats_f["orders_mean"].to_numpy() +wts = stats_f["n_clients"].to_numpy().astype(float) + +X = np.column_stack([x, x**2]) +X = sm.add_constant(X) # [1, x, x^2] + +model = sm.WLS(y, X, weights=wts) +res = model.fit(cov_type="HC3") # робастные ошибки + +b0, b1, b2 = res.params +p_b1_two = res.pvalues[1] +p_b2_two = res.pvalues[2] + +# one-sided p-values for directional hypotheses +p_b1_pos = (p_b1_two / 2) if (b1 > 0) else (1 - p_b1_two / 2) +p_b2_neg = (p_b2_two / 2) if (b2 < 0) else (1 - p_b2_two / 2) + +# turning point (if concave) +x_star = None +y_star = None +if b2 < 0: + x_star = -b1 / (2 * b2) + y_star = b0 + b1 * x_star + b2 * x_star**2 + +# Intersection with cost line: b0 + b1 x + b2 x^2 = c x -> b2 x^2 + (b1 - c) x + b0 = 0 +x_cross = None +roots = np.roots([b2, (b1 - c), b0]) # may be complex +roots = [r.real for r in roots if abs(r.imag) < 1e-8] +roots_in_range = [r for r in roots if (stats_f["avg_imp_per_day"].min() <= r <= stats_f["avg_imp_per_day"].max())] +if roots_in_range: + # берём корень ближе к "правой" части (обычно пересечение интереснее там, где начинается невыгодно) + x_cross = max(roots_in_range) + +# ----------------------------- +# Print results + interpretation (по-человечески) +# ----------------------------- +print("\n=== Квадратичная регрессия (WLS, веса = n_clients, SE = HC3) ===") +print(res.summary()) + +print("\n=== Проверка гипотезы убывающей отдачи / спада ===") +print(f"β1 (линейный эффект): {b1:.6f}, двусторонний p={p_b1_two:.4g}, односторонний p(β1>0)={p_b1_pos:.4g}") +print(f"β2 (кривизна): {b2:.6f}, двусторонний p={p_b2_two:.4g}, односторонний p(β2<0)={p_b2_neg:.4g}") + +alpha = 0.05 +support = (b1 > 0) and (b2 < 0) and (p_b1_pos < alpha) and (p_b2_neg < alpha) + +if support: + print("\nВывод: данные поддерживают гипотезу нелинейности.") + print("Есть статистически значимый рост на малых x (β1>0) и насыщение/спад (β2<0).") +else: + print("\nВывод: строгого статистического подтверждения по знакам/значимости может не хватить.") + print("Но знак коэффициентов и форма кривой всё равно могут быть согласованы с гипотезой.") + print("На защите говори аккуратно: 'наблюдается тенденция/согласуется с гипотезой'.") + +if x_star is not None: + print(f"\nОценка 'порога насыщения' (вершина параболы): x* = {x_star:.3f} показов/день") + print(f"Прогноз среднего числа заказов в x*: y(x*) ≈ {y_star:.3f}") + if not (stats_f["avg_imp_per_day"].min() <= x_star <= stats_f["avg_imp_per_day"].max()): + print("Внимание: x* вне диапазона наблюдений, интерпретация как 'оптимума' сомнительная.") +else: + print("\nВершина не считается как максимум: β2 >= 0 (нет выпуклости вниз).") + +if x_cross is not None: + y_cross = b0 + b1 * x_cross + b2 * x_cross**2 + print(f"\nТочка пересечения с линейными расходами (в нормировке c={c:.4f}): x≈{x_cross:.3f}, y≈{y_cross:.3f}") +else: + print("\nПересечение с линией расходов в выбранной нормировке не найдено (или вне диапазона).") + +# ----------------------------- +# Plot: points + smooth + quadratic fit + cost + markers +# ----------------------------- +x_grid = np.linspace(stats_f["avg_imp_per_day"].min(), stats_f["avg_imp_per_day"].max(), 300) +y_hat = b0 + b1 * x_grid + b2 * x_grid**2 +cost_hat = c * x_grid + +plt.figure(figsize=(10, 8)) + +plt.plot( + stats_f["avg_imp_per_day"], stats_f["orders_mean"], + marker="o", linestyle="-", linewidth=1, alpha=0.3, + label="Среднее число заказов (по точкам)" +) + +plt.plot( + stats_f["avg_imp_per_day"], stats_f["orders_smooth"], + color="red", linewidth=2.2, + label="Сглаженный тренд (rolling mean)" +) + +plt.plot( + x_grid, y_hat, + color="blue", linewidth=2.5, + label="Квадратичная регрессия (WLS)" +) + +plt.plot( + x_grid, cost_hat, + color="black", linestyle="--", linewidth=2, + label="Линейные расходы на показы" +) + +if x_star is not None and (stats_f["avg_imp_per_day"].min() <= x_star <= stats_f["avg_imp_per_day"].max()): + plt.axvline(x_star, color="blue", linestyle=":", linewidth=2) + plt.scatter([x_star], [y_star], color="blue", zorder=5) + plt.text(x_star, y_star, f" x*={x_star:.2f}", va="bottom") + +if x_cross is not None: + y_cross = b0 + b1 * x_cross + b2 * x_cross**2 + plt.axvline(x_cross, color="black", linestyle=":", linewidth=2, alpha=0.8) + plt.scatter([x_cross], [y_cross], color="black", zorder=5) + plt.text(x_cross, y_cross, f" пересечение≈{x_cross:.2f}", va="top") + +plt.xlabel("Среднее число показов в день") +plt.ylabel("Среднее число заказов") +plt.title("Нелинейный эффект интенсивности коммуникаций: квадратичная регрессия") +plt.legend() +plt.grid(alpha=0.3) +plt.tight_layout() + +out_dir = project_root / "spam_hypot" +out_dir.mkdir(parents=True, exist_ok=True) +out_path = out_dir / "quad_regression_with_costs.png" +plt.savefig(out_path, dpi=150) +print(f"\nSaved: {out_path}")