Update the get_stats function in main.py to safely fetch and process counts from the database, preventing potential errors when no records are found for total logs, recent logs, detections, blocked IPs, and active routers. Replit-Commit-Author: Agent Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528 Replit-Commit-Checkpoint-Type: full_checkpoint Replit-Commit-Event-Id: 853b2085-c74d-4b3d-adeb-9db4276a24aa Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/1P26v7M
456 lines
14 KiB
Python
456 lines
14 KiB
Python
"""
|
|
IDS Backend FastAPI - Intrusion Detection System
|
|
Gestisce training ML, detection real-time e comunicazione con router MikroTik
|
|
"""
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional, Dict
|
|
from datetime import datetime, timedelta
|
|
import pandas as pd
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import asyncio
|
|
|
|
from ml_analyzer import MLAnalyzer
|
|
from mikrotik_manager import MikroTikManager
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
app = FastAPI(title="IDS API", version="1.0.0")
|
|
|
|
# CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Global instances
|
|
ml_analyzer = MLAnalyzer(model_dir="models")
|
|
mikrotik_manager = MikroTikManager()
|
|
|
|
# Database connection
|
|
def get_db_connection():
|
|
return psycopg2.connect(
|
|
host=os.getenv("PGHOST"),
|
|
port=os.getenv("PGPORT"),
|
|
database=os.getenv("PGDATABASE"),
|
|
user=os.getenv("PGUSER"),
|
|
password=os.getenv("PGPASSWORD")
|
|
)
|
|
|
|
# Pydantic models
|
|
class TrainRequest(BaseModel):
|
|
max_records: int = 10000
|
|
hours_back: int = 24
|
|
contamination: float = 0.01
|
|
|
|
class DetectRequest(BaseModel):
|
|
max_records: int = 5000
|
|
hours_back: int = 1
|
|
risk_threshold: float = 60.0
|
|
auto_block: bool = False
|
|
|
|
class BlockIPRequest(BaseModel):
|
|
ip_address: str
|
|
list_name: str = "ddos_blocked"
|
|
comment: Optional[str] = None
|
|
timeout_duration: str = "1h"
|
|
|
|
class UnblockIPRequest(BaseModel):
|
|
ip_address: str
|
|
list_name: str = "ddos_blocked"
|
|
|
|
|
|
# API Endpoints
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {
|
|
"service": "IDS API",
|
|
"version": "1.0.0",
|
|
"status": "running",
|
|
"model_loaded": ml_analyzer.model is not None
|
|
}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Check system health"""
|
|
try:
|
|
conn = get_db_connection()
|
|
conn.close()
|
|
db_status = "connected"
|
|
except Exception as e:
|
|
db_status = f"error: {str(e)}"
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"database": db_status,
|
|
"ml_model": "loaded" if ml_analyzer.model is not None else "not_loaded",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
@app.post("/train")
|
|
async def train_model(request: TrainRequest, background_tasks: BackgroundTasks):
|
|
"""
|
|
Addestra il modello ML sui log recenti
|
|
Esegue in background per non bloccare l'API
|
|
"""
|
|
def do_training():
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch logs recenti
|
|
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
|
|
query = """
|
|
SELECT * FROM network_logs
|
|
WHERE timestamp >= %s
|
|
ORDER BY timestamp DESC
|
|
LIMIT %s
|
|
"""
|
|
cursor.execute(query, (min_timestamp, request.max_records))
|
|
logs = cursor.fetchall()
|
|
|
|
if not logs:
|
|
print("[TRAIN] Nessun log trovato per training")
|
|
return
|
|
|
|
# Converti in DataFrame
|
|
df = pd.DataFrame(logs)
|
|
|
|
# Training
|
|
result = ml_analyzer.train(df, contamination=request.contamination)
|
|
|
|
# Salva nel database
|
|
cursor.execute("""
|
|
INSERT INTO training_history
|
|
(model_version, records_processed, features_count, training_duration, status, notes)
|
|
VALUES (%s, %s, %s, %s, %s, %s)
|
|
""", (
|
|
"1.0.0",
|
|
result['records_processed'],
|
|
result['features_count'],
|
|
0, # duration non ancora implementato
|
|
result['status'],
|
|
f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']}"
|
|
))
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
print(f"[TRAIN] Completato: {result}")
|
|
|
|
except Exception as e:
|
|
print(f"[TRAIN ERROR] {e}")
|
|
|
|
# Esegui in background
|
|
background_tasks.add_task(do_training)
|
|
|
|
return {
|
|
"message": "Training avviato in background",
|
|
"max_records": request.max_records,
|
|
"hours_back": request.hours_back
|
|
}
|
|
|
|
@app.post("/detect")
|
|
async def detect_anomalies(request: DetectRequest):
|
|
"""
|
|
Rileva anomalie nei log recenti
|
|
Opzionalmente blocca automaticamente IP anomali
|
|
"""
|
|
if ml_analyzer.model is None:
|
|
# Prova a caricare modello salvato
|
|
if not ml_analyzer.load_model():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Modello non addestrato. Esegui /train prima."
|
|
)
|
|
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch logs recenti
|
|
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
|
|
query = """
|
|
SELECT * FROM network_logs
|
|
WHERE timestamp >= %s
|
|
ORDER BY timestamp DESC
|
|
LIMIT %s
|
|
"""
|
|
cursor.execute(query, (min_timestamp, request.max_records))
|
|
logs = cursor.fetchall()
|
|
|
|
if not logs:
|
|
return {"detections": [], "message": "Nessun log da analizzare"}
|
|
|
|
# Converti in DataFrame
|
|
df = pd.DataFrame(logs)
|
|
|
|
# Detection
|
|
detections = ml_analyzer.detect(df, risk_threshold=request.risk_threshold)
|
|
|
|
# Salva detections nel database
|
|
for det in detections:
|
|
# Controlla se già esiste
|
|
cursor.execute(
|
|
"SELECT id FROM detections WHERE source_ip = %s ORDER BY detected_at DESC LIMIT 1",
|
|
(det['source_ip'],)
|
|
)
|
|
existing = cursor.fetchone()
|
|
|
|
if existing:
|
|
# Aggiorna esistente
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET risk_score = %s, confidence = %s, anomaly_type = %s,
|
|
reason = %s, log_count = %s, last_seen = %s
|
|
WHERE id = %s
|
|
""", (
|
|
det['risk_score'], det['confidence'], det['anomaly_type'],
|
|
det['reason'], det['log_count'], det['last_seen'],
|
|
existing['id']
|
|
))
|
|
else:
|
|
# Inserisci nuovo
|
|
cursor.execute("""
|
|
INSERT INTO detections
|
|
(source_ip, risk_score, confidence, anomaly_type, reason,
|
|
log_count, first_seen, last_seen)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
""", (
|
|
det['source_ip'], det['risk_score'], det['confidence'],
|
|
det['anomaly_type'], det['reason'], det['log_count'],
|
|
det['first_seen'], det['last_seen']
|
|
))
|
|
|
|
conn.commit()
|
|
|
|
# Auto-block se richiesto
|
|
blocked_count = 0
|
|
if request.auto_block and detections:
|
|
# Fetch routers abilitati
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if routers:
|
|
for det in detections:
|
|
if det['risk_score'] >= 80: # Solo rischio CRITICO
|
|
# Controlla whitelist
|
|
cursor.execute(
|
|
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
|
|
(det['source_ip'],)
|
|
)
|
|
if cursor.fetchone():
|
|
continue # Skip whitelisted
|
|
|
|
# Blocca su tutti i router
|
|
results = await mikrotik_manager.block_ip_on_all_routers(
|
|
routers,
|
|
det['source_ip'],
|
|
comment=f"IDS: {det['anomaly_type']}"
|
|
)
|
|
|
|
if any(results.values()):
|
|
# Aggiorna detection
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET blocked = true, blocked_at = NOW()
|
|
WHERE source_ip = %s
|
|
""", (det['source_ip'],))
|
|
blocked_count += 1
|
|
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"detections": detections,
|
|
"total": len(detections),
|
|
"blocked": blocked_count if request.auto_block else 0,
|
|
"message": f"Trovate {len(detections)} anomalie"
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/block-ip")
|
|
async def block_ip(request: BlockIPRequest):
|
|
"""Blocca manualmente un IP su tutti i router"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Controlla whitelist
|
|
cursor.execute(
|
|
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
|
|
(request.ip_address,)
|
|
)
|
|
if cursor.fetchone():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"IP {request.ip_address} è in whitelist"
|
|
)
|
|
|
|
# Fetch routers
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if not routers:
|
|
raise HTTPException(status_code=400, detail="Nessun router configurato")
|
|
|
|
# Blocca su tutti i router
|
|
results = await mikrotik_manager.block_ip_on_all_routers(
|
|
routers,
|
|
request.ip_address,
|
|
list_name=request.list_name,
|
|
comment=request.comment or "Manual block",
|
|
timeout_duration=request.timeout_duration
|
|
)
|
|
|
|
success_count = sum(1 for v in results.values() if v)
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"ip_address": request.ip_address,
|
|
"blocked_on": success_count,
|
|
"total_routers": len(routers),
|
|
"results": results
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/unblock-ip")
|
|
async def unblock_ip(request: UnblockIPRequest):
|
|
"""Sblocca un IP da tutti i router"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch routers
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if not routers:
|
|
raise HTTPException(status_code=400, detail="Nessun router configurato")
|
|
|
|
# Sblocca da tutti i router
|
|
results = await mikrotik_manager.unblock_ip_on_all_routers(
|
|
routers,
|
|
request.ip_address,
|
|
list_name=request.list_name
|
|
)
|
|
|
|
success_count = sum(1 for v in results.values() if v)
|
|
|
|
# Aggiorna database
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET blocked = false
|
|
WHERE source_ip = %s
|
|
""", (request.ip_address,))
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"ip_address": request.ip_address,
|
|
"unblocked_from": success_count,
|
|
"total_routers": len(routers),
|
|
"results": results
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/stats")
|
|
async def get_stats():
|
|
"""Statistiche sistema"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Log stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM network_logs")
|
|
result = cursor.fetchone()
|
|
total_logs = result['total'] if result else 0
|
|
|
|
cursor.execute("""
|
|
SELECT COUNT(*) as recent FROM network_logs
|
|
WHERE timestamp >= NOW() - INTERVAL '1 hour'
|
|
""")
|
|
result = cursor.fetchone()
|
|
recent_logs = result['recent'] if result else 0
|
|
|
|
# Detection stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM detections")
|
|
result = cursor.fetchone()
|
|
total_detections = result['total'] if result else 0
|
|
|
|
cursor.execute("""
|
|
SELECT COUNT(*) as blocked FROM detections WHERE blocked = true
|
|
""")
|
|
result = cursor.fetchone()
|
|
blocked_ips = result['blocked'] if result else 0
|
|
|
|
# Router stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM routers WHERE enabled = true")
|
|
result = cursor.fetchone()
|
|
active_routers = result['total'] if result else 0
|
|
|
|
# Latest training
|
|
cursor.execute("""
|
|
SELECT * FROM training_history
|
|
ORDER BY trained_at DESC LIMIT 1
|
|
""")
|
|
latest_training = cursor.fetchone()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"logs": {
|
|
"total": total_logs,
|
|
"last_hour": recent_logs
|
|
},
|
|
"detections": {
|
|
"total": total_detections,
|
|
"blocked": blocked_ips
|
|
},
|
|
"routers": {
|
|
"active": active_routers
|
|
},
|
|
"latest_training": dict(latest_training) if latest_training else None
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# Prova a caricare modello esistente
|
|
ml_analyzer.load_model()
|
|
|
|
print("🚀 Starting IDS API on http://0.0.0.0:8000")
|
|
print("📚 Docs available at http://0.0.0.0:8000/docs")
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|