ids.alfacom.it/python_ml/main.py
marco370 c31e1ca838 Improve training history logging and file management
Enhance error handling in Python ML backend for training and update script location.

Replit-Commit-Author: Agent
Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528
Replit-Commit-Checkpoint-Type: full_checkpoint
Replit-Commit-Event-Id: b7099249-7827-46da-bdf9-2ff1d9c07b6c
Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/VDRknFA
2025-11-22 10:21:29 +00:00

577 lines
18 KiB
Python

"""
IDS Backend FastAPI - Intrusion Detection System
Gestisce training ML, detection real-time e comunicazione con router MikroTik
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks, Security, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from typing import List, Optional, Dict
from datetime import datetime, timedelta
import pandas as pd
import psycopg2
from psycopg2.extras import RealDictCursor
import os
from dotenv import load_dotenv
import asyncio
import secrets
from ml_analyzer import MLAnalyzer
from mikrotik_manager import MikroTikManager
# Load environment variables
load_dotenv()
# API Key Security
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
def get_api_key():
"""Get API key from environment"""
return os.getenv("IDS_API_KEY")
async def verify_api_key(api_key: str = Security(api_key_header)):
"""Verify API key for internal service communication"""
expected_key = get_api_key()
# In development without API key, allow access (backward compatibility)
if not expected_key:
return True
if not api_key or api_key != expected_key:
raise HTTPException(
status_code=403,
detail="Invalid or missing API key"
)
return True
app = FastAPI(title="IDS API", version="1.0.0")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global instances
ml_analyzer = MLAnalyzer(model_dir="models")
mikrotik_manager = MikroTikManager()
# Database connection
def get_db_connection():
return psycopg2.connect(
host=os.getenv("PGHOST"),
port=os.getenv("PGPORT"),
database=os.getenv("PGDATABASE"),
user=os.getenv("PGUSER"),
password=os.getenv("PGPASSWORD")
)
# Pydantic models
class TrainRequest(BaseModel):
max_records: int = 10000
hours_back: int = 24
contamination: float = 0.01
class DetectRequest(BaseModel):
max_records: int = 5000
hours_back: int = 1
risk_threshold: float = 60.0
auto_block: bool = False
class BlockIPRequest(BaseModel):
ip_address: str
list_name: str = "ddos_blocked"
comment: Optional[str] = None
timeout_duration: str = "1h"
class UnblockIPRequest(BaseModel):
ip_address: str
list_name: str = "ddos_blocked"
# API Endpoints
@app.get("/")
async def root():
return {
"service": "IDS API",
"version": "1.0.0",
"status": "running",
"model_loaded": ml_analyzer.model is not None
}
@app.get("/health")
async def health_check():
"""Check system health"""
try:
conn = get_db_connection()
conn.close()
db_status = "connected"
except Exception as e:
db_status = f"error: {str(e)}"
return {
"status": "healthy",
"database": db_status,
"ml_model": "loaded" if ml_analyzer.model is not None else "not_loaded",
"timestamp": datetime.now().isoformat()
}
@app.post("/train")
async def train_model(request: TrainRequest, background_tasks: BackgroundTasks):
"""
Addestra il modello ML sui log recenti
Esegue in background per non bloccare l'API
"""
def do_training():
conn = None
cursor = None
try:
print("[TRAIN] Inizio training...")
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch logs recenti
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
query = """
SELECT * FROM network_logs
WHERE timestamp >= %s
ORDER BY timestamp DESC
LIMIT %s
"""
cursor.execute(query, (min_timestamp, request.max_records))
logs = cursor.fetchall()
if not logs:
print("[TRAIN] Nessun log trovato per training")
return
print(f"[TRAIN] Trovati {len(logs)} log per training")
# Converti in DataFrame
df = pd.DataFrame(logs)
# Training
print("[TRAIN] Addestramento modello...")
result = ml_analyzer.train(df, contamination=request.contamination)
print(f"[TRAIN] Modello addestrato: {result}")
# Salva nel database
print("[TRAIN] Salvataggio training history nel database...")
cursor.execute("""
INSERT INTO training_history
(model_version, records_processed, features_count, training_duration, status, notes)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
"1.0.0",
result['records_processed'],
result['features_count'],
0, # duration non ancora implementato
result['status'],
f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']}"
))
conn.commit()
print("[TRAIN] ✅ Training history salvato con successo!")
cursor.close()
conn.close()
print(f"[TRAIN] Completato: {result}")
except Exception as e:
print(f"[TRAIN ERROR] ❌ Errore durante training: {e}")
import traceback
traceback.print_exc()
if conn:
conn.rollback()
finally:
if cursor:
cursor.close()
if conn:
conn.close()
# Esegui in background
background_tasks.add_task(do_training)
return {
"message": "Training avviato in background",
"max_records": request.max_records,
"hours_back": request.hours_back
}
@app.post("/detect")
async def detect_anomalies(request: DetectRequest):
"""
Rileva anomalie nei log recenti
Opzionalmente blocca automaticamente IP anomali
"""
if ml_analyzer.model is None:
# Prova a caricare modello salvato
if not ml_analyzer.load_model():
raise HTTPException(
status_code=400,
detail="Modello non addestrato. Esegui /train prima."
)
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch logs recenti
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
query = """
SELECT * FROM network_logs
WHERE timestamp >= %s
ORDER BY timestamp DESC
LIMIT %s
"""
cursor.execute(query, (min_timestamp, request.max_records))
logs = cursor.fetchall()
if not logs:
return {"detections": [], "message": "Nessun log da analizzare"}
# Converti in DataFrame
df = pd.DataFrame(logs)
# Detection
detections = ml_analyzer.detect(df, risk_threshold=request.risk_threshold)
# Salva detections nel database
for det in detections:
# Controlla se già esiste
cursor.execute(
"SELECT id FROM detections WHERE source_ip = %s ORDER BY detected_at DESC LIMIT 1",
(det['source_ip'],)
)
existing = cursor.fetchone()
if existing:
# Aggiorna esistente
cursor.execute("""
UPDATE detections
SET risk_score = %s, confidence = %s, anomaly_type = %s,
reason = %s, log_count = %s, last_seen = %s
WHERE id = %s
""", (
det['risk_score'], det['confidence'], det['anomaly_type'],
det['reason'], det['log_count'], det['last_seen'],
existing['id']
))
else:
# Inserisci nuovo
cursor.execute("""
INSERT INTO detections
(source_ip, risk_score, confidence, anomaly_type, reason,
log_count, first_seen, last_seen)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""", (
det['source_ip'], det['risk_score'], det['confidence'],
det['anomaly_type'], det['reason'], det['log_count'],
det['first_seen'], det['last_seen']
))
conn.commit()
# Auto-block se richiesto
blocked_count = 0
if request.auto_block and detections:
# Fetch routers abilitati
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if routers:
for det in detections:
if det['risk_score'] >= 80: # Solo rischio CRITICO
# Controlla whitelist
cursor.execute(
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
(det['source_ip'],)
)
if cursor.fetchone():
continue # Skip whitelisted
# Blocca su tutti i router
results = await mikrotik_manager.block_ip_on_all_routers(
routers,
det['source_ip'],
comment=f"IDS: {det['anomaly_type']}"
)
if any(results.values()):
# Aggiorna detection
cursor.execute("""
UPDATE detections
SET blocked = true, blocked_at = NOW()
WHERE source_ip = %s
""", (det['source_ip'],))
blocked_count += 1
conn.commit()
cursor.close()
conn.close()
return {
"detections": detections,
"total": len(detections),
"blocked": blocked_count if request.auto_block else 0,
"message": f"Trovate {len(detections)} anomalie"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/block-ip")
async def block_ip(request: BlockIPRequest):
"""Blocca manualmente un IP su tutti i router"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Controlla whitelist
cursor.execute(
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
(request.ip_address,)
)
if cursor.fetchone():
raise HTTPException(
status_code=400,
detail=f"IP {request.ip_address} è in whitelist"
)
# Fetch routers
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if not routers:
raise HTTPException(status_code=400, detail="Nessun router configurato")
# Blocca su tutti i router
results = await mikrotik_manager.block_ip_on_all_routers(
routers,
request.ip_address,
list_name=request.list_name,
comment=request.comment or "Manual block",
timeout_duration=request.timeout_duration
)
success_count = sum(1 for v in results.values() if v)
cursor.close()
conn.close()
return {
"ip_address": request.ip_address,
"blocked_on": success_count,
"total_routers": len(routers),
"results": results
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/unblock-ip")
async def unblock_ip(request: UnblockIPRequest):
"""Sblocca un IP da tutti i router"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch routers
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if not routers:
raise HTTPException(status_code=400, detail="Nessun router configurato")
# Sblocca da tutti i router
results = await mikrotik_manager.unblock_ip_on_all_routers(
routers,
request.ip_address,
list_name=request.list_name
)
success_count = sum(1 for v in results.values() if v)
# Aggiorna database
cursor.execute("""
UPDATE detections
SET blocked = false
WHERE source_ip = %s
""", (request.ip_address,))
conn.commit()
cursor.close()
conn.close()
return {
"ip_address": request.ip_address,
"unblocked_from": success_count,
"total_routers": len(routers),
"results": results
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stats")
async def get_stats():
"""Statistiche sistema"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Log stats
cursor.execute("SELECT COUNT(*) as total FROM network_logs")
result = cursor.fetchone()
total_logs = result['total'] if result else 0
cursor.execute("""
SELECT COUNT(*) as recent FROM network_logs
WHERE timestamp >= NOW() - INTERVAL '1 hour'
""")
result = cursor.fetchone()
recent_logs = result['recent'] if result else 0
# Detection stats
cursor.execute("SELECT COUNT(*) as total FROM detections")
result = cursor.fetchone()
total_detections = result['total'] if result else 0
cursor.execute("""
SELECT COUNT(*) as blocked FROM detections WHERE blocked = true
""")
result = cursor.fetchone()
blocked_ips = result['blocked'] if result else 0
# Router stats
cursor.execute("SELECT COUNT(*) as total FROM routers WHERE enabled = true")
result = cursor.fetchone()
active_routers = result['total'] if result else 0
# Latest training
cursor.execute("""
SELECT * FROM training_history
ORDER BY trained_at DESC LIMIT 1
""")
latest_training = cursor.fetchone()
cursor.close()
conn.close()
return {
"logs": {
"total": total_logs,
"last_hour": recent_logs
},
"detections": {
"total": total_detections,
"blocked": blocked_ips
},
"routers": {
"active": active_routers
},
"latest_training": dict(latest_training) if latest_training else None
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Service Management Endpoints (Secured with API Key)
@app.get("/services/status")
async def get_services_status(authorized: bool = Security(verify_api_key)):
"""
Verifica status di tutti i servizi gestiti da systemd
RICHIEDE: API Key valida per accesso
"""
try:
import subprocess
services_to_check = {
"ml_backend": "ids-ml-backend",
"syslog_parser": "ids-syslog-parser"
}
results = {}
for service_name, systemd_unit in services_to_check.items():
try:
# Check systemd service status
result = subprocess.run(
["systemctl", "is-active", systemd_unit],
capture_output=True,
text=True,
timeout=5
)
is_active = result.stdout.strip() == "active"
# Get more details if active
details = {}
if is_active:
status_result = subprocess.run(
["systemctl", "status", systemd_unit, "--no-pager"],
capture_output=True,
text=True,
timeout=5
)
# Extract PID from status if available
for line in status_result.stdout.split('\n'):
if 'Main PID:' in line:
pid = line.split('Main PID:')[1].strip().split()[0]
details['pid'] = pid
break
results[service_name] = {
"running": is_active,
"systemd_unit": systemd_unit,
"details": details
}
except subprocess.TimeoutExpired:
results[service_name] = {
"running": False,
"error": "Timeout checking service"
}
except FileNotFoundError:
# systemctl not available (likely development environment)
results[service_name] = {
"running": False,
"error": "systemd not available"
}
except Exception as e:
results[service_name] = {
"running": False,
"error": str(e)
}
return {
"services": results,
"timestamp": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Prova a caricare modello esistente
ml_analyzer.load_model()
print("🚀 Starting IDS API on http://0.0.0.0:8000")
print("📚 Docs available at http://0.0.0.0:8000/docs")
uvicorn.run(app, host="0.0.0.0", port=8000)