Add dataset loader and validation metrics modules
Introduces `CICIDS2017Loader` for dataset handling and `ValidationMetrics` class for calculating performance metrics in Python. Replit-Commit-Author: Agent Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528 Replit-Commit-Checkpoint-Type: intermediate_checkpoint Replit-Commit-Event-Id: ad530f16-3a16-44a3-8fed-6c5d56775c77 Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/F6DiMv4
This commit is contained in:
parent
932931457e
commit
350a0994bd
374
python_ml/dataset_loader.py
Normal file
374
python_ml/dataset_loader.py
Normal file
@ -0,0 +1,374 @@
|
||||
"""
|
||||
CICIDS2017 Dataset Loader and Preprocessor
|
||||
Downloads, cleans, and maps CICIDS2017 features to IDS feature space
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Optional
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CICIDS2017Loader:
|
||||
"""
|
||||
Loads and preprocesses CICIDS2017 dataset
|
||||
Maps 80 CIC features to 25 IDS features
|
||||
"""
|
||||
|
||||
DATASET_INFO = {
|
||||
'name': 'CICIDS2017',
|
||||
'source': 'Canadian Institute for Cybersecurity',
|
||||
'url': 'https://www.unb.ca/cic/datasets/ids-2017.html',
|
||||
'size_gb': 7.8,
|
||||
'files': [
|
||||
'Monday-WorkingHours.pcap_ISCX.csv',
|
||||
'Tuesday-WorkingHours.pcap_ISCX.csv',
|
||||
'Wednesday-workingHours.pcap_ISCX.csv',
|
||||
'Thursday-WorkingHours-Morning-WebAttacks.pcap_ISCX.csv',
|
||||
'Thursday-WorkingHours-Afternoon-Infilteration.pcap_ISCX.csv',
|
||||
'Friday-WorkingHours-Morning.pcap_ISCX.csv',
|
||||
'Friday-WorkingHours-Afternoon-PortScan.pcap_ISCX.csv',
|
||||
'Friday-WorkingHours-Afternoon-DDos.pcap_ISCX.csv',
|
||||
]
|
||||
}
|
||||
|
||||
# Mapping CIC feature names → IDS feature names
|
||||
FEATURE_MAPPING = {
|
||||
# Volume features
|
||||
'Total Fwd Packets': 'total_packets',
|
||||
'Total Backward Packets': 'total_packets', # Combined
|
||||
'Total Length of Fwd Packets': 'total_bytes',
|
||||
'Total Length of Bwd Packets': 'total_bytes', # Combined
|
||||
'Flow Duration': 'time_span_seconds',
|
||||
|
||||
# Temporal features
|
||||
'Flow Packets/s': 'conn_per_second',
|
||||
'Flow Bytes/s': 'bytes_per_second',
|
||||
'Fwd Packets/s': 'packets_per_conn',
|
||||
|
||||
# Protocol diversity
|
||||
'Protocol': 'unique_protocols',
|
||||
'Destination Port': 'unique_dest_ports',
|
||||
|
||||
# Port scanning
|
||||
'Fwd PSH Flags': 'port_scan_score',
|
||||
'Fwd URG Flags': 'port_scan_score',
|
||||
|
||||
# Behavioral
|
||||
'Fwd Packet Length Mean': 'avg_packet_size',
|
||||
'Fwd Packet Length Std': 'packet_size_variance',
|
||||
'Bwd Packet Length Mean': 'avg_packet_size',
|
||||
'Bwd Packet Length Std': 'packet_size_variance',
|
||||
|
||||
# Burst patterns
|
||||
'Subflow Fwd Packets': 'max_burst',
|
||||
'Subflow Fwd Bytes': 'burst_variance',
|
||||
}
|
||||
|
||||
# Attack type mapping
|
||||
ATTACK_LABELS = {
|
||||
'BENIGN': 'normal',
|
||||
'DoS Hulk': 'ddos',
|
||||
'DoS GoldenEye': 'ddos',
|
||||
'DoS slowloris': 'ddos',
|
||||
'DoS Slowhttptest': 'ddos',
|
||||
'DDoS': 'ddos',
|
||||
'PortScan': 'port_scan',
|
||||
'FTP-Patator': 'brute_force',
|
||||
'SSH-Patator': 'brute_force',
|
||||
'Bot': 'botnet',
|
||||
'Web Attack – Brute Force': 'brute_force',
|
||||
'Web Attack – XSS': 'suspicious',
|
||||
'Web Attack – Sql Injection': 'suspicious',
|
||||
'Infiltration': 'suspicious',
|
||||
'Heartbleed': 'suspicious',
|
||||
}
|
||||
|
||||
def __init__(self, data_dir: str = "datasets/cicids2017"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def download_instructions(self) -> str:
|
||||
"""Return download instructions for CICIDS2017"""
|
||||
instructions = f"""
|
||||
╔══════════════════════════════════════════════════════════════════╗
|
||||
║ CICIDS2017 Dataset Download Instructions ║
|
||||
╚══════════════════════════════════════════════════════════════════╝
|
||||
|
||||
Dataset: {self.DATASET_INFO['name']}
|
||||
Source: {self.DATASET_INFO['source']}
|
||||
Size: {self.DATASET_INFO['size_gb']} GB
|
||||
URL: {self.DATASET_INFO['url']}
|
||||
|
||||
MANUAL DOWNLOAD (Recommended):
|
||||
1. Visit: {self.DATASET_INFO['url']}
|
||||
2. Register/Login (free account required)
|
||||
3. Download CSV files for all days (Monday-Friday)
|
||||
4. Extract to: {self.data_dir.absolute()}
|
||||
|
||||
Expected files:
|
||||
"""
|
||||
for i, fname in enumerate(self.DATASET_INFO['files'], 1):
|
||||
instructions += f" {i}. {fname}\n"
|
||||
|
||||
instructions += f"\nAfter download, run: python_ml/train_hybrid.py --validate\n"
|
||||
instructions += "=" * 66
|
||||
|
||||
return instructions
|
||||
|
||||
def check_dataset_exists(self) -> Tuple[bool, list]:
|
||||
"""Check if dataset files exist"""
|
||||
missing_files = []
|
||||
for fname in self.DATASET_INFO['files']:
|
||||
fpath = self.data_dir / fname
|
||||
if not fpath.exists():
|
||||
missing_files.append(fname)
|
||||
|
||||
exists = len(missing_files) == 0
|
||||
return exists, missing_files
|
||||
|
||||
def load_day(self, day_file: str, sample_frac: float = 1.0) -> pd.DataFrame:
|
||||
"""
|
||||
Load single day CSV file
|
||||
sample_frac: fraction to sample (0.1 = 10% for testing)
|
||||
"""
|
||||
fpath = self.data_dir / day_file
|
||||
|
||||
if not fpath.exists():
|
||||
raise FileNotFoundError(f"Dataset file not found: {fpath}")
|
||||
|
||||
logger.info(f"Loading {day_file}...")
|
||||
|
||||
# CICIDS2017 has known issues: extra space before column names, inf values
|
||||
df = pd.read_csv(fpath, skipinitialspace=True)
|
||||
|
||||
# Strip whitespace from column names
|
||||
df.columns = df.columns.str.strip()
|
||||
|
||||
# Sample if requested
|
||||
if sample_frac < 1.0:
|
||||
df = df.sample(frac=sample_frac, random_state=42)
|
||||
logger.info(f"Sampled {len(df)} rows ({sample_frac*100:.0f}%)")
|
||||
|
||||
return df
|
||||
|
||||
def preprocess(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Clean and preprocess CICIDS2017 data
|
||||
- Remove NaN and Inf values
|
||||
- Fix data types
|
||||
- Map labels
|
||||
"""
|
||||
logger.info(f"Preprocessing {len(df)} rows...")
|
||||
|
||||
# Replace inf with NaN, then drop
|
||||
df = df.replace([np.inf, -np.inf], np.nan)
|
||||
df = df.dropna()
|
||||
|
||||
# Map attack labels
|
||||
if ' Label' in df.columns:
|
||||
df['attack_type'] = df[' Label'].map(self.ATTACK_LABELS)
|
||||
df['is_attack'] = (df['attack_type'] != 'normal').astype(int)
|
||||
elif 'Label' in df.columns:
|
||||
df['attack_type'] = df['Label'].map(self.ATTACK_LABELS)
|
||||
df['is_attack'] = (df['attack_type'] != 'normal').astype(int)
|
||||
else:
|
||||
logger.warning("No label column found, assuming all BENIGN")
|
||||
df['attack_type'] = 'normal'
|
||||
df['is_attack'] = 0
|
||||
|
||||
# Remove unknown attack types
|
||||
df = df[df['attack_type'].notna()]
|
||||
|
||||
logger.info(f"After preprocessing: {len(df)} rows")
|
||||
logger.info(f"Attack distribution:\n{df['attack_type'].value_counts()}")
|
||||
|
||||
return df
|
||||
|
||||
def map_to_ids_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Map 80 CICIDS2017 features → 25 IDS features
|
||||
This is approximate mapping for validation purposes
|
||||
"""
|
||||
logger.info("Mapping CICIDS features to IDS feature space...")
|
||||
|
||||
ids_features = {}
|
||||
|
||||
# Volume features (combine fwd+bwd)
|
||||
ids_features['total_packets'] = (
|
||||
df.get('Total Fwd Packets', 0) +
|
||||
df.get('Total Backward Packets', 0)
|
||||
)
|
||||
ids_features['total_bytes'] = (
|
||||
df.get('Total Length of Fwd Packets', 0) +
|
||||
df.get('Total Length of Bwd Packets', 0)
|
||||
)
|
||||
ids_features['conn_count'] = 1 # Each row = 1 flow
|
||||
ids_features['avg_packet_size'] = df.get('Fwd Packet Length Mean', 0)
|
||||
ids_features['bytes_per_second'] = df.get('Flow Bytes/s', 0)
|
||||
|
||||
# Temporal features
|
||||
ids_features['time_span_seconds'] = df.get('Flow Duration', 0) / 1_000_000 # Microseconds to seconds
|
||||
ids_features['conn_per_second'] = df.get('Flow Packets/s', 0)
|
||||
ids_features['hour_of_day'] = 12 # Unknown, use midday
|
||||
ids_features['day_of_week'] = 3 # Unknown, use Wednesday
|
||||
|
||||
# Burst detection (approximate)
|
||||
ids_features['max_burst'] = df.get('Subflow Fwd Packets', 0)
|
||||
ids_features['avg_burst'] = df.get('Subflow Fwd Packets', 0)
|
||||
ids_features['burst_variance'] = df.get('Subflow Fwd Bytes', 0).apply(lambda x: max(0, x))
|
||||
ids_features['avg_interval'] = 1.0 # Unknown
|
||||
|
||||
# Protocol diversity
|
||||
ids_features['unique_protocols'] = 1 # Each row = single protocol
|
||||
ids_features['unique_dest_ports'] = 1
|
||||
ids_features['unique_dest_ips'] = 1
|
||||
ids_features['protocol_entropy'] = 0
|
||||
ids_features['tcp_ratio'] = (df.get('Protocol', 6) == 6).astype(int)
|
||||
ids_features['udp_ratio'] = (df.get('Protocol', 17) == 17).astype(int)
|
||||
|
||||
# Port scanning detection
|
||||
ids_features['unique_ports_contacted'] = df.get('Destination Port', 0).apply(lambda x: 1 if x > 0 else 0)
|
||||
ids_features['port_scan_score'] = (df.get('Fwd PSH Flags', 0) + df.get('Fwd URG Flags', 0)) / 2
|
||||
ids_features['sequential_ports'] = 0
|
||||
|
||||
# Behavioral anomalies
|
||||
ids_features['packets_per_conn'] = ids_features['total_packets']
|
||||
ids_features['packet_size_variance'] = df.get('Fwd Packet Length Std', 0)
|
||||
ids_features['blocked_ratio'] = 0
|
||||
|
||||
# Add labels
|
||||
ids_features['attack_type'] = df['attack_type']
|
||||
ids_features['is_attack'] = df['is_attack']
|
||||
|
||||
ids_df = pd.DataFrame(ids_features)
|
||||
|
||||
# Clip negative values
|
||||
numeric_cols = ids_df.select_dtypes(include=[np.number]).columns
|
||||
ids_df[numeric_cols] = ids_df[numeric_cols].clip(lower=0)
|
||||
|
||||
logger.info(f"Mapped to {len(ids_df.columns)} IDS features")
|
||||
return ids_df
|
||||
|
||||
def load_and_process_all(
|
||||
self,
|
||||
sample_frac: float = 1.0,
|
||||
train_ratio: float = 0.7,
|
||||
val_ratio: float = 0.15
|
||||
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Load all days, preprocess, map to IDS features, and split
|
||||
Returns: train_df, val_df, test_df
|
||||
"""
|
||||
exists, missing = self.check_dataset_exists()
|
||||
if not exists:
|
||||
raise FileNotFoundError(
|
||||
f"Missing dataset files: {missing}\n\n"
|
||||
f"{self.download_instructions()}"
|
||||
)
|
||||
|
||||
all_data = []
|
||||
for fname in self.DATASET_INFO['files']:
|
||||
try:
|
||||
df = self.load_day(fname, sample_frac=sample_frac)
|
||||
df = self.preprocess(df)
|
||||
df_ids = self.map_to_ids_features(df)
|
||||
all_data.append(df_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load {fname}: {e}")
|
||||
continue
|
||||
|
||||
if not all_data:
|
||||
raise ValueError("No data loaded successfully")
|
||||
|
||||
# Combine all days
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
logger.info(f"Combined dataset: {len(combined)} rows")
|
||||
|
||||
# Shuffle
|
||||
combined = combined.sample(frac=1, random_state=42).reset_index(drop=True)
|
||||
|
||||
# Split train/val/test
|
||||
n = len(combined)
|
||||
n_train = int(n * train_ratio)
|
||||
n_val = int(n * val_ratio)
|
||||
|
||||
train_df = combined.iloc[:n_train]
|
||||
val_df = combined.iloc[n_train:n_train+n_val]
|
||||
test_df = combined.iloc[n_train+n_val:]
|
||||
|
||||
logger.info(f"Split: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}")
|
||||
|
||||
return train_df, val_df, test_df
|
||||
|
||||
def create_sample_dataset(self, n_samples: int = 10000) -> pd.DataFrame:
|
||||
"""
|
||||
Create synthetic sample dataset for testing
|
||||
Mimics CICIDS2017 structure
|
||||
"""
|
||||
logger.info(f"Creating sample dataset ({n_samples} samples)...")
|
||||
|
||||
np.random.seed(42)
|
||||
|
||||
# Generate synthetic features
|
||||
data = {
|
||||
'total_packets': np.random.lognormal(3, 1.5, n_samples).astype(int),
|
||||
'total_bytes': np.random.lognormal(8, 2, n_samples).astype(int),
|
||||
'conn_count': np.ones(n_samples, dtype=int),
|
||||
'avg_packet_size': np.random.normal(500, 200, n_samples),
|
||||
'bytes_per_second': np.random.lognormal(6, 2, n_samples),
|
||||
'time_span_seconds': np.random.exponential(10, n_samples),
|
||||
'conn_per_second': np.random.exponential(5, n_samples),
|
||||
'hour_of_day': np.random.randint(0, 24, n_samples),
|
||||
'day_of_week': np.random.randint(0, 7, n_samples),
|
||||
'max_burst': np.random.poisson(20, n_samples),
|
||||
'avg_burst': np.random.poisson(15, n_samples),
|
||||
'burst_variance': np.random.exponential(5, n_samples),
|
||||
'avg_interval': np.random.exponential(0.1, n_samples),
|
||||
'unique_protocols': np.ones(n_samples, dtype=int),
|
||||
'unique_dest_ports': np.ones(n_samples, dtype=int),
|
||||
'unique_dest_ips': np.ones(n_samples, dtype=int),
|
||||
'protocol_entropy': np.zeros(n_samples),
|
||||
'tcp_ratio': np.random.choice([0, 1], n_samples, p=[0.3, 0.7]),
|
||||
'udp_ratio': np.random.choice([0, 1], n_samples, p=[0.7, 0.3]),
|
||||
'unique_ports_contacted': np.ones(n_samples, dtype=int),
|
||||
'port_scan_score': np.random.beta(1, 10, n_samples),
|
||||
'sequential_ports': np.zeros(n_samples, dtype=int),
|
||||
'packets_per_conn': np.random.lognormal(3, 1.5, n_samples),
|
||||
'packet_size_variance': np.random.exponential(100, n_samples),
|
||||
'blocked_ratio': np.zeros(n_samples),
|
||||
}
|
||||
|
||||
# Generate labels: 90% normal, 10% attacks
|
||||
is_attack = np.random.choice([0, 1], n_samples, p=[0.9, 0.1])
|
||||
attack_types = np.where(
|
||||
is_attack == 1,
|
||||
np.random.choice(['ddos', 'port_scan', 'brute_force', 'suspicious'], n_samples),
|
||||
'normal'
|
||||
)
|
||||
|
||||
data['is_attack'] = is_attack
|
||||
data['attack_type'] = attack_types
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Make attacks more extreme
|
||||
attack_mask = df['is_attack'] == 1
|
||||
df.loc[attack_mask, 'total_packets'] *= 10
|
||||
df.loc[attack_mask, 'total_bytes'] *= 15
|
||||
df.loc[attack_mask, 'conn_per_second'] *= 20
|
||||
|
||||
logger.info(f"Sample dataset created: {len(df)} rows")
|
||||
logger.info(f"Attack distribution:\n{df['attack_type'].value_counts()}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
# Utility function
|
||||
def get_cicids2017_loader(data_dir: str = "datasets/cicids2017") -> CICIDS2017Loader:
|
||||
"""Factory function to get loader instance"""
|
||||
return CICIDS2017Loader(data_dir)
|
||||
324
python_ml/validation_metrics.py
Normal file
324
python_ml/validation_metrics.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""
|
||||
Validation Metrics for IDS Models
|
||||
Calculates Precision, Recall, F1-Score, False Positive Rate, Accuracy
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, Tuple, Optional
|
||||
from sklearn.metrics import (
|
||||
precision_score,
|
||||
recall_score,
|
||||
f1_score,
|
||||
accuracy_score,
|
||||
confusion_matrix,
|
||||
roc_auc_score,
|
||||
classification_report
|
||||
)
|
||||
import json
|
||||
|
||||
|
||||
class ValidationMetrics:
|
||||
"""Calculate and track validation metrics for IDS models"""
|
||||
|
||||
def __init__(self):
|
||||
self.history = []
|
||||
|
||||
def calculate(
|
||||
self,
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
y_prob: Optional[np.ndarray] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate all metrics
|
||||
|
||||
Args:
|
||||
y_true: True labels (0=normal, 1=attack)
|
||||
y_pred: Predicted labels (0=normal, 1=attack)
|
||||
y_prob: Prediction probabilities (optional, for ROC-AUC)
|
||||
|
||||
Returns:
|
||||
Dict with all metrics
|
||||
"""
|
||||
# Confusion matrix
|
||||
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
||||
|
||||
# Core metrics
|
||||
precision = precision_score(y_true, y_pred, zero_division=0)
|
||||
recall = recall_score(y_true, y_pred, zero_division=0)
|
||||
f1 = f1_score(y_true, y_pred, zero_division=0)
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
|
||||
# False Positive Rate (critical for IDS!)
|
||||
fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
|
||||
|
||||
# True Negative Rate (Specificity)
|
||||
tnr = tn / (tn + fp) if (tn + fp) > 0 else 0
|
||||
|
||||
# Matthews Correlation Coefficient (good for imbalanced datasets)
|
||||
mcc_num = (tp * tn) - (fp * fn)
|
||||
mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
|
||||
mcc = mcc_num / mcc_den if mcc_den > 0 else 0
|
||||
|
||||
metrics = {
|
||||
# Primary metrics
|
||||
'precision': float(precision),
|
||||
'recall': float(recall),
|
||||
'f1_score': float(f1),
|
||||
'accuracy': float(accuracy),
|
||||
'false_positive_rate': float(fpr),
|
||||
|
||||
# Additional metrics
|
||||
'true_negative_rate': float(tnr), # Specificity
|
||||
'matthews_corr_coef': float(mcc),
|
||||
|
||||
# Confusion matrix
|
||||
'true_positives': int(tp),
|
||||
'false_positives': int(fp),
|
||||
'true_negatives': int(tn),
|
||||
'false_negatives': int(fn),
|
||||
|
||||
# Sample counts
|
||||
'total_samples': int(len(y_true)),
|
||||
'total_attacks': int(np.sum(y_true == 1)),
|
||||
'total_normal': int(np.sum(y_true == 0)),
|
||||
}
|
||||
|
||||
# ROC-AUC if probabilities provided
|
||||
if y_prob is not None:
|
||||
try:
|
||||
roc_auc = roc_auc_score(y_true, y_prob)
|
||||
metrics['roc_auc'] = float(roc_auc)
|
||||
except Exception:
|
||||
metrics['roc_auc'] = None
|
||||
|
||||
return metrics
|
||||
|
||||
def calculate_per_class(
|
||||
self,
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
class_names: Optional[list] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate metrics per attack type
|
||||
|
||||
Args:
|
||||
y_true: True class labels (attack types)
|
||||
y_pred: Predicted class labels
|
||||
class_names: List of class names
|
||||
|
||||
Returns:
|
||||
Dict with per-class metrics
|
||||
"""
|
||||
if class_names is None:
|
||||
class_names = sorted(np.unique(np.concatenate([y_true, y_pred])))
|
||||
|
||||
# Get classification report as dict
|
||||
report = classification_report(
|
||||
y_true,
|
||||
y_pred,
|
||||
target_names=class_names,
|
||||
output_dict=True,
|
||||
zero_division=0
|
||||
)
|
||||
|
||||
# Format per-class metrics
|
||||
per_class = {}
|
||||
for class_name in class_names:
|
||||
if class_name in report:
|
||||
per_class[class_name] = {
|
||||
'precision': report[class_name]['precision'],
|
||||
'recall': report[class_name]['recall'],
|
||||
'f1_score': report[class_name]['f1-score'],
|
||||
'support': report[class_name]['support'],
|
||||
}
|
||||
|
||||
# Add macro/weighted averages
|
||||
per_class['macro_avg'] = report['macro avg']
|
||||
per_class['weighted_avg'] = report['weighted avg']
|
||||
|
||||
return per_class
|
||||
|
||||
def print_summary(self, metrics: Dict, title: str = "Validation Metrics"):
|
||||
"""Print formatted metrics summary"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"{title:^60}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
print(f"\n🎯 Primary Metrics:")
|
||||
print(f" Precision: {metrics['precision']*100:6.2f}% (of 100 flagged, how many are real attacks)")
|
||||
print(f" Recall: {metrics['recall']*100:6.2f}% (of 100 attacks, how many detected)")
|
||||
print(f" F1-Score: {metrics['f1_score']*100:6.2f}% (harmonic mean of P&R)")
|
||||
print(f" Accuracy: {metrics['accuracy']*100:6.2f}% (overall correctness)")
|
||||
|
||||
print(f"\n⚠️ False Positive Analysis:")
|
||||
print(f" FP Rate: {metrics['false_positive_rate']*100:6.2f}% (normal traffic flagged as attack)")
|
||||
print(f" FP Count: {metrics['false_positives']:6d} (actual false positives)")
|
||||
print(f" TN Rate: {metrics['true_negative_rate']*100:6.2f}% (specificity - correct normal)")
|
||||
|
||||
print(f"\n📊 Confusion Matrix:")
|
||||
print(f" Predicted Normal Predicted Attack")
|
||||
print(f" Actual Normal {metrics['true_negatives']:6d} {metrics['false_positives']:6d}")
|
||||
print(f" Actual Attack {metrics['false_negatives']:6d} {metrics['true_positives']:6d}")
|
||||
|
||||
print(f"\n📈 Dataset Statistics:")
|
||||
print(f" Total Samples: {metrics['total_samples']:6d}")
|
||||
print(f" Total Attacks: {metrics['total_attacks']:6d} ({metrics['total_attacks']/metrics['total_samples']*100:.1f}%)")
|
||||
print(f" Total Normal: {metrics['total_normal']:6d} ({metrics['total_normal']/metrics['total_samples']*100:.1f}%)")
|
||||
|
||||
if 'roc_auc' in metrics and metrics['roc_auc'] is not None:
|
||||
print(f"\n🎲 ROC-AUC: {metrics['roc_auc']:6.4f}")
|
||||
|
||||
if 'matthews_corr_coef' in metrics:
|
||||
print(f" MCC: {metrics['matthews_corr_coef']:6.4f} (correlation coefficient)")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
def compare_models(
|
||||
self,
|
||||
model_metrics: Dict[str, Dict],
|
||||
highlight_best: bool = True
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Compare metrics across multiple models
|
||||
|
||||
Args:
|
||||
model_metrics: Dict of {model_name: metrics_dict}
|
||||
highlight_best: Print best model
|
||||
|
||||
Returns:
|
||||
DataFrame with comparison
|
||||
"""
|
||||
comparison = pd.DataFrame(model_metrics).T
|
||||
|
||||
# Select key columns
|
||||
key_cols = ['precision', 'recall', 'f1_score', 'accuracy', 'false_positive_rate']
|
||||
comparison = comparison[key_cols]
|
||||
|
||||
# Convert to percentages
|
||||
for col in key_cols:
|
||||
comparison[col] = comparison[col] * 100
|
||||
|
||||
# Round to 2 decimals
|
||||
comparison = comparison.round(2)
|
||||
|
||||
if highlight_best:
|
||||
print("\n📊 Model Comparison:")
|
||||
print(comparison.to_string())
|
||||
|
||||
# Find best model (highest F1, lowest FPR)
|
||||
comparison['score'] = comparison['f1_score'] - comparison['false_positive_rate']
|
||||
best_model = comparison['score'].idxmax()
|
||||
|
||||
print(f"\n🏆 Best Model: {best_model}")
|
||||
print(f" - F1-Score: {comparison.loc[best_model, 'f1_score']:.2f}%")
|
||||
print(f" - FPR: {comparison.loc[best_model, 'false_positive_rate']:.2f}%")
|
||||
|
||||
return comparison
|
||||
|
||||
def save_metrics(self, metrics: Dict, filepath: str):
|
||||
"""Save metrics to JSON file"""
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
print(f"[METRICS] Saved to {filepath}")
|
||||
|
||||
def load_metrics(self, filepath: str) -> Dict:
|
||||
"""Load metrics from JSON file"""
|
||||
with open(filepath) as f:
|
||||
metrics = json.load(f)
|
||||
return metrics
|
||||
|
||||
def meets_production_criteria(
|
||||
self,
|
||||
metrics: Dict,
|
||||
min_precision: float = 0.90,
|
||||
max_fpr: float = 0.05,
|
||||
min_recall: float = 0.80
|
||||
) -> Tuple[bool, list]:
|
||||
"""
|
||||
Check if model meets production deployment criteria
|
||||
|
||||
Args:
|
||||
metrics: Calculated metrics
|
||||
min_precision: Minimum acceptable precision (default 90%)
|
||||
max_fpr: Maximum acceptable FPR (default 5%)
|
||||
min_recall: Minimum acceptable recall (default 80%)
|
||||
|
||||
Returns:
|
||||
(passes: bool, issues: list)
|
||||
"""
|
||||
issues = []
|
||||
|
||||
if metrics['precision'] < min_precision:
|
||||
issues.append(
|
||||
f"Precision {metrics['precision']*100:.1f}% < {min_precision*100:.0f}% "
|
||||
f"(too many false positives)"
|
||||
)
|
||||
|
||||
if metrics['false_positive_rate'] > max_fpr:
|
||||
issues.append(
|
||||
f"FPR {metrics['false_positive_rate']*100:.1f}% > {max_fpr*100:.0f}% "
|
||||
f"(flagging too much normal traffic)"
|
||||
)
|
||||
|
||||
if metrics['recall'] < min_recall:
|
||||
issues.append(
|
||||
f"Recall {metrics['recall']*100:.1f}% < {min_recall*100:.0f}% "
|
||||
f"(missing too many attacks)"
|
||||
)
|
||||
|
||||
passes = len(issues) == 0
|
||||
|
||||
if passes:
|
||||
print("✅ Model meets production criteria!")
|
||||
else:
|
||||
print("❌ Model does NOT meet production criteria:")
|
||||
for issue in issues:
|
||||
print(f" - {issue}")
|
||||
|
||||
return passes, issues
|
||||
|
||||
|
||||
def calculate_confidence_metrics(
|
||||
detections: list,
|
||||
ground_truth: Dict[str, bool]
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate metrics for confidence-based detection system
|
||||
|
||||
Args:
|
||||
detections: List of detection dicts with 'source_ip' and 'confidence_level'
|
||||
ground_truth: Dict of {ip: is_attack (bool)}
|
||||
|
||||
Returns:
|
||||
Metrics broken down by confidence level
|
||||
"""
|
||||
confidence_levels = ['high', 'medium', 'low']
|
||||
metrics_by_confidence = {}
|
||||
|
||||
for level in confidence_levels:
|
||||
level_detections = [d for d in detections if d.get('confidence_level') == level]
|
||||
|
||||
if not level_detections:
|
||||
metrics_by_confidence[level] = {
|
||||
'count': 0,
|
||||
'true_positives': 0,
|
||||
'false_positives': 0,
|
||||
'precision': 0.0
|
||||
}
|
||||
continue
|
||||
|
||||
tp = sum(1 for d in level_detections if ground_truth.get(d['source_ip'], False))
|
||||
fp = len(level_detections) - tp
|
||||
precision = tp / len(level_detections) if level_detections else 0
|
||||
|
||||
metrics_by_confidence[level] = {
|
||||
'count': len(level_detections),
|
||||
'true_positives': tp,
|
||||
'false_positives': fp,
|
||||
'precision': precision
|
||||
}
|
||||
|
||||
return metrics_by_confidence
|
||||
Loading…
Reference in New Issue
Block a user