Latest metrics and executive notes are available in the report.
Open current reportStep-by-step: clean/encode → split/scale → train a PyTorch logistic model → evaluate.
import numpy as np, pandas as pd, torch
from torch import nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import metrics
SEED = 42
class TorchLogReg(nn.Module):
def __init__(self, in_features):
super().__init__()
self.linear = nn.Linear(in_features, 1)
def forward(self, x):
return self.linear(x)
Training uses BCEWithLogitsLoss with Adam, early stopping on validation loss, and standardization via StandardScaler. See the notebook for the full pipeline and evaluation code.
# Class balance (Matplotlib/Seaborn)
vc = df['class_asd'].value_counts().sort_index()
ax = sns.barplot(x=vc.index, y=vc.values, palette=['#4C78A8','#F58518'])
ax.set_xlabel('class_asd'); ax.set_ylabel('count'); ax.set_title('Class Balance')
plt.tight_layout(); plt.savefig('figures/eda_class_balance.png', dpi=180)
# Correlation heatmap
num_corr = df.select_dtypes(include='number').corr()
sns.heatmap(num_corr, cmap='coolwarm', center=0, cbar=True)
plt.title('Correlation Heatmap')
plt.tight_layout(); plt.savefig('figures/eda_corr_heatmap.png', dpi=180)
Counts of ASD vs non‑ASD labels in the dataset.
Pairwise correlations among numeric features (darker = stronger).
# Age histogram by class
sns.histplot(data=df, x='age', hue='class_asd', stat='density', common_norm=False, element='step')
plt.tight_layout(); plt.savefig('figures/hist_age_by_class.png', dpi=180)
# Result (A1–A10 total) histogram by class
sns.histplot(data=df, x='result', hue='class_asd', stat='density', common_norm=False, element='step')
plt.tight_layout(); plt.savefig('figures/hist_result_by_class.png', dpi=180)
Distribution of ages split by label.
Distribution of screening score totals.
# Used app before by class
sns.countplot(data=df, x='used_app_before', hue='class_asd')
plt.tight_layout(); plt.savefig('figures/count_used_app_before_by_class.png', dpi=180)
# Family autism history by class
sns.countplot(data=df, x='austim', hue='class_asd')
plt.tight_layout(); plt.savefig('figures/count_austim_by_class.png', dpi=180)
Self‑reported family history.
# Jaundice counts by class
sns.countplot(data=df, x='jundice', hue='class_asd')
plt.tight_layout(); plt.savefig('figures/count_jundice_by_class.png', dpi=180)
# RandomForest feature importances (top 20)
rf = RandomForestClassifier(
n_estimators=400, max_depth=12, min_samples_leaf=5,
max_features='sqrt', class_weight='balanced', random_state=42)
rf.fit(X_train, y_train)
imp = pd.Series(rf.feature_importances_, index=X_train.columns).sort_values(ascending=False).head(20)
imp.plot(kind='barh')
plt.gca().invert_yaxis(); plt.tight_layout(); plt.savefig('figures/featimp_RandomForest.png', dpi=180)
Reported jaundice split by label.
Variable importance from RF (subject to re‑estimation).
# Training curve (recorded each epoch)
plt.plot(history['train_loss'], label='train'); plt.plot(history['val_loss'], label='val')
plt.legend(); plt.tight_layout(); plt.savefig('figures/torch_logreg_training_curve.png', dpi=180)
# Confusion matrix (sklearn)
from sklearn.metrics import ConfusionMatrixDisplay
ConfusionMatrixDisplay.from_predictions(y_test, y_pred)
plt.tight_layout(); plt.savefig('figures/cm_TorchLogReg.png', dpi=180)
Loss across epochs (train vs validation).
Predicted vs true labels on held‑out test data.
What these visuals mean for you