Enhance error handling in Python ML backend for training and update script location. Replit-Commit-Author: Agent Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528 Replit-Commit-Checkpoint-Type: full_checkpoint Replit-Commit-Event-Id: b7099249-7827-46da-bdf9-2ff1d9c07b6c Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/VDRknFA
577 lines
18 KiB
Python
577 lines
18 KiB
Python
"""
|
|
IDS Backend FastAPI - Intrusion Detection System
|
|
Gestisce training ML, detection real-time e comunicazione con router MikroTik
|
|
"""
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Security, Header
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import APIKeyHeader
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional, Dict
|
|
from datetime import datetime, timedelta
|
|
import pandas as pd
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import asyncio
|
|
import secrets
|
|
|
|
from ml_analyzer import MLAnalyzer
|
|
from mikrotik_manager import MikroTikManager
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# API Key Security
|
|
API_KEY_NAME = "X-API-Key"
|
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
|
|
|
def get_api_key():
|
|
"""Get API key from environment"""
|
|
return os.getenv("IDS_API_KEY")
|
|
|
|
async def verify_api_key(api_key: str = Security(api_key_header)):
|
|
"""Verify API key for internal service communication"""
|
|
expected_key = get_api_key()
|
|
|
|
# In development without API key, allow access (backward compatibility)
|
|
if not expected_key:
|
|
return True
|
|
|
|
if not api_key or api_key != expected_key:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Invalid or missing API key"
|
|
)
|
|
return True
|
|
|
|
app = FastAPI(title="IDS API", version="1.0.0")
|
|
|
|
# CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Global instances
|
|
ml_analyzer = MLAnalyzer(model_dir="models")
|
|
mikrotik_manager = MikroTikManager()
|
|
|
|
# Database connection
|
|
def get_db_connection():
|
|
return psycopg2.connect(
|
|
host=os.getenv("PGHOST"),
|
|
port=os.getenv("PGPORT"),
|
|
database=os.getenv("PGDATABASE"),
|
|
user=os.getenv("PGUSER"),
|
|
password=os.getenv("PGPASSWORD")
|
|
)
|
|
|
|
# Pydantic models
|
|
class TrainRequest(BaseModel):
|
|
max_records: int = 10000
|
|
hours_back: int = 24
|
|
contamination: float = 0.01
|
|
|
|
class DetectRequest(BaseModel):
|
|
max_records: int = 5000
|
|
hours_back: int = 1
|
|
risk_threshold: float = 60.0
|
|
auto_block: bool = False
|
|
|
|
class BlockIPRequest(BaseModel):
|
|
ip_address: str
|
|
list_name: str = "ddos_blocked"
|
|
comment: Optional[str] = None
|
|
timeout_duration: str = "1h"
|
|
|
|
class UnblockIPRequest(BaseModel):
|
|
ip_address: str
|
|
list_name: str = "ddos_blocked"
|
|
|
|
|
|
# API Endpoints
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {
|
|
"service": "IDS API",
|
|
"version": "1.0.0",
|
|
"status": "running",
|
|
"model_loaded": ml_analyzer.model is not None
|
|
}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Check system health"""
|
|
try:
|
|
conn = get_db_connection()
|
|
conn.close()
|
|
db_status = "connected"
|
|
except Exception as e:
|
|
db_status = f"error: {str(e)}"
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"database": db_status,
|
|
"ml_model": "loaded" if ml_analyzer.model is not None else "not_loaded",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
@app.post("/train")
|
|
async def train_model(request: TrainRequest, background_tasks: BackgroundTasks):
|
|
"""
|
|
Addestra il modello ML sui log recenti
|
|
Esegue in background per non bloccare l'API
|
|
"""
|
|
def do_training():
|
|
conn = None
|
|
cursor = None
|
|
try:
|
|
print("[TRAIN] Inizio training...")
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch logs recenti
|
|
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
|
|
query = """
|
|
SELECT * FROM network_logs
|
|
WHERE timestamp >= %s
|
|
ORDER BY timestamp DESC
|
|
LIMIT %s
|
|
"""
|
|
cursor.execute(query, (min_timestamp, request.max_records))
|
|
logs = cursor.fetchall()
|
|
|
|
if not logs:
|
|
print("[TRAIN] Nessun log trovato per training")
|
|
return
|
|
|
|
print(f"[TRAIN] Trovati {len(logs)} log per training")
|
|
|
|
# Converti in DataFrame
|
|
df = pd.DataFrame(logs)
|
|
|
|
# Training
|
|
print("[TRAIN] Addestramento modello...")
|
|
result = ml_analyzer.train(df, contamination=request.contamination)
|
|
print(f"[TRAIN] Modello addestrato: {result}")
|
|
|
|
# Salva nel database
|
|
print("[TRAIN] Salvataggio training history nel database...")
|
|
cursor.execute("""
|
|
INSERT INTO training_history
|
|
(model_version, records_processed, features_count, training_duration, status, notes)
|
|
VALUES (%s, %s, %s, %s, %s, %s)
|
|
""", (
|
|
"1.0.0",
|
|
result['records_processed'],
|
|
result['features_count'],
|
|
0, # duration non ancora implementato
|
|
result['status'],
|
|
f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']}"
|
|
))
|
|
conn.commit()
|
|
print("[TRAIN] ✅ Training history salvato con successo!")
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
print(f"[TRAIN] Completato: {result}")
|
|
|
|
except Exception as e:
|
|
print(f"[TRAIN ERROR] ❌ Errore durante training: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
if conn:
|
|
conn.rollback()
|
|
finally:
|
|
if cursor:
|
|
cursor.close()
|
|
if conn:
|
|
conn.close()
|
|
|
|
# Esegui in background
|
|
background_tasks.add_task(do_training)
|
|
|
|
return {
|
|
"message": "Training avviato in background",
|
|
"max_records": request.max_records,
|
|
"hours_back": request.hours_back
|
|
}
|
|
|
|
@app.post("/detect")
|
|
async def detect_anomalies(request: DetectRequest):
|
|
"""
|
|
Rileva anomalie nei log recenti
|
|
Opzionalmente blocca automaticamente IP anomali
|
|
"""
|
|
if ml_analyzer.model is None:
|
|
# Prova a caricare modello salvato
|
|
if not ml_analyzer.load_model():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Modello non addestrato. Esegui /train prima."
|
|
)
|
|
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch logs recenti
|
|
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
|
|
query = """
|
|
SELECT * FROM network_logs
|
|
WHERE timestamp >= %s
|
|
ORDER BY timestamp DESC
|
|
LIMIT %s
|
|
"""
|
|
cursor.execute(query, (min_timestamp, request.max_records))
|
|
logs = cursor.fetchall()
|
|
|
|
if not logs:
|
|
return {"detections": [], "message": "Nessun log da analizzare"}
|
|
|
|
# Converti in DataFrame
|
|
df = pd.DataFrame(logs)
|
|
|
|
# Detection
|
|
detections = ml_analyzer.detect(df, risk_threshold=request.risk_threshold)
|
|
|
|
# Salva detections nel database
|
|
for det in detections:
|
|
# Controlla se già esiste
|
|
cursor.execute(
|
|
"SELECT id FROM detections WHERE source_ip = %s ORDER BY detected_at DESC LIMIT 1",
|
|
(det['source_ip'],)
|
|
)
|
|
existing = cursor.fetchone()
|
|
|
|
if existing:
|
|
# Aggiorna esistente
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET risk_score = %s, confidence = %s, anomaly_type = %s,
|
|
reason = %s, log_count = %s, last_seen = %s
|
|
WHERE id = %s
|
|
""", (
|
|
det['risk_score'], det['confidence'], det['anomaly_type'],
|
|
det['reason'], det['log_count'], det['last_seen'],
|
|
existing['id']
|
|
))
|
|
else:
|
|
# Inserisci nuovo
|
|
cursor.execute("""
|
|
INSERT INTO detections
|
|
(source_ip, risk_score, confidence, anomaly_type, reason,
|
|
log_count, first_seen, last_seen)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
""", (
|
|
det['source_ip'], det['risk_score'], det['confidence'],
|
|
det['anomaly_type'], det['reason'], det['log_count'],
|
|
det['first_seen'], det['last_seen']
|
|
))
|
|
|
|
conn.commit()
|
|
|
|
# Auto-block se richiesto
|
|
blocked_count = 0
|
|
if request.auto_block and detections:
|
|
# Fetch routers abilitati
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if routers:
|
|
for det in detections:
|
|
if det['risk_score'] >= 80: # Solo rischio CRITICO
|
|
# Controlla whitelist
|
|
cursor.execute(
|
|
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
|
|
(det['source_ip'],)
|
|
)
|
|
if cursor.fetchone():
|
|
continue # Skip whitelisted
|
|
|
|
# Blocca su tutti i router
|
|
results = await mikrotik_manager.block_ip_on_all_routers(
|
|
routers,
|
|
det['source_ip'],
|
|
comment=f"IDS: {det['anomaly_type']}"
|
|
)
|
|
|
|
if any(results.values()):
|
|
# Aggiorna detection
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET blocked = true, blocked_at = NOW()
|
|
WHERE source_ip = %s
|
|
""", (det['source_ip'],))
|
|
blocked_count += 1
|
|
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"detections": detections,
|
|
"total": len(detections),
|
|
"blocked": blocked_count if request.auto_block else 0,
|
|
"message": f"Trovate {len(detections)} anomalie"
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/block-ip")
|
|
async def block_ip(request: BlockIPRequest):
|
|
"""Blocca manualmente un IP su tutti i router"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Controlla whitelist
|
|
cursor.execute(
|
|
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
|
|
(request.ip_address,)
|
|
)
|
|
if cursor.fetchone():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"IP {request.ip_address} è in whitelist"
|
|
)
|
|
|
|
# Fetch routers
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if not routers:
|
|
raise HTTPException(status_code=400, detail="Nessun router configurato")
|
|
|
|
# Blocca su tutti i router
|
|
results = await mikrotik_manager.block_ip_on_all_routers(
|
|
routers,
|
|
request.ip_address,
|
|
list_name=request.list_name,
|
|
comment=request.comment or "Manual block",
|
|
timeout_duration=request.timeout_duration
|
|
)
|
|
|
|
success_count = sum(1 for v in results.values() if v)
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"ip_address": request.ip_address,
|
|
"blocked_on": success_count,
|
|
"total_routers": len(routers),
|
|
"results": results
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/unblock-ip")
|
|
async def unblock_ip(request: UnblockIPRequest):
|
|
"""Sblocca un IP da tutti i router"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch routers
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if not routers:
|
|
raise HTTPException(status_code=400, detail="Nessun router configurato")
|
|
|
|
# Sblocca da tutti i router
|
|
results = await mikrotik_manager.unblock_ip_on_all_routers(
|
|
routers,
|
|
request.ip_address,
|
|
list_name=request.list_name
|
|
)
|
|
|
|
success_count = sum(1 for v in results.values() if v)
|
|
|
|
# Aggiorna database
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET blocked = false
|
|
WHERE source_ip = %s
|
|
""", (request.ip_address,))
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"ip_address": request.ip_address,
|
|
"unblocked_from": success_count,
|
|
"total_routers": len(routers),
|
|
"results": results
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/stats")
|
|
async def get_stats():
|
|
"""Statistiche sistema"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Log stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM network_logs")
|
|
result = cursor.fetchone()
|
|
total_logs = result['total'] if result else 0
|
|
|
|
cursor.execute("""
|
|
SELECT COUNT(*) as recent FROM network_logs
|
|
WHERE timestamp >= NOW() - INTERVAL '1 hour'
|
|
""")
|
|
result = cursor.fetchone()
|
|
recent_logs = result['recent'] if result else 0
|
|
|
|
# Detection stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM detections")
|
|
result = cursor.fetchone()
|
|
total_detections = result['total'] if result else 0
|
|
|
|
cursor.execute("""
|
|
SELECT COUNT(*) as blocked FROM detections WHERE blocked = true
|
|
""")
|
|
result = cursor.fetchone()
|
|
blocked_ips = result['blocked'] if result else 0
|
|
|
|
# Router stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM routers WHERE enabled = true")
|
|
result = cursor.fetchone()
|
|
active_routers = result['total'] if result else 0
|
|
|
|
# Latest training
|
|
cursor.execute("""
|
|
SELECT * FROM training_history
|
|
ORDER BY trained_at DESC LIMIT 1
|
|
""")
|
|
latest_training = cursor.fetchone()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"logs": {
|
|
"total": total_logs,
|
|
"last_hour": recent_logs
|
|
},
|
|
"detections": {
|
|
"total": total_detections,
|
|
"blocked": blocked_ips
|
|
},
|
|
"routers": {
|
|
"active": active_routers
|
|
},
|
|
"latest_training": dict(latest_training) if latest_training else None
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# Service Management Endpoints (Secured with API Key)
|
|
|
|
@app.get("/services/status")
|
|
async def get_services_status(authorized: bool = Security(verify_api_key)):
|
|
"""
|
|
Verifica status di tutti i servizi gestiti da systemd
|
|
RICHIEDE: API Key valida per accesso
|
|
"""
|
|
try:
|
|
import subprocess
|
|
|
|
services_to_check = {
|
|
"ml_backend": "ids-ml-backend",
|
|
"syslog_parser": "ids-syslog-parser"
|
|
}
|
|
|
|
results = {}
|
|
|
|
for service_name, systemd_unit in services_to_check.items():
|
|
try:
|
|
# Check systemd service status
|
|
result = subprocess.run(
|
|
["systemctl", "is-active", systemd_unit],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5
|
|
)
|
|
|
|
is_active = result.stdout.strip() == "active"
|
|
|
|
# Get more details if active
|
|
details = {}
|
|
if is_active:
|
|
status_result = subprocess.run(
|
|
["systemctl", "status", systemd_unit, "--no-pager"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5
|
|
)
|
|
# Extract PID from status if available
|
|
for line in status_result.stdout.split('\n'):
|
|
if 'Main PID:' in line:
|
|
pid = line.split('Main PID:')[1].strip().split()[0]
|
|
details['pid'] = pid
|
|
break
|
|
|
|
results[service_name] = {
|
|
"running": is_active,
|
|
"systemd_unit": systemd_unit,
|
|
"details": details
|
|
}
|
|
|
|
except subprocess.TimeoutExpired:
|
|
results[service_name] = {
|
|
"running": False,
|
|
"error": "Timeout checking service"
|
|
}
|
|
except FileNotFoundError:
|
|
# systemctl not available (likely development environment)
|
|
results[service_name] = {
|
|
"running": False,
|
|
"error": "systemd not available"
|
|
}
|
|
except Exception as e:
|
|
results[service_name] = {
|
|
"running": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
return {
|
|
"services": results,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# Prova a caricare modello esistente
|
|
ml_analyzer.load_model()
|
|
|
|
print("🚀 Starting IDS API on http://0.0.0.0:8000")
|
|
print("📚 Docs available at http://0.0.0.0:8000/docs")
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|