Add exception handling to the model training process to log failures and improve robustness. Replit-Commit-Author: Agent Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528 Replit-Commit-Checkpoint-Type: intermediate_checkpoint Replit-Commit-Event-Id: 9c7ad6b8-3e9d-41fe-83f7-6b2a48f8ff44 Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/2lUhxO2
379 lines
13 KiB
Python
379 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
IDS Hybrid ML Training Script
|
|
Trains Extended Isolation Forest with Feature Selection on CICIDS2017 or synthetic data
|
|
Validates with production-grade metrics
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
import pandas as pd
|
|
import numpy as np
|
|
from datetime import datetime
|
|
|
|
# Import our modules
|
|
from ml_hybrid_detector import MLHybridDetector
|
|
from dataset_loader import CICIDS2017Loader
|
|
from validation_metrics import ValidationMetrics
|
|
|
|
|
|
def train_on_real_traffic(db_config: dict, days: int = 7) -> pd.DataFrame:
|
|
"""
|
|
Load real traffic from PostgreSQL database
|
|
Last N days of network_logs
|
|
"""
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
print(f"[TRAIN] Loading last {days} days of real traffic from database...")
|
|
|
|
conn = psycopg2.connect(**db_config)
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
query = """
|
|
SELECT
|
|
timestamp,
|
|
source_ip,
|
|
dest_ip,
|
|
dest_port,
|
|
protocol,
|
|
packets,
|
|
bytes,
|
|
action
|
|
FROM network_logs
|
|
WHERE timestamp > NOW() - INTERVAL '%s days'
|
|
ORDER BY timestamp DESC
|
|
LIMIT 1000000
|
|
"""
|
|
|
|
cursor.execute(query, (days,))
|
|
rows = cursor.fetchall()
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
if not rows:
|
|
raise ValueError("No data found in database")
|
|
|
|
df = pd.DataFrame(rows)
|
|
print(f"[TRAIN] Loaded {len(df)} logs from database")
|
|
|
|
return df
|
|
|
|
|
|
def train_unsupervised(args):
|
|
"""
|
|
Train unsupervised model (no labels needed)
|
|
Uses real traffic or synthetic data
|
|
"""
|
|
print("\n" + "="*70)
|
|
print(" IDS HYBRID ML TRAINING - UNSUPERVISED MODE")
|
|
print("="*70)
|
|
|
|
detector = MLHybridDetector(model_dir=args.model_dir)
|
|
|
|
# Load data
|
|
if args.source == 'synthetic':
|
|
print("\n[TRAIN] Using synthetic dataset...")
|
|
loader = CICIDS2017Loader()
|
|
logs_df = loader.create_sample_dataset(n_samples=args.n_samples)
|
|
# Remove labels for unsupervised training
|
|
logs_df = logs_df.drop(['is_attack', 'attack_type'], axis=1, errors='ignore')
|
|
|
|
elif args.source == 'database':
|
|
db_config = {
|
|
'host': args.db_host,
|
|
'port': args.db_port,
|
|
'database': args.db_name,
|
|
'user': args.db_user,
|
|
'password': args.db_password,
|
|
}
|
|
logs_df = train_on_real_traffic(db_config, days=args.days)
|
|
|
|
else:
|
|
raise ValueError(f"Invalid source: {args.source}")
|
|
|
|
# Train
|
|
print(f"\n[TRAIN] Training on {len(logs_df)} logs...")
|
|
result = detector.train_unsupervised(logs_df)
|
|
|
|
# Print results
|
|
print("\n" + "="*70)
|
|
print(" TRAINING RESULTS")
|
|
print("="*70)
|
|
print(f" Records processed: {result['records_processed']:,}")
|
|
print(f" Unique IPs: {result['unique_ips']:,}")
|
|
print(f" Features (total): {result['features_total']}")
|
|
print(f" Features (selected): {result['features_selected']}")
|
|
print(f" Anomalies detected: {result['anomalies_detected']:,} ({result['anomalies_detected']/result['unique_ips']*100:.1f}%)")
|
|
print(f" Contamination: {result['contamination']*100:.1f}%")
|
|
print(f" Model type: {result['model_type']}")
|
|
print("="*70)
|
|
|
|
print(f"\n✅ Training completed! Models saved to: {args.model_dir}")
|
|
print(f"\nNext steps:")
|
|
print(f" 1. Test detection: python python_ml/test_detection.py")
|
|
print(f" 2. Validate with CICIDS2017: python python_ml/train_hybrid.py --validate")
|
|
|
|
return detector
|
|
|
|
|
|
def validate_with_cicids(args):
|
|
"""
|
|
Validate trained model with CICIDS2017 dataset
|
|
Calculate Precision, Recall, F1, FPR
|
|
"""
|
|
print("\n" + "="*70)
|
|
print(" IDS HYBRID ML VALIDATION - CICIDS2017")
|
|
print("="*70)
|
|
|
|
# Load dataset
|
|
loader = CICIDS2017Loader(data_dir=args.cicids_dir)
|
|
|
|
# Check if dataset exists
|
|
exists, missing = loader.check_dataset_exists()
|
|
if not exists:
|
|
print("\n❌ CICIDS2017 dataset not found!")
|
|
print(loader.download_instructions())
|
|
sys.exit(1)
|
|
|
|
print("\n[VALIDATE] Loading CICIDS2017 dataset...")
|
|
|
|
# Use sample for faster testing
|
|
sample_frac = args.sample_frac if args.sample_frac > 0 else 1.0
|
|
train_df, val_df, test_df = loader.load_and_process_all(
|
|
sample_frac=sample_frac,
|
|
train_ratio=0.7,
|
|
val_ratio=0.15
|
|
)
|
|
|
|
print(f"\n[VALIDATE] Dataset split:")
|
|
print(f" Train: {len(train_df):,} samples")
|
|
print(f" Val: {len(val_df):,} samples")
|
|
print(f" Test: {len(test_df):,} samples")
|
|
|
|
# Load or train model
|
|
detector = MLHybridDetector(model_dir=args.model_dir)
|
|
|
|
if args.retrain or not detector.load_models():
|
|
print("\n[VALIDATE] Training new model on CICIDS2017 training set...")
|
|
# Use normal traffic only for unsupervised training
|
|
normal_train = train_df[train_df['is_attack'] == 0].drop(['is_attack', 'attack_type'], axis=1)
|
|
result = detector.train_unsupervised(normal_train)
|
|
print(f" Trained on {len(normal_train):,} normal traffic samples")
|
|
else:
|
|
print("\n[VALIDATE] Using existing trained model")
|
|
|
|
# Validate on test set
|
|
print("\n[VALIDATE] Running detection on test set...")
|
|
test_logs = test_df.drop(['is_attack', 'attack_type'], axis=1)
|
|
detections = detector.detect(test_logs, mode='all')
|
|
|
|
# Convert detections to binary predictions
|
|
# Create set of detected IPs with risk_score >= 60 (configurable threshold)
|
|
detection_threshold = 60
|
|
detected_ips = {d['source_ip'] for d in detections if d['risk_score'] >= detection_threshold}
|
|
|
|
print(f"[VALIDATE] Detected {len(detected_ips)} unique IPs above threshold {detection_threshold}")
|
|
|
|
# Create predictions array by mapping source_ip
|
|
y_true = test_df['is_attack'].values
|
|
y_pred = np.zeros(len(test_df), dtype=int)
|
|
|
|
# Map detections to test_df rows using source_ip
|
|
for i, row in test_df.iterrows():
|
|
if row['source_ip'] in detected_ips:
|
|
y_pred[i] = 1
|
|
|
|
# Calculate metrics
|
|
print("\n[VALIDATE] Calculating validation metrics...")
|
|
validator = ValidationMetrics()
|
|
metrics = validator.calculate(y_true, y_pred)
|
|
|
|
# Print summary
|
|
validator.print_summary(metrics, title="CICIDS2017 Validation Results")
|
|
|
|
# Check production criteria
|
|
print("\n[VALIDATE] Checking production deployment criteria...")
|
|
passes, issues = validator.meets_production_criteria(
|
|
metrics,
|
|
min_precision=0.90,
|
|
max_fpr=0.05,
|
|
min_recall=0.80
|
|
)
|
|
|
|
# Save metrics
|
|
metrics_file = Path(args.model_dir) / f"validation_metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
validator.save_metrics(metrics, str(metrics_file))
|
|
|
|
if passes:
|
|
print(f"\n🎉 Model ready for production deployment!")
|
|
else:
|
|
print(f"\n⚠️ Model needs improvement before production")
|
|
print(f"\nSuggestions:")
|
|
print(f" - Adjust contamination parameter (currently {detector.config['eif_contamination']})")
|
|
print(f" - Increase n_estimators for more stable predictions")
|
|
print(f" - Review feature selection threshold")
|
|
|
|
return detector, metrics
|
|
|
|
|
|
def test_on_synthetic(args):
|
|
"""
|
|
Quick test on synthetic data to verify system works
|
|
"""
|
|
print("\n" + "="*70)
|
|
print(" IDS HYBRID ML TEST - SYNTHETIC DATA")
|
|
print("="*70)
|
|
|
|
# Create synthetic dataset
|
|
loader = CICIDS2017Loader()
|
|
df = loader.create_sample_dataset(n_samples=args.n_samples)
|
|
|
|
print(f"\n[TEST] Created synthetic dataset: {len(df)} samples")
|
|
print(f" Normal: {(df['is_attack']==0).sum():,} ({(df['is_attack']==0).sum()/len(df)*100:.1f}%)")
|
|
print(f" Attacks: {(df['is_attack']==1).sum():,} ({(df['is_attack']==1).sum()/len(df)*100:.1f}%)")
|
|
|
|
# Split
|
|
n = len(df)
|
|
n_train = int(n * 0.7)
|
|
|
|
train_df = df.iloc[:n_train]
|
|
test_df = df.iloc[n_train:]
|
|
|
|
# Train on normal traffic only
|
|
detector = MLHybridDetector(model_dir=args.model_dir)
|
|
normal_train = train_df[train_df['is_attack'] == 0].drop(['is_attack', 'attack_type'], axis=1)
|
|
|
|
print(f"\n[TEST] Training on {len(normal_train):,} normal samples...")
|
|
detector.train_unsupervised(normal_train)
|
|
|
|
# Test detection
|
|
test_logs = test_df.drop(['is_attack', 'attack_type'], axis=1)
|
|
detections = detector.detect(test_logs, mode='all')
|
|
|
|
print(f"\n[TEST] Detection results:")
|
|
print(f" Total detections: {len(detections)}")
|
|
|
|
# Count by confidence
|
|
confidence_counts = {'high': 0, 'medium': 0, 'low': 0}
|
|
for d in detections:
|
|
confidence_counts[d['confidence_level']] += 1
|
|
|
|
print(f" High confidence: {confidence_counts['high']}")
|
|
print(f" Medium confidence: {confidence_counts['medium']}")
|
|
print(f" Low confidence: {confidence_counts['low']}")
|
|
|
|
# Show top 5 detections
|
|
print(f"\n[TEST] Top 5 detections:")
|
|
for i, d in enumerate(detections[:5], 1):
|
|
print(f" {i}. {d['source_ip']}: risk={d['risk_score']:.1f}, "
|
|
f"type={d['anomaly_type']}, confidence={d['confidence_level']}")
|
|
|
|
# Validation - map detections to test_df rows using source_ip
|
|
detection_threshold = 60
|
|
detected_ips = {d['source_ip'] for d in detections if d['risk_score'] >= detection_threshold}
|
|
|
|
y_true = test_df['is_attack'].values
|
|
y_pred = np.zeros(len(test_df), dtype=int)
|
|
|
|
# Map detections to test_df rows
|
|
for i, row in test_df.iterrows():
|
|
if row['source_ip'] in detected_ips:
|
|
y_pred[i] = 1
|
|
|
|
validator = ValidationMetrics()
|
|
metrics = validator.calculate(y_true, y_pred)
|
|
validator.print_summary(metrics, title="Synthetic Test Results")
|
|
|
|
print("\n✅ System test completed!")
|
|
|
|
# Check if ensemble was trained
|
|
if detector.ensemble_classifier is None:
|
|
print("\n⚠️ WARNING: System running in IF-only mode (no ensemble)")
|
|
print(" This may occur with very clean datasets")
|
|
print(" Expected metrics will be lower than hybrid mode")
|
|
|
|
return detector, metrics
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Train and validate IDS Hybrid ML Detector",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Quick test with synthetic data
|
|
python train_hybrid.py --test
|
|
|
|
# Train on real traffic from database
|
|
python train_hybrid.py --source database --days 7
|
|
|
|
# Validate with CICIDS2017 (full dataset)
|
|
python train_hybrid.py --validate
|
|
|
|
# Validate with CICIDS2017 (10% sample for testing)
|
|
python train_hybrid.py --validate --sample 0.1
|
|
"""
|
|
)
|
|
|
|
# Mode selection
|
|
mode = parser.add_mutually_exclusive_group(required=True)
|
|
mode.add_argument('--train', action='store_true', help='Train unsupervised model')
|
|
mode.add_argument('--validate', action='store_true', help='Validate with CICIDS2017')
|
|
mode.add_argument('--test', action='store_true', help='Quick test with synthetic data')
|
|
|
|
# Data source
|
|
parser.add_argument('--source', choices=['synthetic', 'database'], default='synthetic',
|
|
help='Data source for training (default: synthetic)')
|
|
|
|
# Database options
|
|
parser.add_argument('--db-host', default='localhost', help='Database host')
|
|
parser.add_argument('--db-port', type=int, default=5432, help='Database port')
|
|
parser.add_argument('--db-name', default='ids', help='Database name')
|
|
parser.add_argument('--db-user', default='postgres', help='Database user')
|
|
parser.add_argument('--db-password', help='Database password')
|
|
parser.add_argument('--days', type=int, default=7, help='Days of traffic to load from DB')
|
|
|
|
# CICIDS2017 options
|
|
parser.add_argument('--cicids-dir', default='datasets/cicids2017',
|
|
help='CICIDS2017 dataset directory')
|
|
parser.add_argument('--sample', type=float, dest='sample_frac', default=0,
|
|
help='Sample fraction of CICIDS2017 (0.1 = 10%, 0 = all)')
|
|
parser.add_argument('--retrain', action='store_true',
|
|
help='Force retrain even if model exists')
|
|
|
|
# General options
|
|
parser.add_argument('--model-dir', default='models', help='Model save directory')
|
|
parser.add_argument('--n-samples', type=int, default=10000,
|
|
help='Number of synthetic samples to generate')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate database password if needed
|
|
if args.source == 'database' and not args.db_password:
|
|
print("Error: --db-password required when using --source database")
|
|
sys.exit(1)
|
|
|
|
# Execute mode
|
|
try:
|
|
if args.test:
|
|
test_on_synthetic(args)
|
|
elif args.validate:
|
|
validate_with_cicids(args)
|
|
elif args.train:
|
|
train_unsupervised(args)
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n\nInterrupted by user")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"\n❌ Error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|