#!/usr/bin/env python3 """ IDS Hybrid ML Training Script Trains Extended Isolation Forest with Feature Selection on CICIDS2017 or synthetic data Validates with production-grade metrics """ import argparse import sys from pathlib import Path import pandas as pd import numpy as np from datetime import datetime # Import our modules from ml_hybrid_detector import MLHybridDetector from dataset_loader import CICIDS2017Loader from validation_metrics import ValidationMetrics def train_on_real_traffic(db_config: dict, days: int = 7) -> pd.DataFrame: """ Load real traffic from PostgreSQL database Last N days of network_logs """ import psycopg2 from psycopg2.extras import RealDictCursor print(f"[TRAIN] Loading last {days} days of real traffic from database...") conn = psycopg2.connect(**db_config) cursor = conn.cursor(cursor_factory=RealDictCursor) query = """ SELECT timestamp, source_ip, dest_ip, dest_port, protocol, packets, bytes, action FROM network_logs WHERE timestamp > NOW() - INTERVAL '%s days' ORDER BY timestamp DESC LIMIT 1000000 """ cursor.execute(query, (days,)) rows = cursor.fetchall() cursor.close() conn.close() if not rows: raise ValueError("No data found in database") df = pd.DataFrame(rows) print(f"[TRAIN] Loaded {len(df)} logs from database") return df def train_unsupervised(args): """ Train unsupervised model (no labels needed) Uses real traffic or synthetic data """ print("\n" + "="*70) print(" IDS HYBRID ML TRAINING - UNSUPERVISED MODE") print("="*70) detector = MLHybridDetector(model_dir=args.model_dir) # Load data if args.source == 'synthetic': print("\n[TRAIN] Using synthetic dataset...") loader = CICIDS2017Loader() logs_df = loader.create_sample_dataset(n_samples=args.n_samples) # Remove labels for unsupervised training logs_df = logs_df.drop(['is_attack', 'attack_type'], axis=1, errors='ignore') elif args.source == 'database': db_config = { 'host': args.db_host, 'port': args.db_port, 'database': args.db_name, 'user': args.db_user, 'password': args.db_password, } logs_df = train_on_real_traffic(db_config, days=args.days) else: raise ValueError(f"Invalid source: {args.source}") # Train print(f"\n[TRAIN] Training on {len(logs_df)} logs...") result = detector.train_unsupervised(logs_df) # Print results print("\n" + "="*70) print(" TRAINING RESULTS") print("="*70) print(f" Records processed: {result['records_processed']:,}") print(f" Unique IPs: {result['unique_ips']:,}") print(f" Features (total): {result['features_total']}") print(f" Features (selected): {result['features_selected']}") print(f" Anomalies detected: {result['anomalies_detected']:,} ({result['anomalies_detected']/result['unique_ips']*100:.1f}%)") print(f" Contamination: {result['contamination']*100:.1f}%") print(f" Model type: {result['model_type']}") print("="*70) print(f"\n✅ Training completed! Models saved to: {args.model_dir}") print(f"\nNext steps:") print(f" 1. Test detection: python python_ml/test_detection.py") print(f" 2. Validate with CICIDS2017: python python_ml/train_hybrid.py --validate") return detector def validate_with_cicids(args): """ Validate trained model with CICIDS2017 dataset Calculate Precision, Recall, F1, FPR """ print("\n" + "="*70) print(" IDS HYBRID ML VALIDATION - CICIDS2017") print("="*70) # Load dataset loader = CICIDS2017Loader(data_dir=args.cicids_dir) # Check if dataset exists exists, missing = loader.check_dataset_exists() if not exists: print("\n❌ CICIDS2017 dataset not found!") print(loader.download_instructions()) sys.exit(1) print("\n[VALIDATE] Loading CICIDS2017 dataset...") # Use sample for faster testing sample_frac = args.sample_frac if args.sample_frac > 0 else 1.0 train_df, val_df, test_df = loader.load_and_process_all( sample_frac=sample_frac, train_ratio=0.7, val_ratio=0.15 ) print(f"\n[VALIDATE] Dataset split:") print(f" Train: {len(train_df):,} samples") print(f" Val: {len(val_df):,} samples") print(f" Test: {len(test_df):,} samples") # Load or train model detector = MLHybridDetector(model_dir=args.model_dir) if args.retrain or not detector.load_models(): print("\n[VALIDATE] Training new model on CICIDS2017 training set...") # Use normal traffic only for unsupervised training normal_train = train_df[train_df['is_attack'] == 0].drop(['is_attack', 'attack_type'], axis=1) result = detector.train_unsupervised(normal_train) print(f" Trained on {len(normal_train):,} normal traffic samples") else: print("\n[VALIDATE] Using existing trained model") # Validate on test set print("\n[VALIDATE] Running detection on test set...") test_logs = test_df.drop(['is_attack', 'attack_type'], axis=1) detections = detector.detect(test_logs, mode='all') # Convert detections to binary predictions # Create set of detected IPs with risk_score >= 60 (configurable threshold) detection_threshold = 60 detected_ips = {d['source_ip'] for d in detections if d['risk_score'] >= detection_threshold} print(f"[VALIDATE] Detected {len(detected_ips)} unique IPs above threshold {detection_threshold}") # Create predictions array by mapping source_ip y_true = test_df['is_attack'].values y_pred = np.zeros(len(test_df), dtype=int) # Map detections to test_df rows using source_ip for i, row in test_df.iterrows(): if row['source_ip'] in detected_ips: y_pred[i] = 1 # Calculate metrics print("\n[VALIDATE] Calculating validation metrics...") validator = ValidationMetrics() metrics = validator.calculate(y_true, y_pred) # Print summary validator.print_summary(metrics, title="CICIDS2017 Validation Results") # Check production criteria print("\n[VALIDATE] Checking production deployment criteria...") passes, issues = validator.meets_production_criteria( metrics, min_precision=0.90, max_fpr=0.05, min_recall=0.80 ) # Save metrics metrics_file = Path(args.model_dir) / f"validation_metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" validator.save_metrics(metrics, str(metrics_file)) if passes: print(f"\n🎉 Model ready for production deployment!") else: print(f"\n⚠️ Model needs improvement before production") print(f"\nSuggestions:") print(f" - Adjust contamination parameter (currently {detector.config['eif_contamination']})") print(f" - Increase n_estimators for more stable predictions") print(f" - Review feature selection threshold") return detector, metrics def test_on_synthetic(args): """ Quick test on synthetic data to verify system works """ print("\n" + "="*70) print(" IDS HYBRID ML TEST - SYNTHETIC DATA") print("="*70) # Create synthetic dataset loader = CICIDS2017Loader() df = loader.create_sample_dataset(n_samples=args.n_samples) print(f"\n[TEST] Created synthetic dataset: {len(df)} samples") print(f" Normal: {(df['is_attack']==0).sum():,} ({(df['is_attack']==0).sum()/len(df)*100:.1f}%)") print(f" Attacks: {(df['is_attack']==1).sum():,} ({(df['is_attack']==1).sum()/len(df)*100:.1f}%)") # Split n = len(df) n_train = int(n * 0.7) train_df = df.iloc[:n_train] test_df = df.iloc[n_train:] # Train on normal traffic only detector = MLHybridDetector(model_dir=args.model_dir) normal_train = train_df[train_df['is_attack'] == 0].drop(['is_attack', 'attack_type'], axis=1) print(f"\n[TEST] Training on {len(normal_train):,} normal samples...") detector.train_unsupervised(normal_train) # Test detection test_logs = test_df.drop(['is_attack', 'attack_type'], axis=1) detections = detector.detect(test_logs, mode='all') print(f"\n[TEST] Detection results:") print(f" Total detections: {len(detections)}") # Count by confidence confidence_counts = {'high': 0, 'medium': 0, 'low': 0} for d in detections: confidence_counts[d['confidence_level']] += 1 print(f" High confidence: {confidence_counts['high']}") print(f" Medium confidence: {confidence_counts['medium']}") print(f" Low confidence: {confidence_counts['low']}") # Show top 5 detections print(f"\n[TEST] Top 5 detections:") for i, d in enumerate(detections[:5], 1): print(f" {i}. {d['source_ip']}: risk={d['risk_score']:.1f}, " f"type={d['anomaly_type']}, confidence={d['confidence_level']}") # Validation - map detections to test_df rows using source_ip detection_threshold = 60 detected_ips = {d['source_ip'] for d in detections if d['risk_score'] >= detection_threshold} y_true = test_df['is_attack'].values y_pred = np.zeros(len(test_df), dtype=int) # Map detections to test_df rows for i, row in test_df.iterrows(): if row['source_ip'] in detected_ips: y_pred[i] = 1 validator = ValidationMetrics() metrics = validator.calculate(y_true, y_pred) validator.print_summary(metrics, title="Synthetic Test Results") print("\n✅ System test completed successfully!") return detector, metrics def main(): parser = argparse.ArgumentParser( description="Train and validate IDS Hybrid ML Detector", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Quick test with synthetic data python train_hybrid.py --test # Train on real traffic from database python train_hybrid.py --source database --days 7 # Validate with CICIDS2017 (full dataset) python train_hybrid.py --validate # Validate with CICIDS2017 (10% sample for testing) python train_hybrid.py --validate --sample 0.1 """ ) # Mode selection mode = parser.add_mutually_exclusive_group(required=True) mode.add_argument('--train', action='store_true', help='Train unsupervised model') mode.add_argument('--validate', action='store_true', help='Validate with CICIDS2017') mode.add_argument('--test', action='store_true', help='Quick test with synthetic data') # Data source parser.add_argument('--source', choices=['synthetic', 'database'], default='synthetic', help='Data source for training (default: synthetic)') # Database options parser.add_argument('--db-host', default='localhost', help='Database host') parser.add_argument('--db-port', type=int, default=5432, help='Database port') parser.add_argument('--db-name', default='ids', help='Database name') parser.add_argument('--db-user', default='postgres', help='Database user') parser.add_argument('--db-password', help='Database password') parser.add_argument('--days', type=int, default=7, help='Days of traffic to load from DB') # CICIDS2017 options parser.add_argument('--cicids-dir', default='datasets/cicids2017', help='CICIDS2017 dataset directory') parser.add_argument('--sample', type=float, dest='sample_frac', default=0, help='Sample fraction of CICIDS2017 (0.1 = 10%, 0 = all)') parser.add_argument('--retrain', action='store_true', help='Force retrain even if model exists') # General options parser.add_argument('--model-dir', default='models', help='Model save directory') parser.add_argument('--n-samples', type=int, default=10000, help='Number of synthetic samples to generate') args = parser.parse_args() # Validate database password if needed if args.source == 'database' and not args.db_password: print("Error: --db-password required when using --source database") sys.exit(1) # Execute mode try: if args.test: test_on_synthetic(args) elif args.validate: validate_with_cicids(args) elif args.train: train_unsupervised(args) except KeyboardInterrupt: print("\n\nInterrupted by user") sys.exit(1) except Exception as e: print(f"\n❌ Error: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == '__main__': main()