Files
dano2025/spam_hypot/best_model_and_plots.py
2025-12-13 03:10:17 +03:00

145 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sqlite3
from pathlib import Path
import sys
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (10, 5)
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root / "preanalysis"))
import eda_utils as eda # noqa: E402
db_path = project_root / "dataset" / "ds.sqlite"
conn = sqlite3.connect(db_path)
df = pd.read_sql_query("select * from communications", conn, parse_dates=["business_dt"])
conn.close()
for cols, name in [
(eda.ACTIVE_IMP_COLS, "active_imp_total"),
(eda.PASSIVE_IMP_COLS, "passive_imp_total"),
(eda.ACTIVE_CLICK_COLS, "active_click_total"),
(eda.PASSIVE_CLICK_COLS, "passive_click_total"),
(eda.ORDER_COLS, "orders_amt_total"),
]:
df[name] = df[cols].sum(axis=1)
df["imp_total"] = df["active_imp_total"] + df["passive_imp_total"]
df["click_total"] = df["active_click_total"] + df["passive_click_total"]
contact_days = df.groupby("id")["business_dt"].nunique().rename("contact_days")
client = (
df.groupby("id")
.agg(
imp_total=("imp_total", "sum"),
click_total=("click_total", "sum"),
orders_amt_total=("orders_amt_total", "sum"),
age=("age", "median"),
gender_cd=("gender_cd", lambda s: s.mode().iat[0]),
device_platform_cd=("device_platform_cd", lambda s: s.mode().iat[0]),
)
.merge(contact_days, on="id", how="left")
.reset_index()
)
client["order_rate"] = eda.safe_divide(client["orders_amt_total"], client["imp_total"]) # orders / impressions
client["order_rate_pct"] = 100 * client["order_rate"] # чтобы шкала была человеческая
client["avg_imp_per_day"] = eda.safe_divide(client["imp_total"], client["contact_days"])
# Mean absolute orders for each exact avg_imp_per_day (no bins), sorted ascending
stats_imp = (
client.groupby("avg_imp_per_day", as_index=False)
.agg(
orders_mean=("orders_amt_total", "mean"),
n_clients=("id", "count"),
)
.sort_values("avg_imp_per_day")
)
K_MULT = 2 # "в разы" -> 5x. Поменяй на 3/10 если хочешь
ABS_DY_MIN = 1
X_MAX = 16
stats_imp = stats_imp.sort_values("avg_imp_per_day").reset_index(drop=True)
# 1) cut by x
stats_f = stats_imp[stats_imp["avg_imp_per_day"] <= X_MAX].copy().reset_index(drop=True)
# 2) detect vertical outliers by dy logic
before = len(stats_f)
y = stats_f["orders_mean"]
abs_dy = y.diff().abs()
prev3_mean = abs_dy.shift(1).rolling(window=3, min_periods=3).mean()
ratio = abs_dy / (prev3_mean.replace(0, np.nan)) # avoid inf when prev mean == 0
is_outlier = (abs_dy >= ABS_DY_MIN) & (ratio >= K_MULT) | (y > 5)
# первые точки не могут нормально иметь "3 предыдущих дельты"
is_outlier = is_outlier.fillna(False)
stats_f = stats_f.loc[~is_outlier].copy().reset_index(drop=True)
after = len(stats_f)
cleaned = before - after
print(f"{before} - {after}, cleaned: {cleaned}")
# --- smoothing (rolling mean on remaining points) ---
w = max(7, int(len(stats_f) * 0.05))
if w % 2 == 0:
w += 1
stats_f["orders_smooth"] = (
stats_f["orders_mean"]
.rolling(window=w, center=True, min_periods=1)
.mean()
)
# --- cost line (linear expenses) ---
# нормируем так, чтобы масштаб был сопоставим с заказами
c = stats_f["orders_smooth"].max() / stats_f["avg_imp_per_day"].max()
stats_f["cost_line"] = c * stats_f["avg_imp_per_day"]
# plot
plt.figure(figsize=(10, 8))
plt.plot(
stats_f["avg_imp_per_day"],
stats_f["orders_mean"],
marker="o",
linewidth=1,
alpha=0.3,
label="Среднее число заказов"
)
plt.plot(
stats_f["avg_imp_per_day"],
stats_f["orders_smooth"],
color="red",
linewidth=2.5,
label="Сглаженный тренд заказов"
)
plt.plot(
stats_f["avg_imp_per_day"],
stats_f["cost_line"],
color="black",
linestyle="--",
linewidth=2,
label="Линейные расходы на показы"
)
plt.xlabel("Среднее число показов в день")
plt.ylabel("Среднее число заказов")
plt.title("Зависимость заказов от интенсивности коммуникаций")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(
project_root / "spam_hypot" / "orders_vs_avg_imp_with_costs.png",
dpi=150
)
print("Saved orders_vs_avg_imp_with_costs.png")