diff --git a/deployment/run_ml_training.sh b/deployment/run_ml_training.sh old mode 100644 new mode 100755 diff --git a/deployment/setup_ml_training_timer.sh b/deployment/setup_ml_training_timer.sh old mode 100644 new mode 100755 diff --git a/python_ml/train_hybrid.py b/python_ml/train_hybrid.py index 575b785..693b2d1 100644 --- a/python_ml/train_hybrid.py +++ b/python_ml/train_hybrid.py @@ -60,6 +60,44 @@ def train_on_real_traffic(db_config: dict, days: int = 7) -> pd.DataFrame: return df +def save_training_history(db_config: dict, result: dict): + """ + Save training results to database training_history table + """ + import psycopg2 + + MODEL_VERSION = "2.0.0" # Hybrid ML Detector version + + print(f"\n[TRAIN] Saving training history to database...") + + try: + conn = psycopg2.connect(**db_config) + cursor = conn.cursor() + + cursor.execute(""" + INSERT INTO training_history + (model_version, records_processed, features_count, training_duration, status, notes) + VALUES (%s, %s, %s, %s, %s, %s) + """, ( + MODEL_VERSION, + result['records_processed'], + result['features_selected'], # Use selected features count + 0, # duration not implemented yet + 'success', + f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']} - {result['model_type']}" + )) + + conn.commit() + cursor.close() + conn.close() + + print(f"[TRAIN] āœ… Training history saved (version {MODEL_VERSION})") + + except Exception as e: + print(f"[TRAIN] ⚠ Failed to save training history: {e}") + # Don't fail the whole training if just logging fails + + def train_unsupervised(args): """ Train unsupervised model (no labels needed) @@ -71,6 +109,9 @@ def train_unsupervised(args): detector = MLHybridDetector(model_dir=args.model_dir) + # Database config for later use + db_config = None + # Load data if args.source == 'synthetic': print("\n[TRAIN] Using synthetic dataset...") @@ -109,6 +150,10 @@ def train_unsupervised(args): print(f" Model type: {result['model_type']}") print("="*70) + # Save training history to database (if using database source) + if db_config and args.source == 'database': + save_training_history(db_config, result) + print(f"\nāœ… Training completed! Models saved to: {args.model_dir}") print(f"\nNext steps:") print(f" 1. Test detection: python python_ml/test_detection.py")