ids.alfacom.it/python_ml/main.py
marco370 8b16800bb6 Update system to use hybrid detector and improve validation accuracy
Update main.py endpoints to use the hybrid detector and improve validation logic in train_hybrid.py by mapping detections using source_ip. Also, add synthetic source_ip to dataset_loader.py for both CICIDS2017 and synthetic datasets.

Replit-Commit-Author: Agent
Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528
Replit-Commit-Checkpoint-Type: intermediate_checkpoint
Replit-Commit-Event-Id: 5c4982f1-3d37-47da-9253-c04888f5ff64
Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/2lUhxO2
2025-11-24 16:02:49 +00:00

664 lines
22 KiB
Python

"""
IDS Backend FastAPI - Intrusion Detection System
Gestisce training ML, detection real-time e comunicazione con router MikroTik
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks, Security, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from typing import List, Optional, Dict
from datetime import datetime, timedelta
import pandas as pd
import psycopg2
from psycopg2.extras import RealDictCursor
import os
from dotenv import load_dotenv
import asyncio
import secrets
from ml_analyzer import MLAnalyzer
from ml_hybrid_detector import MLHybridDetector
from mikrotik_manager import MikroTikManager
from ip_geolocation import get_geo_service
# Load environment variables
load_dotenv()
# API Key Security
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
def get_api_key():
"""Get API key from environment"""
return os.getenv("IDS_API_KEY")
async def verify_api_key(api_key: str = Security(api_key_header)):
"""Verify API key for internal service communication"""
expected_key = get_api_key()
# In development without API key, allow access (backward compatibility)
if not expected_key:
return True
if not api_key or api_key != expected_key:
raise HTTPException(
status_code=403,
detail="Invalid or missing API key"
)
return True
app = FastAPI(title="IDS API", version="2.0.0")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global instances - Try hybrid first, fallback to legacy
USE_HYBRID_DETECTOR = os.getenv("USE_HYBRID_DETECTOR", "true").lower() == "true"
if USE_HYBRID_DETECTOR:
print("[ML] Using Hybrid ML Detector (Extended Isolation Forest + Feature Selection)")
ml_detector = MLHybridDetector(model_dir="models")
# Try to load existing model
if not ml_detector.load_models():
print("[ML] No hybrid model found, will use on-demand training")
ml_analyzer = None # Legacy disabled
else:
print("[ML] Using Legacy ML Analyzer (standard Isolation Forest)")
ml_analyzer = MLAnalyzer(model_dir="models")
ml_detector = None
mikrotik_manager = MikroTikManager()
# Database connection
def get_db_connection():
return psycopg2.connect(
host=os.getenv("PGHOST"),
port=os.getenv("PGPORT"),
database=os.getenv("PGDATABASE"),
user=os.getenv("PGUSER"),
password=os.getenv("PGPASSWORD")
)
# Pydantic models
class TrainRequest(BaseModel):
max_records: int = 10000
hours_back: int = 24
contamination: float = 0.01
class DetectRequest(BaseModel):
max_records: int = 5000
hours_back: int = 1
risk_threshold: float = 60.0
auto_block: bool = False
class BlockIPRequest(BaseModel):
ip_address: str
list_name: str = "ddos_blocked"
comment: Optional[str] = None
timeout_duration: str = "1h"
class UnblockIPRequest(BaseModel):
ip_address: str
list_name: str = "ddos_blocked"
# API Endpoints
@app.get("/")
async def root():
# Check which detector is active
if USE_HYBRID_DETECTOR:
model_loaded = ml_detector.isolation_forest is not None
model_type = "hybrid"
else:
model_loaded = ml_analyzer.model is not None
model_type = "legacy"
return {
"service": "IDS API",
"version": "2.0.0",
"status": "running",
"model_type": model_type,
"model_loaded": model_loaded,
"use_hybrid": USE_HYBRID_DETECTOR
}
@app.get("/health")
async def health_check():
"""Check system health"""
try:
conn = get_db_connection()
conn.close()
db_status = "connected"
except Exception as e:
db_status = f"error: {str(e)}"
# Check model status
if USE_HYBRID_DETECTOR:
model_status = "loaded" if ml_detector.isolation_forest is not None else "not_loaded"
model_type = "hybrid (EIF + Feature Selection)"
else:
model_status = "loaded" if ml_analyzer.model is not None else "not_loaded"
model_type = "legacy (Isolation Forest)"
return {
"status": "healthy",
"database": db_status,
"ml_model": model_status,
"ml_model_type": model_type,
"timestamp": datetime.now().isoformat()
}
@app.post("/train")
async def train_model(request: TrainRequest, background_tasks: BackgroundTasks):
"""
Addestra il modello ML sui log recenti
Esegue in background per non bloccare l'API
"""
def do_training():
conn = None
cursor = None
try:
print("[TRAIN] Inizio training...")
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch logs recenti
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
query = """
SELECT * FROM network_logs
WHERE timestamp >= %s
ORDER BY timestamp DESC
LIMIT %s
"""
cursor.execute(query, (min_timestamp, request.max_records))
logs = cursor.fetchall()
if not logs:
print("[TRAIN] Nessun log trovato per training")
return
print(f"[TRAIN] Trovati {len(logs)} log per training")
# Converti in DataFrame
df = pd.DataFrame(logs)
# Training - usa detector appropriato
print("[TRAIN] Addestramento modello...")
if USE_HYBRID_DETECTOR:
print("[TRAIN] Using Hybrid ML Detector")
result = ml_detector.train_unsupervised(df)
else:
print("[TRAIN] Using Legacy ML Analyzer")
result = ml_analyzer.train(df, contamination=request.contamination)
print(f"[TRAIN] Modello addestrato: {result}")
# Salva nel database
print("[TRAIN] Salvataggio training history nel database...")
cursor.execute("""
INSERT INTO training_history
(model_version, records_processed, features_count, training_duration, status, notes)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
"1.0.0",
result['records_processed'],
result['features_count'],
0, # duration non ancora implementato
result['status'],
f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']}"
))
conn.commit()
print("[TRAIN] ✅ Training history salvato con successo!")
cursor.close()
conn.close()
print(f"[TRAIN] Completato: {result}")
except Exception as e:
print(f"[TRAIN ERROR] ❌ Errore durante training: {e}")
import traceback
traceback.print_exc()
if conn:
conn.rollback()
finally:
if cursor:
cursor.close()
if conn:
conn.close()
# Esegui in background
background_tasks.add_task(do_training)
return {
"message": "Training avviato in background",
"max_records": request.max_records,
"hours_back": request.hours_back
}
@app.post("/detect")
async def detect_anomalies(request: DetectRequest):
"""
Rileva anomalie nei log recenti
Opzionalmente blocca automaticamente IP anomali
"""
# Check model loaded
if USE_HYBRID_DETECTOR:
if ml_detector.isolation_forest is None:
# Try to load
if not ml_detector.load_models():
raise HTTPException(
status_code=400,
detail="Modello hybrid non addestrato. Esegui /train prima."
)
else:
if ml_analyzer.model is None:
# Prova a caricare modello salvato
if not ml_analyzer.load_model():
raise HTTPException(
status_code=400,
detail="Modello non addestrato. Esegui /train prima."
)
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch logs recenti
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
query = """
SELECT * FROM network_logs
WHERE timestamp >= %s
ORDER BY timestamp DESC
LIMIT %s
"""
cursor.execute(query, (min_timestamp, request.max_records))
logs = cursor.fetchall()
if not logs:
return {"detections": [], "message": "Nessun log da analizzare"}
# Converti in DataFrame
df = pd.DataFrame(logs)
# Detection - usa detector appropriato
if USE_HYBRID_DETECTOR:
print("[DETECT] Using Hybrid ML Detector")
# Hybrid detector returns different format
detections = ml_detector.detect(df, mode='confidence')
# Convert to legacy format for compatibility
for det in detections:
det['confidence'] = det['confidence_level'] # Map confidence_level to confidence
else:
print("[DETECT] Using Legacy ML Analyzer")
detections = ml_analyzer.detect(df, risk_threshold=request.risk_threshold)
# Geolocation lookup service - BATCH ASYNC per performance
geo_service = get_geo_service()
# Estrai lista IP unici per batch lookup
unique_ips = list(set(det['source_ip'] for det in detections))
# Batch lookup async (VELOCE - tutti in parallelo!)
geo_results = await geo_service.lookup_batch_async(unique_ips)
# Salva detections nel database
for det in detections:
# Get geo info from batch results
geo_info = geo_results.get(det['source_ip'])
# Controlla se già esiste
cursor.execute(
"SELECT id FROM detections WHERE source_ip = %s ORDER BY detected_at DESC LIMIT 1",
(det['source_ip'],)
)
existing = cursor.fetchone()
if existing:
# Aggiorna esistente (con geo info)
cursor.execute("""
UPDATE detections
SET risk_score = %s, confidence = %s, anomaly_type = %s,
reason = %s, log_count = %s, last_seen = %s,
country = %s, country_code = %s, city = %s,
organization = %s, as_number = %s, as_name = %s, isp = %s
WHERE id = %s
""", (
det['risk_score'], det['confidence'], det['anomaly_type'],
det['reason'], det['log_count'], det['last_seen'],
geo_info.get('country') if geo_info else None,
geo_info.get('country_code') if geo_info else None,
geo_info.get('city') if geo_info else None,
geo_info.get('organization') if geo_info else None,
geo_info.get('as_number') if geo_info else None,
geo_info.get('as_name') if geo_info else None,
geo_info.get('isp') if geo_info else None,
existing['id']
))
else:
# Inserisci nuovo (con geo info)
cursor.execute("""
INSERT INTO detections
(source_ip, risk_score, confidence, anomaly_type, reason,
log_count, first_seen, last_seen,
country, country_code, city, organization, as_number, as_name, isp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
det['source_ip'], det['risk_score'], det['confidence'],
det['anomaly_type'], det['reason'], det['log_count'],
det['first_seen'], det['last_seen'],
geo_info.get('country') if geo_info else None,
geo_info.get('country_code') if geo_info else None,
geo_info.get('city') if geo_info else None,
geo_info.get('organization') if geo_info else None,
geo_info.get('as_number') if geo_info else None,
geo_info.get('as_name') if geo_info else None,
geo_info.get('isp') if geo_info else None
))
conn.commit()
# Auto-block se richiesto
blocked_count = 0
if request.auto_block and detections:
# Fetch routers abilitati
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if routers:
for det in detections:
if det['risk_score'] >= 80: # Solo rischio CRITICO
# Controlla whitelist
cursor.execute(
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
(det['source_ip'],)
)
if cursor.fetchone():
continue # Skip whitelisted
# Blocca su tutti i router
results = await mikrotik_manager.block_ip_on_all_routers(
routers,
det['source_ip'],
comment=f"IDS: {det['anomaly_type']}"
)
if any(results.values()):
# Aggiorna detection
cursor.execute("""
UPDATE detections
SET blocked = true, blocked_at = NOW()
WHERE source_ip = %s
""", (det['source_ip'],))
blocked_count += 1
conn.commit()
cursor.close()
conn.close()
return {
"detections": detections,
"total": len(detections),
"blocked": blocked_count if request.auto_block else 0,
"message": f"Trovate {len(detections)} anomalie"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/block-ip")
async def block_ip(request: BlockIPRequest):
"""Blocca manualmente un IP su tutti i router"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Controlla whitelist
cursor.execute(
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
(request.ip_address,)
)
if cursor.fetchone():
raise HTTPException(
status_code=400,
detail=f"IP {request.ip_address} è in whitelist"
)
# Fetch routers
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if not routers:
raise HTTPException(status_code=400, detail="Nessun router configurato")
# Blocca su tutti i router
results = await mikrotik_manager.block_ip_on_all_routers(
routers,
request.ip_address,
list_name=request.list_name,
comment=request.comment or "Manual block",
timeout_duration=request.timeout_duration
)
success_count = sum(1 for v in results.values() if v)
cursor.close()
conn.close()
return {
"ip_address": request.ip_address,
"blocked_on": success_count,
"total_routers": len(routers),
"results": results
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/unblock-ip")
async def unblock_ip(request: UnblockIPRequest):
"""Sblocca un IP da tutti i router"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch routers
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if not routers:
raise HTTPException(status_code=400, detail="Nessun router configurato")
# Sblocca da tutti i router
results = await mikrotik_manager.unblock_ip_on_all_routers(
routers,
request.ip_address,
list_name=request.list_name
)
success_count = sum(1 for v in results.values() if v)
# Aggiorna database
cursor.execute("""
UPDATE detections
SET blocked = false
WHERE source_ip = %s
""", (request.ip_address,))
conn.commit()
cursor.close()
conn.close()
return {
"ip_address": request.ip_address,
"unblocked_from": success_count,
"total_routers": len(routers),
"results": results
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stats")
async def get_stats():
"""Statistiche sistema"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Log stats
cursor.execute("SELECT COUNT(*) as total FROM network_logs")
result = cursor.fetchone()
total_logs = result['total'] if result else 0
cursor.execute("""
SELECT COUNT(*) as recent FROM network_logs
WHERE timestamp >= NOW() - INTERVAL '1 hour'
""")
result = cursor.fetchone()
recent_logs = result['recent'] if result else 0
# Detection stats
cursor.execute("SELECT COUNT(*) as total FROM detections")
result = cursor.fetchone()
total_detections = result['total'] if result else 0
cursor.execute("""
SELECT COUNT(*) as blocked FROM detections WHERE blocked = true
""")
result = cursor.fetchone()
blocked_ips = result['blocked'] if result else 0
# Router stats
cursor.execute("SELECT COUNT(*) as total FROM routers WHERE enabled = true")
result = cursor.fetchone()
active_routers = result['total'] if result else 0
# Latest training
cursor.execute("""
SELECT * FROM training_history
ORDER BY trained_at DESC LIMIT 1
""")
latest_training = cursor.fetchone()
cursor.close()
conn.close()
return {
"logs": {
"total": total_logs,
"last_hour": recent_logs
},
"detections": {
"total": total_detections,
"blocked": blocked_ips
},
"routers": {
"active": active_routers
},
"latest_training": dict(latest_training) if latest_training else None
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Service Management Endpoints (Secured with API Key)
@app.get("/services/status")
async def get_services_status(authorized: bool = Security(verify_api_key)):
"""
Verifica status di tutti i servizi gestiti da systemd
RICHIEDE: API Key valida per accesso
"""
try:
import subprocess
services_to_check = {
"ml_backend": "ids-ml-backend",
"syslog_parser": "ids-syslog-parser"
}
results = {}
for service_name, systemd_unit in services_to_check.items():
try:
# Check systemd service status
result = subprocess.run(
["systemctl", "is-active", systemd_unit],
capture_output=True,
text=True,
timeout=5
)
is_active = result.stdout.strip() == "active"
# Get more details if active
details = {}
if is_active:
status_result = subprocess.run(
["systemctl", "status", systemd_unit, "--no-pager"],
capture_output=True,
text=True,
timeout=5
)
# Extract PID from status if available
for line in status_result.stdout.split('\n'):
if 'Main PID:' in line:
pid = line.split('Main PID:')[1].strip().split()[0]
details['pid'] = pid
break
results[service_name] = {
"running": is_active,
"systemd_unit": systemd_unit,
"details": details
}
except subprocess.TimeoutExpired:
results[service_name] = {
"running": False,
"error": "Timeout checking service"
}
except FileNotFoundError:
# systemctl not available (likely development environment)
results[service_name] = {
"running": False,
"error": "systemd not available"
}
except Exception as e:
results[service_name] = {
"running": False,
"error": str(e)
}
return {
"services": results,
"timestamp": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Prova a caricare modello esistente
ml_analyzer.load_model()
print("🚀 Starting IDS API on http://0.0.0.0:8000")
print("📚 Docs available at http://0.0.0.0:8000/docs")
uvicorn.run(app, host="0.0.0.0", port=8000)