""" Validation Metrics for IDS Models Calculates Precision, Recall, F1-Score, False Positive Rate, Accuracy """ import numpy as np import pandas as pd from typing import Dict, Tuple, Optional from sklearn.metrics import ( precision_score, recall_score, f1_score, accuracy_score, confusion_matrix, roc_auc_score, classification_report ) import json class ValidationMetrics: """Calculate and track validation metrics for IDS models""" def __init__(self): self.history = [] def calculate( self, y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None ) -> Dict: """ Calculate all metrics Args: y_true: True labels (0=normal, 1=attack) y_pred: Predicted labels (0=normal, 1=attack) y_prob: Prediction probabilities (optional, for ROC-AUC) Returns: Dict with all metrics """ # Confusion matrix tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() # Core metrics precision = precision_score(y_true, y_pred, zero_division=0) recall = recall_score(y_true, y_pred, zero_division=0) f1 = f1_score(y_true, y_pred, zero_division=0) accuracy = accuracy_score(y_true, y_pred) # False Positive Rate (critical for IDS!) fpr = fp / (fp + tn) if (fp + tn) > 0 else 0 # True Negative Rate (Specificity) tnr = tn / (tn + fp) if (tn + fp) > 0 else 0 # Matthews Correlation Coefficient (good for imbalanced datasets) mcc_num = (tp * tn) - (fp * fn) mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) mcc = mcc_num / mcc_den if mcc_den > 0 else 0 metrics = { # Primary metrics 'precision': float(precision), 'recall': float(recall), 'f1_score': float(f1), 'accuracy': float(accuracy), 'false_positive_rate': float(fpr), # Additional metrics 'true_negative_rate': float(tnr), # Specificity 'matthews_corr_coef': float(mcc), # Confusion matrix 'true_positives': int(tp), 'false_positives': int(fp), 'true_negatives': int(tn), 'false_negatives': int(fn), # Sample counts 'total_samples': int(len(y_true)), 'total_attacks': int(np.sum(y_true == 1)), 'total_normal': int(np.sum(y_true == 0)), } # ROC-AUC if probabilities provided if y_prob is not None: try: roc_auc = roc_auc_score(y_true, y_prob) metrics['roc_auc'] = float(roc_auc) except Exception: metrics['roc_auc'] = None return metrics def calculate_per_class( self, y_true: np.ndarray, y_pred: np.ndarray, class_names: Optional[list] = None ) -> Dict: """ Calculate metrics per attack type Args: y_true: True class labels (attack types) y_pred: Predicted class labels class_names: List of class names Returns: Dict with per-class metrics """ if class_names is None: class_names = sorted(np.unique(np.concatenate([y_true, y_pred]))) # Get classification report as dict report = classification_report( y_true, y_pred, target_names=class_names, output_dict=True, zero_division=0 ) # Format per-class metrics per_class = {} for class_name in class_names: if class_name in report: per_class[class_name] = { 'precision': report[class_name]['precision'], 'recall': report[class_name]['recall'], 'f1_score': report[class_name]['f1-score'], 'support': report[class_name]['support'], } # Add macro/weighted averages per_class['macro_avg'] = report['macro avg'] per_class['weighted_avg'] = report['weighted avg'] return per_class def print_summary(self, metrics: Dict, title: str = "Validation Metrics"): """Print formatted metrics summary""" print(f"\n{'='*60}") print(f"{title:^60}") print(f"{'='*60}") print(f"\nšŸŽÆ Primary Metrics:") print(f" Precision: {metrics['precision']*100:6.2f}% (of 100 flagged, how many are real attacks)") print(f" Recall: {metrics['recall']*100:6.2f}% (of 100 attacks, how many detected)") print(f" F1-Score: {metrics['f1_score']*100:6.2f}% (harmonic mean of P&R)") print(f" Accuracy: {metrics['accuracy']*100:6.2f}% (overall correctness)") print(f"\nāš ļø False Positive Analysis:") print(f" FP Rate: {metrics['false_positive_rate']*100:6.2f}% (normal traffic flagged as attack)") print(f" FP Count: {metrics['false_positives']:6d} (actual false positives)") print(f" TN Rate: {metrics['true_negative_rate']*100:6.2f}% (specificity - correct normal)") print(f"\nšŸ“Š Confusion Matrix:") print(f" Predicted Normal Predicted Attack") print(f" Actual Normal {metrics['true_negatives']:6d} {metrics['false_positives']:6d}") print(f" Actual Attack {metrics['false_negatives']:6d} {metrics['true_positives']:6d}") print(f"\nšŸ“ˆ Dataset Statistics:") print(f" Total Samples: {metrics['total_samples']:6d}") print(f" Total Attacks: {metrics['total_attacks']:6d} ({metrics['total_attacks']/metrics['total_samples']*100:.1f}%)") print(f" Total Normal: {metrics['total_normal']:6d} ({metrics['total_normal']/metrics['total_samples']*100:.1f}%)") if 'roc_auc' in metrics and metrics['roc_auc'] is not None: print(f"\nšŸŽ² ROC-AUC: {metrics['roc_auc']:6.4f}") if 'matthews_corr_coef' in metrics: print(f" MCC: {metrics['matthews_corr_coef']:6.4f} (correlation coefficient)") print(f"\n{'='*60}\n") def compare_models( self, model_metrics: Dict[str, Dict], highlight_best: bool = True ) -> pd.DataFrame: """ Compare metrics across multiple models Args: model_metrics: Dict of {model_name: metrics_dict} highlight_best: Print best model Returns: DataFrame with comparison """ comparison = pd.DataFrame(model_metrics).T # Select key columns key_cols = ['precision', 'recall', 'f1_score', 'accuracy', 'false_positive_rate'] comparison = comparison[key_cols] # Convert to percentages for col in key_cols: comparison[col] = comparison[col] * 100 # Round to 2 decimals comparison = comparison.round(2) if highlight_best: print("\nšŸ“Š Model Comparison:") print(comparison.to_string()) # Find best model (highest F1, lowest FPR) comparison['score'] = comparison['f1_score'] - comparison['false_positive_rate'] best_model = comparison['score'].idxmax() print(f"\nšŸ† Best Model: {best_model}") print(f" - F1-Score: {comparison.loc[best_model, 'f1_score']:.2f}%") print(f" - FPR: {comparison.loc[best_model, 'false_positive_rate']:.2f}%") return comparison def save_metrics(self, metrics: Dict, filepath: str): """Save metrics to JSON file""" with open(filepath, 'w') as f: json.dump(metrics, f, indent=2) print(f"[METRICS] Saved to {filepath}") def load_metrics(self, filepath: str) -> Dict: """Load metrics from JSON file""" with open(filepath) as f: metrics = json.load(f) return metrics def meets_production_criteria( self, metrics: Dict, min_precision: float = 0.90, max_fpr: float = 0.05, min_recall: float = 0.80 ) -> Tuple[bool, list]: """ Check if model meets production deployment criteria Args: metrics: Calculated metrics min_precision: Minimum acceptable precision (default 90%) max_fpr: Maximum acceptable FPR (default 5%) min_recall: Minimum acceptable recall (default 80%) Returns: (passes: bool, issues: list) """ issues = [] if metrics['precision'] < min_precision: issues.append( f"Precision {metrics['precision']*100:.1f}% < {min_precision*100:.0f}% " f"(too many false positives)" ) if metrics['false_positive_rate'] > max_fpr: issues.append( f"FPR {metrics['false_positive_rate']*100:.1f}% > {max_fpr*100:.0f}% " f"(flagging too much normal traffic)" ) if metrics['recall'] < min_recall: issues.append( f"Recall {metrics['recall']*100:.1f}% < {min_recall*100:.0f}% " f"(missing too many attacks)" ) passes = len(issues) == 0 if passes: print("āœ… Model meets production criteria!") else: print("āŒ Model does NOT meet production criteria:") for issue in issues: print(f" - {issue}") return passes, issues def calculate_confidence_metrics( detections: list, ground_truth: Dict[str, bool] ) -> Dict: """ Calculate metrics for confidence-based detection system Args: detections: List of detection dicts with 'source_ip' and 'confidence_level' ground_truth: Dict of {ip: is_attack (bool)} Returns: Metrics broken down by confidence level """ confidence_levels = ['high', 'medium', 'low'] metrics_by_confidence = {} for level in confidence_levels: level_detections = [d for d in detections if d.get('confidence_level') == level] if not level_detections: metrics_by_confidence[level] = { 'count': 0, 'true_positives': 0, 'false_positives': 0, 'precision': 0.0 } continue tp = sum(1 for d in level_detections if ground_truth.get(d['source_ip'], False)) fp = len(level_detections) - tp precision = tp / len(level_detections) if level_detections else 0 metrics_by_confidence[level] = { 'count': len(level_detections), 'true_positives': tp, 'false_positives': fp, 'precision': precision } return metrics_by_confidence