Files
dano2025/main_hypot/quadreg.py
2025-12-16 01:51:05 +03:00

156 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Обёртка для построения общей квадратичной регрессии заказов от среднего числа показов."""
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import statsmodels.api as sm
from sklearn.metrics import r2_score, roc_auc_score
import best_model_and_plots as bmp
# Константы из scatter-скрипта
X_COL = bmp.X_COL
Y_COL = "orders_amt_total"
X_MAX = bmp.DEFAULT_X_MAX
Y_MIN = bmp.DEFAULT_Y_MIN
Y_MAX = bmp.DEFAULT_Y_MAX
def fit_quadratic(
cleaned: bmp.pd.DataFrame,
trend_data: Optional[Tuple[np.ndarray, np.ndarray]],
*,
x_col: str = X_COL,
y_col: str = Y_COL,
x_max: float = X_MAX,
) -> Tuple[Optional[sm.regression.linear_model.RegressionResultsWrapper], dict]:
"""Фитит y ~ 1 + x + x^2. Если есть тренд, использует его как целевое для r2_trend."""
df = cleaned[[x_col, y_col]].dropna()
if len(df) < 3:
return None, {}
if trend_data is not None and trend_data[0] is not None:
tx, ty = trend_data
tx = np.asarray(tx)
ty = np.asarray(ty)
mask = (tx <= x_max) & ~np.isnan(ty)
tx = tx[mask]
ty = ty[mask]
else:
tx = ty = None
x = df[x_col].to_numpy()
y = df[y_col].to_numpy()
X_design = sm.add_constant(np.column_stack([x, x**2]))
model = sm.OLS(y, X_design).fit(cov_type="HC3")
auc = np.nan
binary = (y > 0).astype(int)
if len(np.unique(binary)) > 1:
auc = roc_auc_score(binary, model.predict(X_design))
r2_trend = np.nan
if tx is not None and len(tx) >= 3:
X_trend = sm.add_constant(np.column_stack([tx, tx**2]))
y_hat_trend = model.predict(X_trend)
if np.nanvar(ty) > 0:
r2_trend = r2_score(ty, y_hat_trend)
metrics = {
"auc": auc,
"r2_trend": r2_trend,
}
return model, metrics
def plot_overall_quad(
x_max: float = X_MAX,
y_min: float = Y_MIN,
y_max: float = Y_MAX,
savgol_window: int = bmp.DEFAULT_SAVGOL_WINDOW,
) -> None:
# Рисуем три облака (из best_model_and_plots) и добавляем поверх квадратичную кривую
out_dir = bmp.BASE_OUT_DIR / Y_COL
res = bmp.plot_clean_trend_scatter(
bmp.load_client_level(bmp.DB_PATH),
y_col=Y_COL,
out_dir=out_dir,
x_col=X_COL,
x_max=x_max,
scatter_color=bmp.DEFAULT_SCATTER_COLOR,
point_size=bmp.DEFAULT_POINT_SIZE,
alpha=bmp.DEFAULT_TREND_ALPHA,
iqr_k=bmp.DEFAULT_IQR_K,
q_low=bmp.DEFAULT_Q_LOW,
q_high=bmp.DEFAULT_Q_HIGH,
alpha_min=bmp.DEFAULT_ALPHA_MIN,
alpha_max=bmp.DEFAULT_ALPHA_MAX,
bins_x=bmp.DEFAULT_BINS_X,
bins_y=bmp.DEFAULT_BINS_Y,
y_min=y_min,
y_max=y_max,
trend_frac=bmp.DEFAULT_TREND_FRAC,
trend_color=bmp.DEFAULT_TREND_COLOR,
trend_linewidth=bmp.DEFAULT_TREND_LINEWIDTH,
trend_method=bmp.DEFAULT_TREND_METHOD,
savgol_window=savgol_window,
return_components=True,
)
if res is None:
print("Нет данных для построения графика")
return
fig, ax, cleaned, trend_data = res
model, metrics = fit_quadratic(cleaned, trend_data, x_col=X_COL, y_col=Y_COL, x_max=x_max)
if model is None:
print("Недостаточно точек для квадратичной регрессии")
fig.savefig(out_dir / "scatter_trend.png", dpi=150)
bmp.plt.close(fig)
return
# Квадратичная линия поверх существующего тренда
x_grid = np.linspace(0, x_max, 400)
X_grid = sm.add_constant(np.column_stack([x_grid, x_grid**2]))
y_grid = model.predict(X_grid)
ax.plot(x_grid, y_grid, color="blue", linewidth=2.2, linestyle="--", label="Квадр. регрессия")
ax.legend()
params = model.params
pvals = model.pvalues
summary_lines = [
f"R2_trend={metrics['r2_trend']:.3f}",
f"AUC={metrics['auc']:.3f}",
f"b1={params[1]:.3f} (p={pvals[1]:.3g})",
f"b2={params[2]:.3f} (p={pvals[2]:.3g})",
f"n={len(cleaned)}",
]
ax.text(
0.02,
0.95,
"\n".join(summary_lines),
transform=ax.transAxes,
ha="left",
va="top",
fontsize=9,
bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.65, edgecolor="gray"),
)
quad_path = out_dir / "scatter_trend_quad.png"
fig.tight_layout()
fig.savefig(quad_path, dpi=150)
bmp.plt.close(fig)
print(f"Saved {quad_path}")
def main() -> None:
plot_overall_quad()
if __name__ == "__main__":
main()