From 350a0994bdff8ff0281e4e1b483d3415f2adeeb7 Mon Sep 17 00:00:00 2001 From: marco370 <48531002-marco370@users.noreply.replit.com> Date: Mon, 24 Nov 2025 15:55:30 +0000 Subject: [PATCH] Add dataset loader and validation metrics modules Introduces `CICIDS2017Loader` for dataset handling and `ValidationMetrics` class for calculating performance metrics in Python. Replit-Commit-Author: Agent Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528 Replit-Commit-Checkpoint-Type: intermediate_checkpoint Replit-Commit-Event-Id: ad530f16-3a16-44a3-8fed-6c5d56775c77 Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/F6DiMv4 --- python_ml/dataset_loader.py | 374 ++++++++++++++++++++++++++++++++ python_ml/validation_metrics.py | 324 +++++++++++++++++++++++++++ 2 files changed, 698 insertions(+) create mode 100644 python_ml/dataset_loader.py create mode 100644 python_ml/validation_metrics.py diff --git a/python_ml/dataset_loader.py b/python_ml/dataset_loader.py new file mode 100644 index 0000000..ae51a1e --- /dev/null +++ b/python_ml/dataset_loader.py @@ -0,0 +1,374 @@ +""" +CICIDS2017 Dataset Loader and Preprocessor +Downloads, cleans, and maps CICIDS2017 features to IDS feature space +""" + +import pandas as pd +import numpy as np +from pathlib import Path +from typing import Dict, Tuple, Optional +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class CICIDS2017Loader: + """ + Loads and preprocesses CICIDS2017 dataset + Maps 80 CIC features to 25 IDS features + """ + + DATASET_INFO = { + 'name': 'CICIDS2017', + 'source': 'Canadian Institute for Cybersecurity', + 'url': 'https://www.unb.ca/cic/datasets/ids-2017.html', + 'size_gb': 7.8, + 'files': [ + 'Monday-WorkingHours.pcap_ISCX.csv', + 'Tuesday-WorkingHours.pcap_ISCX.csv', + 'Wednesday-workingHours.pcap_ISCX.csv', + 'Thursday-WorkingHours-Morning-WebAttacks.pcap_ISCX.csv', + 'Thursday-WorkingHours-Afternoon-Infilteration.pcap_ISCX.csv', + 'Friday-WorkingHours-Morning.pcap_ISCX.csv', + 'Friday-WorkingHours-Afternoon-PortScan.pcap_ISCX.csv', + 'Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv', + ] + } + + # Mapping CIC feature names → IDS feature names + FEATURE_MAPPING = { + # Volume features + 'Total Fwd Packets': 'total_packets', + 'Total Backward Packets': 'total_packets', # Combined + 'Total Length of Fwd Packets': 'total_bytes', + 'Total Length of Bwd Packets': 'total_bytes', # Combined + 'Flow Duration': 'time_span_seconds', + + # Temporal features + 'Flow Packets/s': 'conn_per_second', + 'Flow Bytes/s': 'bytes_per_second', + 'Fwd Packets/s': 'packets_per_conn', + + # Protocol diversity + 'Protocol': 'unique_protocols', + 'Destination Port': 'unique_dest_ports', + + # Port scanning + 'Fwd PSH Flags': 'port_scan_score', + 'Fwd URG Flags': 'port_scan_score', + + # Behavioral + 'Fwd Packet Length Mean': 'avg_packet_size', + 'Fwd Packet Length Std': 'packet_size_variance', + 'Bwd Packet Length Mean': 'avg_packet_size', + 'Bwd Packet Length Std': 'packet_size_variance', + + # Burst patterns + 'Subflow Fwd Packets': 'max_burst', + 'Subflow Fwd Bytes': 'burst_variance', + } + + # Attack type mapping + ATTACK_LABELS = { + 'BENIGN': 'normal', + 'DoS Hulk': 'ddos', + 'DoS GoldenEye': 'ddos', + 'DoS slowloris': 'ddos', + 'DoS Slowhttptest': 'ddos', + 'DDoS': 'ddos', + 'PortScan': 'port_scan', + 'FTP-Patator': 'brute_force', + 'SSH-Patator': 'brute_force', + 'Bot': 'botnet', + 'Web Attack – Brute Force': 'brute_force', + 'Web Attack – XSS': 'suspicious', + 'Web Attack – Sql Injection': 'suspicious', + 'Infiltration': 'suspicious', + 'Heartbleed': 'suspicious', + } + + def __init__(self, data_dir: str = "datasets/cicids2017"): + self.data_dir = Path(data_dir) + self.data_dir.mkdir(parents=True, exist_ok=True) + + def download_instructions(self) -> str: + """Return download instructions for CICIDS2017""" + instructions = f""" +╔══════════════════════════════════════════════════════════════════╗ +║ CICIDS2017 Dataset Download Instructions ║ +╚══════════════════════════════════════════════════════════════════╝ + +Dataset: {self.DATASET_INFO['name']} +Source: {self.DATASET_INFO['source']} +Size: {self.DATASET_INFO['size_gb']} GB +URL: {self.DATASET_INFO['url']} + +MANUAL DOWNLOAD (Recommended): +1. Visit: {self.DATASET_INFO['url']} +2. Register/Login (free account required) +3. Download CSV files for all days (Monday-Friday) +4. Extract to: {self.data_dir.absolute()} + +Expected files: +""" + for i, fname in enumerate(self.DATASET_INFO['files'], 1): + instructions += f" {i}. {fname}\n" + + instructions += f"\nAfter download, run: python_ml/train_hybrid.py --validate\n" + instructions += "=" * 66 + + return instructions + + def check_dataset_exists(self) -> Tuple[bool, list]: + """Check if dataset files exist""" + missing_files = [] + for fname in self.DATASET_INFO['files']: + fpath = self.data_dir / fname + if not fpath.exists(): + missing_files.append(fname) + + exists = len(missing_files) == 0 + return exists, missing_files + + def load_day(self, day_file: str, sample_frac: float = 1.0) -> pd.DataFrame: + """ + Load single day CSV file + sample_frac: fraction to sample (0.1 = 10% for testing) + """ + fpath = self.data_dir / day_file + + if not fpath.exists(): + raise FileNotFoundError(f"Dataset file not found: {fpath}") + + logger.info(f"Loading {day_file}...") + + # CICIDS2017 has known issues: extra space before column names, inf values + df = pd.read_csv(fpath, skipinitialspace=True) + + # Strip whitespace from column names + df.columns = df.columns.str.strip() + + # Sample if requested + if sample_frac < 1.0: + df = df.sample(frac=sample_frac, random_state=42) + logger.info(f"Sampled {len(df)} rows ({sample_frac*100:.0f}%)") + + return df + + def preprocess(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Clean and preprocess CICIDS2017 data + - Remove NaN and Inf values + - Fix data types + - Map labels + """ + logger.info(f"Preprocessing {len(df)} rows...") + + # Replace inf with NaN, then drop + df = df.replace([np.inf, -np.inf], np.nan) + df = df.dropna() + + # Map attack labels + if ' Label' in df.columns: + df['attack_type'] = df[' Label'].map(self.ATTACK_LABELS) + df['is_attack'] = (df['attack_type'] != 'normal').astype(int) + elif 'Label' in df.columns: + df['attack_type'] = df['Label'].map(self.ATTACK_LABELS) + df['is_attack'] = (df['attack_type'] != 'normal').astype(int) + else: + logger.warning("No label column found, assuming all BENIGN") + df['attack_type'] = 'normal' + df['is_attack'] = 0 + + # Remove unknown attack types + df = df[df['attack_type'].notna()] + + logger.info(f"After preprocessing: {len(df)} rows") + logger.info(f"Attack distribution:\n{df['attack_type'].value_counts()}") + + return df + + def map_to_ids_features(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Map 80 CICIDS2017 features → 25 IDS features + This is approximate mapping for validation purposes + """ + logger.info("Mapping CICIDS features to IDS feature space...") + + ids_features = {} + + # Volume features (combine fwd+bwd) + ids_features['total_packets'] = ( + df.get('Total Fwd Packets', 0) + + df.get('Total Backward Packets', 0) + ) + ids_features['total_bytes'] = ( + df.get('Total Length of Fwd Packets', 0) + + df.get('Total Length of Bwd Packets', 0) + ) + ids_features['conn_count'] = 1 # Each row = 1 flow + ids_features['avg_packet_size'] = df.get('Fwd Packet Length Mean', 0) + ids_features['bytes_per_second'] = df.get('Flow Bytes/s', 0) + + # Temporal features + ids_features['time_span_seconds'] = df.get('Flow Duration', 0) / 1_000_000 # Microseconds to seconds + ids_features['conn_per_second'] = df.get('Flow Packets/s', 0) + ids_features['hour_of_day'] = 12 # Unknown, use midday + ids_features['day_of_week'] = 3 # Unknown, use Wednesday + + # Burst detection (approximate) + ids_features['max_burst'] = df.get('Subflow Fwd Packets', 0) + ids_features['avg_burst'] = df.get('Subflow Fwd Packets', 0) + ids_features['burst_variance'] = df.get('Subflow Fwd Bytes', 0).apply(lambda x: max(0, x)) + ids_features['avg_interval'] = 1.0 # Unknown + + # Protocol diversity + ids_features['unique_protocols'] = 1 # Each row = single protocol + ids_features['unique_dest_ports'] = 1 + ids_features['unique_dest_ips'] = 1 + ids_features['protocol_entropy'] = 0 + ids_features['tcp_ratio'] = (df.get('Protocol', 6) == 6).astype(int) + ids_features['udp_ratio'] = (df.get('Protocol', 17) == 17).astype(int) + + # Port scanning detection + ids_features['unique_ports_contacted'] = df.get('Destination Port', 0).apply(lambda x: 1 if x > 0 else 0) + ids_features['port_scan_score'] = (df.get('Fwd PSH Flags', 0) + df.get('Fwd URG Flags', 0)) / 2 + ids_features['sequential_ports'] = 0 + + # Behavioral anomalies + ids_features['packets_per_conn'] = ids_features['total_packets'] + ids_features['packet_size_variance'] = df.get('Fwd Packet Length Std', 0) + ids_features['blocked_ratio'] = 0 + + # Add labels + ids_features['attack_type'] = df['attack_type'] + ids_features['is_attack'] = df['is_attack'] + + ids_df = pd.DataFrame(ids_features) + + # Clip negative values + numeric_cols = ids_df.select_dtypes(include=[np.number]).columns + ids_df[numeric_cols] = ids_df[numeric_cols].clip(lower=0) + + logger.info(f"Mapped to {len(ids_df.columns)} IDS features") + return ids_df + + def load_and_process_all( + self, + sample_frac: float = 1.0, + train_ratio: float = 0.7, + val_ratio: float = 0.15 + ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Load all days, preprocess, map to IDS features, and split + Returns: train_df, val_df, test_df + """ + exists, missing = self.check_dataset_exists() + if not exists: + raise FileNotFoundError( + f"Missing dataset files: {missing}\n\n" + f"{self.download_instructions()}" + ) + + all_data = [] + for fname in self.DATASET_INFO['files']: + try: + df = self.load_day(fname, sample_frac=sample_frac) + df = self.preprocess(df) + df_ids = self.map_to_ids_features(df) + all_data.append(df_ids) + except Exception as e: + logger.error(f"Failed to load {fname}: {e}") + continue + + if not all_data: + raise ValueError("No data loaded successfully") + + # Combine all days + combined = pd.concat(all_data, ignore_index=True) + logger.info(f"Combined dataset: {len(combined)} rows") + + # Shuffle + combined = combined.sample(frac=1, random_state=42).reset_index(drop=True) + + # Split train/val/test + n = len(combined) + n_train = int(n * train_ratio) + n_val = int(n * val_ratio) + + train_df = combined.iloc[:n_train] + val_df = combined.iloc[n_train:n_train+n_val] + test_df = combined.iloc[n_train+n_val:] + + logger.info(f"Split: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}") + + return train_df, val_df, test_df + + def create_sample_dataset(self, n_samples: int = 10000) -> pd.DataFrame: + """ + Create synthetic sample dataset for testing + Mimics CICIDS2017 structure + """ + logger.info(f"Creating sample dataset ({n_samples} samples)...") + + np.random.seed(42) + + # Generate synthetic features + data = { + 'total_packets': np.random.lognormal(3, 1.5, n_samples).astype(int), + 'total_bytes': np.random.lognormal(8, 2, n_samples).astype(int), + 'conn_count': np.ones(n_samples, dtype=int), + 'avg_packet_size': np.random.normal(500, 200, n_samples), + 'bytes_per_second': np.random.lognormal(6, 2, n_samples), + 'time_span_seconds': np.random.exponential(10, n_samples), + 'conn_per_second': np.random.exponential(5, n_samples), + 'hour_of_day': np.random.randint(0, 24, n_samples), + 'day_of_week': np.random.randint(0, 7, n_samples), + 'max_burst': np.random.poisson(20, n_samples), + 'avg_burst': np.random.poisson(15, n_samples), + 'burst_variance': np.random.exponential(5, n_samples), + 'avg_interval': np.random.exponential(0.1, n_samples), + 'unique_protocols': np.ones(n_samples, dtype=int), + 'unique_dest_ports': np.ones(n_samples, dtype=int), + 'unique_dest_ips': np.ones(n_samples, dtype=int), + 'protocol_entropy': np.zeros(n_samples), + 'tcp_ratio': np.random.choice([0, 1], n_samples, p=[0.3, 0.7]), + 'udp_ratio': np.random.choice([0, 1], n_samples, p=[0.7, 0.3]), + 'unique_ports_contacted': np.ones(n_samples, dtype=int), + 'port_scan_score': np.random.beta(1, 10, n_samples), + 'sequential_ports': np.zeros(n_samples, dtype=int), + 'packets_per_conn': np.random.lognormal(3, 1.5, n_samples), + 'packet_size_variance': np.random.exponential(100, n_samples), + 'blocked_ratio': np.zeros(n_samples), + } + + # Generate labels: 90% normal, 10% attacks + is_attack = np.random.choice([0, 1], n_samples, p=[0.9, 0.1]) + attack_types = np.where( + is_attack == 1, + np.random.choice(['ddos', 'port_scan', 'brute_force', 'suspicious'], n_samples), + 'normal' + ) + + data['is_attack'] = is_attack + data['attack_type'] = attack_types + + df = pd.DataFrame(data) + + # Make attacks more extreme + attack_mask = df['is_attack'] == 1 + df.loc[attack_mask, 'total_packets'] *= 10 + df.loc[attack_mask, 'total_bytes'] *= 15 + df.loc[attack_mask, 'conn_per_second'] *= 20 + + logger.info(f"Sample dataset created: {len(df)} rows") + logger.info(f"Attack distribution:\n{df['attack_type'].value_counts()}") + + return df + + +# Utility function +def get_cicids2017_loader(data_dir: str = "datasets/cicids2017") -> CICIDS2017Loader: + """Factory function to get loader instance""" + return CICIDS2017Loader(data_dir) diff --git a/python_ml/validation_metrics.py b/python_ml/validation_metrics.py new file mode 100644 index 0000000..0f89ed1 --- /dev/null +++ b/python_ml/validation_metrics.py @@ -0,0 +1,324 @@ +""" +Validation Metrics for IDS Models +Calculates Precision, Recall, F1-Score, False Positive Rate, Accuracy +""" + +import numpy as np +import pandas as pd +from typing import Dict, Tuple, Optional +from sklearn.metrics import ( + precision_score, + recall_score, + f1_score, + accuracy_score, + confusion_matrix, + roc_auc_score, + classification_report +) +import json + + +class ValidationMetrics: + """Calculate and track validation metrics for IDS models""" + + def __init__(self): + self.history = [] + + def calculate( + self, + y_true: np.ndarray, + y_pred: np.ndarray, + y_prob: Optional[np.ndarray] = None + ) -> Dict: + """ + Calculate all metrics + + Args: + y_true: True labels (0=normal, 1=attack) + y_pred: Predicted labels (0=normal, 1=attack) + y_prob: Prediction probabilities (optional, for ROC-AUC) + + Returns: + Dict with all metrics + """ + # Confusion matrix + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + + # Core metrics + precision = precision_score(y_true, y_pred, zero_division=0) + recall = recall_score(y_true, y_pred, zero_division=0) + f1 = f1_score(y_true, y_pred, zero_division=0) + accuracy = accuracy_score(y_true, y_pred) + + # False Positive Rate (critical for IDS!) + fpr = fp / (fp + tn) if (fp + tn) > 0 else 0 + + # True Negative Rate (Specificity) + tnr = tn / (tn + fp) if (tn + fp) > 0 else 0 + + # Matthews Correlation Coefficient (good for imbalanced datasets) + mcc_num = (tp * tn) - (fp * fn) + mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) + mcc = mcc_num / mcc_den if mcc_den > 0 else 0 + + metrics = { + # Primary metrics + 'precision': float(precision), + 'recall': float(recall), + 'f1_score': float(f1), + 'accuracy': float(accuracy), + 'false_positive_rate': float(fpr), + + # Additional metrics + 'true_negative_rate': float(tnr), # Specificity + 'matthews_corr_coef': float(mcc), + + # Confusion matrix + 'true_positives': int(tp), + 'false_positives': int(fp), + 'true_negatives': int(tn), + 'false_negatives': int(fn), + + # Sample counts + 'total_samples': int(len(y_true)), + 'total_attacks': int(np.sum(y_true == 1)), + 'total_normal': int(np.sum(y_true == 0)), + } + + # ROC-AUC if probabilities provided + if y_prob is not None: + try: + roc_auc = roc_auc_score(y_true, y_prob) + metrics['roc_auc'] = float(roc_auc) + except Exception: + metrics['roc_auc'] = None + + return metrics + + def calculate_per_class( + self, + y_true: np.ndarray, + y_pred: np.ndarray, + class_names: Optional[list] = None + ) -> Dict: + """ + Calculate metrics per attack type + + Args: + y_true: True class labels (attack types) + y_pred: Predicted class labels + class_names: List of class names + + Returns: + Dict with per-class metrics + """ + if class_names is None: + class_names = sorted(np.unique(np.concatenate([y_true, y_pred]))) + + # Get classification report as dict + report = classification_report( + y_true, + y_pred, + target_names=class_names, + output_dict=True, + zero_division=0 + ) + + # Format per-class metrics + per_class = {} + for class_name in class_names: + if class_name in report: + per_class[class_name] = { + 'precision': report[class_name]['precision'], + 'recall': report[class_name]['recall'], + 'f1_score': report[class_name]['f1-score'], + 'support': report[class_name]['support'], + } + + # Add macro/weighted averages + per_class['macro_avg'] = report['macro avg'] + per_class['weighted_avg'] = report['weighted avg'] + + return per_class + + def print_summary(self, metrics: Dict, title: str = "Validation Metrics"): + """Print formatted metrics summary""" + print(f"\n{'='*60}") + print(f"{title:^60}") + print(f"{'='*60}") + + print(f"\n🎯 Primary Metrics:") + print(f" Precision: {metrics['precision']*100:6.2f}% (of 100 flagged, how many are real attacks)") + print(f" Recall: {metrics['recall']*100:6.2f}% (of 100 attacks, how many detected)") + print(f" F1-Score: {metrics['f1_score']*100:6.2f}% (harmonic mean of P&R)") + print(f" Accuracy: {metrics['accuracy']*100:6.2f}% (overall correctness)") + + print(f"\n⚠️ False Positive Analysis:") + print(f" FP Rate: {metrics['false_positive_rate']*100:6.2f}% (normal traffic flagged as attack)") + print(f" FP Count: {metrics['false_positives']:6d} (actual false positives)") + print(f" TN Rate: {metrics['true_negative_rate']*100:6.2f}% (specificity - correct normal)") + + print(f"\n📊 Confusion Matrix:") + print(f" Predicted Normal Predicted Attack") + print(f" Actual Normal {metrics['true_negatives']:6d} {metrics['false_positives']:6d}") + print(f" Actual Attack {metrics['false_negatives']:6d} {metrics['true_positives']:6d}") + + print(f"\n📈 Dataset Statistics:") + print(f" Total Samples: {metrics['total_samples']:6d}") + print(f" Total Attacks: {metrics['total_attacks']:6d} ({metrics['total_attacks']/metrics['total_samples']*100:.1f}%)") + print(f" Total Normal: {metrics['total_normal']:6d} ({metrics['total_normal']/metrics['total_samples']*100:.1f}%)") + + if 'roc_auc' in metrics and metrics['roc_auc'] is not None: + print(f"\n🎲 ROC-AUC: {metrics['roc_auc']:6.4f}") + + if 'matthews_corr_coef' in metrics: + print(f" MCC: {metrics['matthews_corr_coef']:6.4f} (correlation coefficient)") + + print(f"\n{'='*60}\n") + + def compare_models( + self, + model_metrics: Dict[str, Dict], + highlight_best: bool = True + ) -> pd.DataFrame: + """ + Compare metrics across multiple models + + Args: + model_metrics: Dict of {model_name: metrics_dict} + highlight_best: Print best model + + Returns: + DataFrame with comparison + """ + comparison = pd.DataFrame(model_metrics).T + + # Select key columns + key_cols = ['precision', 'recall', 'f1_score', 'accuracy', 'false_positive_rate'] + comparison = comparison[key_cols] + + # Convert to percentages + for col in key_cols: + comparison[col] = comparison[col] * 100 + + # Round to 2 decimals + comparison = comparison.round(2) + + if highlight_best: + print("\n📊 Model Comparison:") + print(comparison.to_string()) + + # Find best model (highest F1, lowest FPR) + comparison['score'] = comparison['f1_score'] - comparison['false_positive_rate'] + best_model = comparison['score'].idxmax() + + print(f"\n🏆 Best Model: {best_model}") + print(f" - F1-Score: {comparison.loc[best_model, 'f1_score']:.2f}%") + print(f" - FPR: {comparison.loc[best_model, 'false_positive_rate']:.2f}%") + + return comparison + + def save_metrics(self, metrics: Dict, filepath: str): + """Save metrics to JSON file""" + with open(filepath, 'w') as f: + json.dump(metrics, f, indent=2) + print(f"[METRICS] Saved to {filepath}") + + def load_metrics(self, filepath: str) -> Dict: + """Load metrics from JSON file""" + with open(filepath) as f: + metrics = json.load(f) + return metrics + + def meets_production_criteria( + self, + metrics: Dict, + min_precision: float = 0.90, + max_fpr: float = 0.05, + min_recall: float = 0.80 + ) -> Tuple[bool, list]: + """ + Check if model meets production deployment criteria + + Args: + metrics: Calculated metrics + min_precision: Minimum acceptable precision (default 90%) + max_fpr: Maximum acceptable FPR (default 5%) + min_recall: Minimum acceptable recall (default 80%) + + Returns: + (passes: bool, issues: list) + """ + issues = [] + + if metrics['precision'] < min_precision: + issues.append( + f"Precision {metrics['precision']*100:.1f}% < {min_precision*100:.0f}% " + f"(too many false positives)" + ) + + if metrics['false_positive_rate'] > max_fpr: + issues.append( + f"FPR {metrics['false_positive_rate']*100:.1f}% > {max_fpr*100:.0f}% " + f"(flagging too much normal traffic)" + ) + + if metrics['recall'] < min_recall: + issues.append( + f"Recall {metrics['recall']*100:.1f}% < {min_recall*100:.0f}% " + f"(missing too many attacks)" + ) + + passes = len(issues) == 0 + + if passes: + print("✅ Model meets production criteria!") + else: + print("❌ Model does NOT meet production criteria:") + for issue in issues: + print(f" - {issue}") + + return passes, issues + + +def calculate_confidence_metrics( + detections: list, + ground_truth: Dict[str, bool] +) -> Dict: + """ + Calculate metrics for confidence-based detection system + + Args: + detections: List of detection dicts with 'source_ip' and 'confidence_level' + ground_truth: Dict of {ip: is_attack (bool)} + + Returns: + Metrics broken down by confidence level + """ + confidence_levels = ['high', 'medium', 'low'] + metrics_by_confidence = {} + + for level in confidence_levels: + level_detections = [d for d in detections if d.get('confidence_level') == level] + + if not level_detections: + metrics_by_confidence[level] = { + 'count': 0, + 'true_positives': 0, + 'false_positives': 0, + 'precision': 0.0 + } + continue + + tp = sum(1 for d in level_detections if ground_truth.get(d['source_ip'], False)) + fp = len(level_detections) - tp + precision = tp / len(level_detections) if level_detections else 0 + + metrics_by_confidence[level] = { + 'count': len(level_detections), + 'true_positives': tp, + 'false_positives': fp, + 'precision': precision + } + + return metrics_by_confidence