Files
dano2025/old data/quadreg.py
2025-12-16 01:51:05 +03:00

153 lines
4.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import statsmodels.api as sm
from sklearn.metrics import r2_score, roc_auc_score
import best_model_and_plots as bmp
# Константы из scatter-скрипта
X_COL = bmp.X_COL
Y_COL = "orders_amt_total"
X_MAX = bmp.DEFAULT_X_MAX
Y_MIN = bmp.DEFAULT_Y_MIN
Y_MAX = bmp.DEFAULT_Y_MAX
def fit_quadratic(
cleaned: bmp.pd.DataFrame,
trend_data: Optional[Tuple[np.ndarray, np.ndarray]],
*,
x_col: str = X_COL,
y_col: str = Y_COL,
x_max: float = X_MAX,
) -> Tuple[Optional[sm.regression.linear_model.RegressionResultsWrapper], dict]:
"""Фитит y ~ 1 + x + x^2. Если есть тренд, использует его как целевое для r2_trend."""
df = cleaned[[x_col, y_col]].dropna()
if len(df) < 3:
return None, {}
if trend_data is not None and trend_data[0] is not None:
tx, ty = trend_data
tx = np.asarray(tx)
ty = np.asarray(ty)
mask = (tx <= x_max) & ~np.isnan(ty)
tx = tx[mask]
ty = ty[mask]
else:
tx = ty = None
x = df[x_col].to_numpy()
y = df[y_col].to_numpy()
X_design = sm.add_constant(np.column_stack([x, x**2]))
model = sm.OLS(y, X_design).fit(cov_type="HC3")
auc = np.nan
binary = (y > 0).astype(int)
if len(np.unique(binary)) > 1:
auc = roc_auc_score(binary, model.predict(X_design))
r2_trend = np.nan
if tx is not None and len(tx) >= 3:
X_trend = sm.add_constant(np.column_stack([tx, tx**2]))
y_hat_trend = model.predict(X_trend)
if np.nanvar(ty) > 0:
r2_trend = r2_score(ty, y_hat_trend)
metrics = {
"auc": auc,
"r2_trend": r2_trend,
}
return model, metrics
def plot_overall_quad(
x_max: float = X_MAX,
y_min: float = Y_MIN,
y_max: float = Y_MAX,
savgol_window: int = bmp.DEFAULT_SAVGOL_WINDOW,
) -> None:
out_dir = bmp.BASE_OUT_DIR / Y_COL
res = bmp.plot_clean_trend_scatter(
bmp.load_client_level(bmp.DB_PATH),
y_col=Y_COL,
out_dir=out_dir,
x_col=X_COL,
x_max=x_max,
scatter_color=bmp.DEFAULT_SCATTER_COLOR,
point_size=bmp.DEFAULT_POINT_SIZE,
alpha=bmp.DEFAULT_TREND_ALPHA,
iqr_k=bmp.DEFAULT_IQR_K,
q_low=bmp.DEFAULT_Q_LOW,
q_high=bmp.DEFAULT_Q_HIGH,
alpha_min=bmp.DEFAULT_ALPHA_MIN,
alpha_max=bmp.DEFAULT_ALPHA_MAX,
bins_x=bmp.DEFAULT_BINS_X,
bins_y=bmp.DEFAULT_BINS_Y,
y_min=y_min,
y_max=y_max,
trend_frac=bmp.DEFAULT_TREND_FRAC,
trend_color=bmp.DEFAULT_TREND_COLOR,
trend_linewidth=bmp.DEFAULT_TREND_LINEWIDTH,
trend_method=bmp.DEFAULT_TREND_METHOD,
savgol_window=savgol_window,
return_components=True,
)
if res is None:
print("Нет данных для построения графика")
return
fig, ax, cleaned, trend_data = res
model, metrics = fit_quadratic(cleaned, trend_data, x_col=X_COL, y_col=Y_COL, x_max=x_max)
if model is None:
print("Недостаточно точек для квадратичной регрессии")
fig.savefig(out_dir / "scatter_trend.png", dpi=150)
bmp.plt.close(fig)
return
# Квадратичная линия поверх существующего тренда
x_grid = np.linspace(0, x_max, 400)
X_grid = sm.add_constant(np.column_stack([x_grid, x_grid**2]))
y_grid = model.predict(X_grid)
ax.plot(x_grid, y_grid, color="blue", linewidth=2.2, linestyle="--", label="Квадр. регрессия")
ax.legend()
params = model.params
pvals = model.pvalues
summary_lines = [
f"R2_trend={metrics['r2_trend']:.3f}",
f"AUC={metrics['auc']:.3f}",
f"b1={params[1]:.3f} (p={pvals[1]:.3g})",
f"b2={params[2]:.3f} (p={pvals[2]:.3g})",
f"n={len(cleaned)}",
]
ax.text(
0.02,
0.95,
"\n".join(summary_lines),
transform=ax.transAxes,
ha="left",
va="top",
fontsize=9,
bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.65, edgecolor="gray"),
)
quad_path = out_dir / "scatter_trend_quad.png"
fig.tight_layout()
fig.savefig(quad_path, dpi=150)
bmp.plt.close(fig)
print(f"Saved {quad_path}")
def main() -> None:
plot_overall_quad()
if __name__ == "__main__":
main()