gadem
This commit is contained in:
152
old data/quadreg.py
Normal file
152
old data/quadreg.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import statsmodels.api as sm
|
||||
from sklearn.metrics import r2_score, roc_auc_score
|
||||
|
||||
import best_model_and_plots as bmp
|
||||
|
||||
# Константы из scatter-скрипта
|
||||
X_COL = bmp.X_COL
|
||||
Y_COL = "orders_amt_total"
|
||||
X_MAX = bmp.DEFAULT_X_MAX
|
||||
Y_MIN = bmp.DEFAULT_Y_MIN
|
||||
Y_MAX = bmp.DEFAULT_Y_MAX
|
||||
|
||||
|
||||
def fit_quadratic(
|
||||
cleaned: bmp.pd.DataFrame,
|
||||
trend_data: Optional[Tuple[np.ndarray, np.ndarray]],
|
||||
*,
|
||||
x_col: str = X_COL,
|
||||
y_col: str = Y_COL,
|
||||
x_max: float = X_MAX,
|
||||
) -> Tuple[Optional[sm.regression.linear_model.RegressionResultsWrapper], dict]:
|
||||
"""Фитит y ~ 1 + x + x^2. Если есть тренд, использует его как целевое для r2_trend."""
|
||||
df = cleaned[[x_col, y_col]].dropna()
|
||||
if len(df) < 3:
|
||||
return None, {}
|
||||
|
||||
if trend_data is not None and trend_data[0] is not None:
|
||||
tx, ty = trend_data
|
||||
tx = np.asarray(tx)
|
||||
ty = np.asarray(ty)
|
||||
mask = (tx <= x_max) & ~np.isnan(ty)
|
||||
tx = tx[mask]
|
||||
ty = ty[mask]
|
||||
else:
|
||||
tx = ty = None
|
||||
|
||||
x = df[x_col].to_numpy()
|
||||
y = df[y_col].to_numpy()
|
||||
|
||||
X_design = sm.add_constant(np.column_stack([x, x**2]))
|
||||
model = sm.OLS(y, X_design).fit(cov_type="HC3")
|
||||
|
||||
auc = np.nan
|
||||
binary = (y > 0).astype(int)
|
||||
if len(np.unique(binary)) > 1:
|
||||
auc = roc_auc_score(binary, model.predict(X_design))
|
||||
|
||||
r2_trend = np.nan
|
||||
if tx is not None and len(tx) >= 3:
|
||||
X_trend = sm.add_constant(np.column_stack([tx, tx**2]))
|
||||
y_hat_trend = model.predict(X_trend)
|
||||
if np.nanvar(ty) > 0:
|
||||
r2_trend = r2_score(ty, y_hat_trend)
|
||||
|
||||
metrics = {
|
||||
"auc": auc,
|
||||
"r2_trend": r2_trend,
|
||||
}
|
||||
return model, metrics
|
||||
|
||||
|
||||
def plot_overall_quad(
|
||||
x_max: float = X_MAX,
|
||||
y_min: float = Y_MIN,
|
||||
y_max: float = Y_MAX,
|
||||
savgol_window: int = bmp.DEFAULT_SAVGOL_WINDOW,
|
||||
) -> None:
|
||||
out_dir = bmp.BASE_OUT_DIR / Y_COL
|
||||
|
||||
res = bmp.plot_clean_trend_scatter(
|
||||
bmp.load_client_level(bmp.DB_PATH),
|
||||
y_col=Y_COL,
|
||||
out_dir=out_dir,
|
||||
x_col=X_COL,
|
||||
x_max=x_max,
|
||||
scatter_color=bmp.DEFAULT_SCATTER_COLOR,
|
||||
point_size=bmp.DEFAULT_POINT_SIZE,
|
||||
alpha=bmp.DEFAULT_TREND_ALPHA,
|
||||
iqr_k=bmp.DEFAULT_IQR_K,
|
||||
q_low=bmp.DEFAULT_Q_LOW,
|
||||
q_high=bmp.DEFAULT_Q_HIGH,
|
||||
alpha_min=bmp.DEFAULT_ALPHA_MIN,
|
||||
alpha_max=bmp.DEFAULT_ALPHA_MAX,
|
||||
bins_x=bmp.DEFAULT_BINS_X,
|
||||
bins_y=bmp.DEFAULT_BINS_Y,
|
||||
y_min=y_min,
|
||||
y_max=y_max,
|
||||
trend_frac=bmp.DEFAULT_TREND_FRAC,
|
||||
trend_color=bmp.DEFAULT_TREND_COLOR,
|
||||
trend_linewidth=bmp.DEFAULT_TREND_LINEWIDTH,
|
||||
trend_method=bmp.DEFAULT_TREND_METHOD,
|
||||
savgol_window=savgol_window,
|
||||
return_components=True,
|
||||
)
|
||||
|
||||
if res is None:
|
||||
print("Нет данных для построения графика")
|
||||
return
|
||||
|
||||
fig, ax, cleaned, trend_data = res
|
||||
model, metrics = fit_quadratic(cleaned, trend_data, x_col=X_COL, y_col=Y_COL, x_max=x_max)
|
||||
|
||||
if model is None:
|
||||
print("Недостаточно точек для квадратичной регрессии")
|
||||
fig.savefig(out_dir / "scatter_trend.png", dpi=150)
|
||||
bmp.plt.close(fig)
|
||||
return
|
||||
|
||||
# Квадратичная линия поверх существующего тренда
|
||||
x_grid = np.linspace(0, x_max, 400)
|
||||
X_grid = sm.add_constant(np.column_stack([x_grid, x_grid**2]))
|
||||
y_grid = model.predict(X_grid)
|
||||
ax.plot(x_grid, y_grid, color="blue", linewidth=2.2, linestyle="--", label="Квадр. регрессия")
|
||||
ax.legend()
|
||||
|
||||
params = model.params
|
||||
pvals = model.pvalues
|
||||
summary_lines = [
|
||||
f"R2_trend={metrics['r2_trend']:.3f}",
|
||||
f"AUC={metrics['auc']:.3f}",
|
||||
f"b1={params[1]:.3f} (p={pvals[1]:.3g})",
|
||||
f"b2={params[2]:.3f} (p={pvals[2]:.3g})",
|
||||
f"n={len(cleaned)}",
|
||||
]
|
||||
ax.text(
|
||||
0.02,
|
||||
0.95,
|
||||
"\n".join(summary_lines),
|
||||
transform=ax.transAxes,
|
||||
ha="left",
|
||||
va="top",
|
||||
fontsize=9,
|
||||
bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.65, edgecolor="gray"),
|
||||
)
|
||||
|
||||
quad_path = out_dir / "scatter_trend_quad.png"
|
||||
fig.tight_layout()
|
||||
fig.savefig(quad_path, dpi=150)
|
||||
bmp.plt.close(fig)
|
||||
print(f"Saved {quad_path}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
plot_overall_quad()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user