ids.alfacom.it/python_ml/main.py
marco370 7ba65c9d96 Fix errors when retrieving statistics by handling empty results
Update the get_stats function in main.py to safely fetch and process counts from the database, preventing potential errors when no records are found for total logs, recent logs, detections, blocked IPs, and active routers.

Replit-Commit-Author: Agent
Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528
Replit-Commit-Checkpoint-Type: full_checkpoint
Replit-Commit-Event-Id: 853b2085-c74d-4b3d-adeb-9db4276a24aa
Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/1P26v7M
2025-11-17 18:18:30 +00:00

456 lines
14 KiB
Python

"""
IDS Backend FastAPI - Intrusion Detection System
Gestisce training ML, detection real-time e comunicazione con router MikroTik
"""
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, Dict
from datetime import datetime, timedelta
import pandas as pd
import psycopg2
from psycopg2.extras import RealDictCursor
import os
from dotenv import load_dotenv
import asyncio
from ml_analyzer import MLAnalyzer
from mikrotik_manager import MikroTikManager
# Load environment variables
load_dotenv()
app = FastAPI(title="IDS API", version="1.0.0")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global instances
ml_analyzer = MLAnalyzer(model_dir="models")
mikrotik_manager = MikroTikManager()
# Database connection
def get_db_connection():
return psycopg2.connect(
host=os.getenv("PGHOST"),
port=os.getenv("PGPORT"),
database=os.getenv("PGDATABASE"),
user=os.getenv("PGUSER"),
password=os.getenv("PGPASSWORD")
)
# Pydantic models
class TrainRequest(BaseModel):
max_records: int = 10000
hours_back: int = 24
contamination: float = 0.01
class DetectRequest(BaseModel):
max_records: int = 5000
hours_back: int = 1
risk_threshold: float = 60.0
auto_block: bool = False
class BlockIPRequest(BaseModel):
ip_address: str
list_name: str = "ddos_blocked"
comment: Optional[str] = None
timeout_duration: str = "1h"
class UnblockIPRequest(BaseModel):
ip_address: str
list_name: str = "ddos_blocked"
# API Endpoints
@app.get("/")
async def root():
return {
"service": "IDS API",
"version": "1.0.0",
"status": "running",
"model_loaded": ml_analyzer.model is not None
}
@app.get("/health")
async def health_check():
"""Check system health"""
try:
conn = get_db_connection()
conn.close()
db_status = "connected"
except Exception as e:
db_status = f"error: {str(e)}"
return {
"status": "healthy",
"database": db_status,
"ml_model": "loaded" if ml_analyzer.model is not None else "not_loaded",
"timestamp": datetime.now().isoformat()
}
@app.post("/train")
async def train_model(request: TrainRequest, background_tasks: BackgroundTasks):
"""
Addestra il modello ML sui log recenti
Esegue in background per non bloccare l'API
"""
def do_training():
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch logs recenti
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
query = """
SELECT * FROM network_logs
WHERE timestamp >= %s
ORDER BY timestamp DESC
LIMIT %s
"""
cursor.execute(query, (min_timestamp, request.max_records))
logs = cursor.fetchall()
if not logs:
print("[TRAIN] Nessun log trovato per training")
return
# Converti in DataFrame
df = pd.DataFrame(logs)
# Training
result = ml_analyzer.train(df, contamination=request.contamination)
# Salva nel database
cursor.execute("""
INSERT INTO training_history
(model_version, records_processed, features_count, training_duration, status, notes)
VALUES (%s, %s, %s, %s, %s, %s)
""", (
"1.0.0",
result['records_processed'],
result['features_count'],
0, # duration non ancora implementato
result['status'],
f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']}"
))
conn.commit()
cursor.close()
conn.close()
print(f"[TRAIN] Completato: {result}")
except Exception as e:
print(f"[TRAIN ERROR] {e}")
# Esegui in background
background_tasks.add_task(do_training)
return {
"message": "Training avviato in background",
"max_records": request.max_records,
"hours_back": request.hours_back
}
@app.post("/detect")
async def detect_anomalies(request: DetectRequest):
"""
Rileva anomalie nei log recenti
Opzionalmente blocca automaticamente IP anomali
"""
if ml_analyzer.model is None:
# Prova a caricare modello salvato
if not ml_analyzer.load_model():
raise HTTPException(
status_code=400,
detail="Modello non addestrato. Esegui /train prima."
)
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch logs recenti
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
query = """
SELECT * FROM network_logs
WHERE timestamp >= %s
ORDER BY timestamp DESC
LIMIT %s
"""
cursor.execute(query, (min_timestamp, request.max_records))
logs = cursor.fetchall()
if not logs:
return {"detections": [], "message": "Nessun log da analizzare"}
# Converti in DataFrame
df = pd.DataFrame(logs)
# Detection
detections = ml_analyzer.detect(df, risk_threshold=request.risk_threshold)
# Salva detections nel database
for det in detections:
# Controlla se già esiste
cursor.execute(
"SELECT id FROM detections WHERE source_ip = %s ORDER BY detected_at DESC LIMIT 1",
(det['source_ip'],)
)
existing = cursor.fetchone()
if existing:
# Aggiorna esistente
cursor.execute("""
UPDATE detections
SET risk_score = %s, confidence = %s, anomaly_type = %s,
reason = %s, log_count = %s, last_seen = %s
WHERE id = %s
""", (
det['risk_score'], det['confidence'], det['anomaly_type'],
det['reason'], det['log_count'], det['last_seen'],
existing['id']
))
else:
# Inserisci nuovo
cursor.execute("""
INSERT INTO detections
(source_ip, risk_score, confidence, anomaly_type, reason,
log_count, first_seen, last_seen)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""", (
det['source_ip'], det['risk_score'], det['confidence'],
det['anomaly_type'], det['reason'], det['log_count'],
det['first_seen'], det['last_seen']
))
conn.commit()
# Auto-block se richiesto
blocked_count = 0
if request.auto_block and detections:
# Fetch routers abilitati
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if routers:
for det in detections:
if det['risk_score'] >= 80: # Solo rischio CRITICO
# Controlla whitelist
cursor.execute(
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
(det['source_ip'],)
)
if cursor.fetchone():
continue # Skip whitelisted
# Blocca su tutti i router
results = await mikrotik_manager.block_ip_on_all_routers(
routers,
det['source_ip'],
comment=f"IDS: {det['anomaly_type']}"
)
if any(results.values()):
# Aggiorna detection
cursor.execute("""
UPDATE detections
SET blocked = true, blocked_at = NOW()
WHERE source_ip = %s
""", (det['source_ip'],))
blocked_count += 1
conn.commit()
cursor.close()
conn.close()
return {
"detections": detections,
"total": len(detections),
"blocked": blocked_count if request.auto_block else 0,
"message": f"Trovate {len(detections)} anomalie"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/block-ip")
async def block_ip(request: BlockIPRequest):
"""Blocca manualmente un IP su tutti i router"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Controlla whitelist
cursor.execute(
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
(request.ip_address,)
)
if cursor.fetchone():
raise HTTPException(
status_code=400,
detail=f"IP {request.ip_address} è in whitelist"
)
# Fetch routers
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if not routers:
raise HTTPException(status_code=400, detail="Nessun router configurato")
# Blocca su tutti i router
results = await mikrotik_manager.block_ip_on_all_routers(
routers,
request.ip_address,
list_name=request.list_name,
comment=request.comment or "Manual block",
timeout_duration=request.timeout_duration
)
success_count = sum(1 for v in results.values() if v)
cursor.close()
conn.close()
return {
"ip_address": request.ip_address,
"blocked_on": success_count,
"total_routers": len(routers),
"results": results
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/unblock-ip")
async def unblock_ip(request: UnblockIPRequest):
"""Sblocca un IP da tutti i router"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Fetch routers
cursor.execute("SELECT * FROM routers WHERE enabled = true")
routers = cursor.fetchall()
if not routers:
raise HTTPException(status_code=400, detail="Nessun router configurato")
# Sblocca da tutti i router
results = await mikrotik_manager.unblock_ip_on_all_routers(
routers,
request.ip_address,
list_name=request.list_name
)
success_count = sum(1 for v in results.values() if v)
# Aggiorna database
cursor.execute("""
UPDATE detections
SET blocked = false
WHERE source_ip = %s
""", (request.ip_address,))
conn.commit()
cursor.close()
conn.close()
return {
"ip_address": request.ip_address,
"unblocked_from": success_count,
"total_routers": len(routers),
"results": results
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stats")
async def get_stats():
"""Statistiche sistema"""
try:
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=RealDictCursor)
# Log stats
cursor.execute("SELECT COUNT(*) as total FROM network_logs")
result = cursor.fetchone()
total_logs = result['total'] if result else 0
cursor.execute("""
SELECT COUNT(*) as recent FROM network_logs
WHERE timestamp >= NOW() - INTERVAL '1 hour'
""")
result = cursor.fetchone()
recent_logs = result['recent'] if result else 0
# Detection stats
cursor.execute("SELECT COUNT(*) as total FROM detections")
result = cursor.fetchone()
total_detections = result['total'] if result else 0
cursor.execute("""
SELECT COUNT(*) as blocked FROM detections WHERE blocked = true
""")
result = cursor.fetchone()
blocked_ips = result['blocked'] if result else 0
# Router stats
cursor.execute("SELECT COUNT(*) as total FROM routers WHERE enabled = true")
result = cursor.fetchone()
active_routers = result['total'] if result else 0
# Latest training
cursor.execute("""
SELECT * FROM training_history
ORDER BY trained_at DESC LIMIT 1
""")
latest_training = cursor.fetchone()
cursor.close()
conn.close()
return {
"logs": {
"total": total_logs,
"last_hour": recent_logs
},
"detections": {
"total": total_detections,
"blocked": blocked_ips
},
"routers": {
"active": active_routers
},
"latest_training": dict(latest_training) if latest_training else None
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
# Prova a caricare modello esistente
ml_analyzer.load_model()
print("🚀 Starting IDS API on http://0.0.0.0:8000")
print("📚 Docs available at http://0.0.0.0:8000/docs")
uvicorn.run(app, host="0.0.0.0", port=8000)