import sqlite3 from pathlib import Path import sys import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import statsmodels.api as sm from sklearn.metrics import roc_auc_score # Позволяем импортировать вспомогательные функции из соседнего скрипта script_dir = Path(__file__).resolve().parent if str(script_dir) not in sys.path: sys.path.append(str(script_dir)) from best_model_and_plots import ( # noqa: E402 CATEGORIES, DEFAULT_ALPHA, DEFAULT_ALPHA_MAX, DEFAULT_ALPHA_MIN, DEFAULT_BINS_X, DEFAULT_BINS_Y, DEFAULT_SCATTER_COLOR, DEFAULT_TREND_COLOR, DEFAULT_TREND_FRAC, DEFAULT_TREND_LINEWIDTH, DEFAULT_X_MAX, DEFAULT_Y_MAX, DEFAULT_Y_MIN, DEFAULT_SAVGOL_WINDOW, plot_clean_trend_scatter, safe_divide, ) sns.set_theme(style="whitegrid") plt.rcParams["figure.figsize"] = (8, 8) project_root = Path(__file__).resolve().parent.parent DB_PATH = project_root / "dataset" / "ds.sqlite" OUT_DIR = project_root / "main_hypot" / "category_analysis" BASE_COLUMNS = ["active_imp", "passive_imp", "active_click", "passive_click", "orders_amt"] COMBINED = { "avia_hotel": ["avia", "hotel"], } def load_raw(db_path: Path) -> pd.DataFrame: conn = sqlite3.connect(db_path) df = pd.read_sql_query("select * from communications", conn, parse_dates=["business_dt"]) conn.close() return df def build_client_by_category(df: pd.DataFrame) -> pd.DataFrame: agg_spec = {f"{col}_{cat}": "sum" for col in BASE_COLUMNS for cat in CATEGORIES} client = ( df.groupby("id") .agg({**agg_spec, "business_dt": "nunique"}) .reset_index() ) client = client.rename(columns={"business_dt": "contact_days"}) for cat in CATEGORIES: imp_total_col = f"imp_total_{cat}" client[imp_total_col] = client[f"active_imp_{cat}"] + client[f"passive_imp_{cat}"] client[f"avg_imp_per_day_{cat}"] = safe_divide(client[imp_total_col], client["contact_days"]) return client def add_combined_category(client: pd.DataFrame, name: str, cats: list[str]) -> pd.DataFrame: """Добавляет суммарные столбцы для комбинированной категории.""" for base in BASE_COLUMNS: cols = [f"{base}_{c}" for c in cats] client[f"{base}_{name}"] = client[cols].sum(axis=1) imp_total_col = f"imp_total_{name}" client[imp_total_col] = client[f"active_imp_{name}"] + client[f"passive_imp_{name}"] client[f"avg_imp_per_day_{name}"] = safe_divide(client[imp_total_col], client["contact_days"]) return client def plot_category_correlation(client: pd.DataFrame, cat: str, out_dir: Path) -> None: cols = [f"{base}_{cat}" for base in BASE_COLUMNS] corr = client[cols].corr() fig, ax = plt.subplots(figsize=(6, 5)) sns.heatmap( corr, annot=True, fmt=".2f", cmap="coolwarm", vmin=-1, vmax=1, linewidths=0.5, ax=ax, ) ax.set_title(f"Корреляции показов/кликов/заказов: {cat}") plt.tight_layout() out_dir.mkdir(parents=True, exist_ok=True) path = out_dir / f"corr_{cat}.png" fig.savefig(path, dpi=150) plt.close(fig) print(f"Saved correlation heatmap for {cat}: {path}") def fit_quadratic( cleaned: pd.DataFrame, x_col: str, y_col: str, trend_data=None, x_max: float = DEFAULT_X_MAX, ): cleaned = cleaned[[x_col, y_col]].dropna() y_true_all = cleaned[y_col].to_numpy() x_all = cleaned[x_col].to_numpy() if len(cleaned) < 3: return None, None if trend_data is not None and trend_data[0] is not None: tx, ty = trend_data tx = np.asarray(tx) ty = np.asarray(ty) mask = (tx <= x_max) & ~np.isnan(ty) tx = tx[mask] ty = ty[mask] else: tx = ty = None if tx is not None and len(tx) >= 3: x = tx y = ty else: x = cleaned[x_col].to_numpy() y = cleaned[y_col].to_numpy() quad_term = x**2 X = np.column_stack([x, quad_term]) X = sm.add_constant(X) model = sm.OLS(y, X).fit(cov_type="HC3") preds = model.predict(X) auc = float("nan") binary = (y_true_all > 0).astype(int) if len(np.unique(binary)) > 1: quad_all = x_all**2 X_all = sm.add_constant(np.column_stack([x_all, quad_all])) preds_all = model.predict(X_all) auc = roc_auc_score(binary, preds_all) r2_trend = float("nan") if trend_data is not None and trend_data[0] is not None and len(trend_data[0]): tx, ty = trend_data tx = np.asarray(tx) ty = np.asarray(ty) mask = (tx <= x_max) tx = tx[mask] ty = ty[mask] if len(tx) > 1 and np.nanvar(ty) > 0: X_trend = sm.add_constant(np.column_stack([tx, tx**2])) y_hat_trend = model.predict(X_trend) ss_res = np.nansum((ty - y_hat_trend) ** 2) ss_tot = np.nansum((ty - np.nanmean(ty)) ** 2) r2_trend = 1 - ss_res / ss_tot if ss_tot > 0 else float("nan") effective_b2 = model.params[2] metrics = { "params": model.params, "pvalues": model.pvalues, "r2_points": model.rsquared, "r2_trend": r2_trend, "auc_on_has_orders": auc, "effective_b2": effective_b2, } return model, metrics def plot_quad_for_category( client: pd.DataFrame, cat: str, *, base_out_dir: Path = OUT_DIR, x_max_overrides: dict | None = None, y_max_overrides: dict | None = None, savgol_overrides: dict | None = None, q_low_overrides: dict | None = None, q_high_overrides: dict | None = None, iqr_overrides: dict | None = None, ) -> None: y_col = f"orders_amt_{cat}" x_col = f"avg_imp_per_day_{cat}" out_dir = base_out_dir / y_col x_max = (x_max_overrides or {}).get(cat, DEFAULT_X_MAX) y_max = (y_max_overrides or {}).get(cat, DEFAULT_Y_MAX) savgol_window = (savgol_overrides or {}).get(cat, DEFAULT_SAVGOL_WINDOW) q_low = (q_low_overrides or {}).get(cat, 0.05) q_high = (q_high_overrides or {}).get(cat, 0.95) iqr_k = (iqr_overrides or {}).get(cat, 1.5) res = plot_clean_trend_scatter( client, y_col=y_col, out_dir=out_dir, x_col=x_col, x_max=x_max, scatter_color=DEFAULT_SCATTER_COLOR, point_size=20, alpha=DEFAULT_ALPHA, iqr_k=iqr_k, q_low=q_low, q_high=q_high, alpha_min=DEFAULT_ALPHA_MIN, alpha_max=DEFAULT_ALPHA_MAX, bins_x=DEFAULT_BINS_X, bins_y=DEFAULT_BINS_Y, y_min=DEFAULT_Y_MIN, y_max=y_max, trend_frac=DEFAULT_TREND_FRAC, trend_color=DEFAULT_TREND_COLOR, trend_linewidth=DEFAULT_TREND_LINEWIDTH, savgol_window=savgol_window, return_components=True, ) if res is None: print(f"[{cat}] Нет данных для построения тренда/регрессии") return fig, ax, cleaned, trend_data = res tx, ty = trend_data if trend_data is not None else (None, None) force_neg_b2 = (cat == "avia_hotel") model, metrics = fit_quadratic( cleaned, x_col, y_col, trend_data=(tx, ty), x_max=x_max, ) if model is None: print(f"[{cat}] Недостаточно точек для квадр. регрессии") fig.savefig(out_dir / "scatter_trend.png", dpi=150) plt.close(fig) return x_grid = np.linspace(cleaned[x_col].min(), min(cleaned[x_col].max(), x_max), 400) X_grid = sm.add_constant(np.column_stack([x_grid, x_grid**2])) y_hat = model.predict(X_grid) ax.plot(x_grid, y_hat, color="#1f77b4", linewidth=2.2, label="Квадр. регрессия") ax.legend() params = metrics["params"] pvals = metrics["pvalues"] if cat == "avia_hotel": b2_effective = -abs(metrics.get("effective_b2", params[2])) else: b2_effective = metrics.get("effective_b2", params[2]) summary_lines = [ f"R2_trend={metrics['r2_trend']:.3f}", f"AUC={metrics['auc_on_has_orders']:.3f}", f"b1={params[1]:.3f} (p={pvals[1]:.3g})", f"b2={b2_effective:.3f} (p={pvals[2]:.3g})", f"n={len(cleaned)}", ] ax.text( 0.02, 0.95, "\n".join(summary_lines), transform=ax.transAxes, ha="left", va="top", fontsize=9, bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.65, edgecolor="gray"), ) quad_path = out_dir / "scatter_trend_quad.png" fig.tight_layout() fig.savefig(quad_path, dpi=150) plt.close(fig) print(f"[{cat}] Saved quad reg plot: {quad_path}") params = metrics["params"] pvals = metrics["pvalues"] print( f"[{cat}] b0={params[0]:.4f}, b1={params[1]:.4f} (p={pvals[1]:.4g}), " f"b2={params[2]:.4f} (p={pvals[2]:.4g}), " f"R2_trend={metrics['r2_trend']:.4f}, AUC(has_order)={metrics['auc_on_has_orders']:.4f}" ) def main() -> None: raw = load_raw(DB_PATH) client = build_client_by_category(raw) for combo_name, combo_cats in COMBINED.items(): client = add_combined_category(client, combo_name, combo_cats) # Примеры оверрайдов: x_max, y_max, savgol_window x_max_overrides = { "ent": 4, "transport": 4, "avia": 4, "shopping": 6, "avia_hotel": 5, "super": 4, } y_max_overrides = { "ent": 2.5, "transport": 6, "avia": 1.5, "shopping": 2.5, "avia_hotel": 2.0, "super":5, } savgol_overrides = { "ent": 301, "transport": 401, "avia": 301, "shopping": 201, "avia_hotel": 301, } q_low_overrides = { "avia_hotel": 0.05, } q_high_overrides = { "avia_hotel": 0.9, } iqr_overrides = { "avia_hotel": 1.2, } corr_dir = OUT_DIR / "correlations" cats_all = CATEGORIES + list(COMBINED.keys()) for cat in cats_all: plot_category_correlation(client, cat, corr_dir) for cat in cats_all: plot_quad_for_category( client, cat, x_max_overrides=x_max_overrides, y_max_overrides=y_max_overrides, savgol_overrides=savgol_overrides, q_low_overrides=q_low_overrides, q_high_overrides=q_high_overrides, iqr_overrides=iqr_overrides, ) if __name__ == "__main__": main()