Integrates IP geolocation and Autonomous System (AS) information into detection records by modifying the frontend to display this data and updating the backend to perform asynchronous batch lookups for efficiency. This enhancement includes database schema updates and the creation of a new IP geolocation service. Replit-Commit-Author: Agent Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528 Replit-Commit-Checkpoint-Type: intermediate_checkpoint Replit-Commit-Event-Id: e81fd4a1-b7b0-48d2-ae38-f5905e278343 Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/SXFWABi
607 lines
20 KiB
Python
607 lines
20 KiB
Python
"""
|
|
IDS Backend FastAPI - Intrusion Detection System
|
|
Gestisce training ML, detection real-time e comunicazione con router MikroTik
|
|
"""
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Security, Header
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import APIKeyHeader
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional, Dict
|
|
from datetime import datetime, timedelta
|
|
import pandas as pd
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import asyncio
|
|
import secrets
|
|
|
|
from ml_analyzer import MLAnalyzer
|
|
from mikrotik_manager import MikroTikManager
|
|
from ip_geolocation import get_geo_service
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# API Key Security
|
|
API_KEY_NAME = "X-API-Key"
|
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
|
|
|
def get_api_key():
|
|
"""Get API key from environment"""
|
|
return os.getenv("IDS_API_KEY")
|
|
|
|
async def verify_api_key(api_key: str = Security(api_key_header)):
|
|
"""Verify API key for internal service communication"""
|
|
expected_key = get_api_key()
|
|
|
|
# In development without API key, allow access (backward compatibility)
|
|
if not expected_key:
|
|
return True
|
|
|
|
if not api_key or api_key != expected_key:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Invalid or missing API key"
|
|
)
|
|
return True
|
|
|
|
app = FastAPI(title="IDS API", version="1.0.0")
|
|
|
|
# CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Global instances
|
|
ml_analyzer = MLAnalyzer(model_dir="models")
|
|
mikrotik_manager = MikroTikManager()
|
|
|
|
# Database connection
|
|
def get_db_connection():
|
|
return psycopg2.connect(
|
|
host=os.getenv("PGHOST"),
|
|
port=os.getenv("PGPORT"),
|
|
database=os.getenv("PGDATABASE"),
|
|
user=os.getenv("PGUSER"),
|
|
password=os.getenv("PGPASSWORD")
|
|
)
|
|
|
|
# Pydantic models
|
|
class TrainRequest(BaseModel):
|
|
max_records: int = 10000
|
|
hours_back: int = 24
|
|
contamination: float = 0.01
|
|
|
|
class DetectRequest(BaseModel):
|
|
max_records: int = 5000
|
|
hours_back: int = 1
|
|
risk_threshold: float = 60.0
|
|
auto_block: bool = False
|
|
|
|
class BlockIPRequest(BaseModel):
|
|
ip_address: str
|
|
list_name: str = "ddos_blocked"
|
|
comment: Optional[str] = None
|
|
timeout_duration: str = "1h"
|
|
|
|
class UnblockIPRequest(BaseModel):
|
|
ip_address: str
|
|
list_name: str = "ddos_blocked"
|
|
|
|
|
|
# API Endpoints
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {
|
|
"service": "IDS API",
|
|
"version": "1.0.0",
|
|
"status": "running",
|
|
"model_loaded": ml_analyzer.model is not None
|
|
}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Check system health"""
|
|
try:
|
|
conn = get_db_connection()
|
|
conn.close()
|
|
db_status = "connected"
|
|
except Exception as e:
|
|
db_status = f"error: {str(e)}"
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"database": db_status,
|
|
"ml_model": "loaded" if ml_analyzer.model is not None else "not_loaded",
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
@app.post("/train")
|
|
async def train_model(request: TrainRequest, background_tasks: BackgroundTasks):
|
|
"""
|
|
Addestra il modello ML sui log recenti
|
|
Esegue in background per non bloccare l'API
|
|
"""
|
|
def do_training():
|
|
conn = None
|
|
cursor = None
|
|
try:
|
|
print("[TRAIN] Inizio training...")
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch logs recenti
|
|
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
|
|
query = """
|
|
SELECT * FROM network_logs
|
|
WHERE timestamp >= %s
|
|
ORDER BY timestamp DESC
|
|
LIMIT %s
|
|
"""
|
|
cursor.execute(query, (min_timestamp, request.max_records))
|
|
logs = cursor.fetchall()
|
|
|
|
if not logs:
|
|
print("[TRAIN] Nessun log trovato per training")
|
|
return
|
|
|
|
print(f"[TRAIN] Trovati {len(logs)} log per training")
|
|
|
|
# Converti in DataFrame
|
|
df = pd.DataFrame(logs)
|
|
|
|
# Training
|
|
print("[TRAIN] Addestramento modello...")
|
|
result = ml_analyzer.train(df, contamination=request.contamination)
|
|
print(f"[TRAIN] Modello addestrato: {result}")
|
|
|
|
# Salva nel database
|
|
print("[TRAIN] Salvataggio training history nel database...")
|
|
cursor.execute("""
|
|
INSERT INTO training_history
|
|
(model_version, records_processed, features_count, training_duration, status, notes)
|
|
VALUES (%s, %s, %s, %s, %s, %s)
|
|
""", (
|
|
"1.0.0",
|
|
result['records_processed'],
|
|
result['features_count'],
|
|
0, # duration non ancora implementato
|
|
result['status'],
|
|
f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']}"
|
|
))
|
|
conn.commit()
|
|
print("[TRAIN] ✅ Training history salvato con successo!")
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
print(f"[TRAIN] Completato: {result}")
|
|
|
|
except Exception as e:
|
|
print(f"[TRAIN ERROR] ❌ Errore durante training: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
if conn:
|
|
conn.rollback()
|
|
finally:
|
|
if cursor:
|
|
cursor.close()
|
|
if conn:
|
|
conn.close()
|
|
|
|
# Esegui in background
|
|
background_tasks.add_task(do_training)
|
|
|
|
return {
|
|
"message": "Training avviato in background",
|
|
"max_records": request.max_records,
|
|
"hours_back": request.hours_back
|
|
}
|
|
|
|
@app.post("/detect")
|
|
async def detect_anomalies(request: DetectRequest):
|
|
"""
|
|
Rileva anomalie nei log recenti
|
|
Opzionalmente blocca automaticamente IP anomali
|
|
"""
|
|
if ml_analyzer.model is None:
|
|
# Prova a caricare modello salvato
|
|
if not ml_analyzer.load_model():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Modello non addestrato. Esegui /train prima."
|
|
)
|
|
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch logs recenti
|
|
min_timestamp = datetime.now() - timedelta(hours=request.hours_back)
|
|
query = """
|
|
SELECT * FROM network_logs
|
|
WHERE timestamp >= %s
|
|
ORDER BY timestamp DESC
|
|
LIMIT %s
|
|
"""
|
|
cursor.execute(query, (min_timestamp, request.max_records))
|
|
logs = cursor.fetchall()
|
|
|
|
if not logs:
|
|
return {"detections": [], "message": "Nessun log da analizzare"}
|
|
|
|
# Converti in DataFrame
|
|
df = pd.DataFrame(logs)
|
|
|
|
# Detection
|
|
detections = ml_analyzer.detect(df, risk_threshold=request.risk_threshold)
|
|
|
|
# Geolocation lookup service - BATCH ASYNC per performance
|
|
geo_service = get_geo_service()
|
|
|
|
# Estrai lista IP unici per batch lookup
|
|
unique_ips = list(set(det['source_ip'] for det in detections))
|
|
|
|
# Batch lookup async (VELOCE - tutti in parallelo!)
|
|
geo_results = await geo_service.lookup_batch_async(unique_ips)
|
|
|
|
# Salva detections nel database
|
|
for det in detections:
|
|
# Get geo info from batch results
|
|
geo_info = geo_results.get(det['source_ip'])
|
|
|
|
# Controlla se già esiste
|
|
cursor.execute(
|
|
"SELECT id FROM detections WHERE source_ip = %s ORDER BY detected_at DESC LIMIT 1",
|
|
(det['source_ip'],)
|
|
)
|
|
existing = cursor.fetchone()
|
|
|
|
if existing:
|
|
# Aggiorna esistente (con geo info)
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET risk_score = %s, confidence = %s, anomaly_type = %s,
|
|
reason = %s, log_count = %s, last_seen = %s,
|
|
country = %s, country_code = %s, city = %s,
|
|
organization = %s, as_number = %s, as_name = %s, isp = %s
|
|
WHERE id = %s
|
|
""", (
|
|
det['risk_score'], det['confidence'], det['anomaly_type'],
|
|
det['reason'], det['log_count'], det['last_seen'],
|
|
geo_info.get('country') if geo_info else None,
|
|
geo_info.get('country_code') if geo_info else None,
|
|
geo_info.get('city') if geo_info else None,
|
|
geo_info.get('organization') if geo_info else None,
|
|
geo_info.get('as_number') if geo_info else None,
|
|
geo_info.get('as_name') if geo_info else None,
|
|
geo_info.get('isp') if geo_info else None,
|
|
existing['id']
|
|
))
|
|
else:
|
|
# Inserisci nuovo (con geo info)
|
|
cursor.execute("""
|
|
INSERT INTO detections
|
|
(source_ip, risk_score, confidence, anomaly_type, reason,
|
|
log_count, first_seen, last_seen,
|
|
country, country_code, city, organization, as_number, as_name, isp)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
""", (
|
|
det['source_ip'], det['risk_score'], det['confidence'],
|
|
det['anomaly_type'], det['reason'], det['log_count'],
|
|
det['first_seen'], det['last_seen'],
|
|
geo_info.get('country') if geo_info else None,
|
|
geo_info.get('country_code') if geo_info else None,
|
|
geo_info.get('city') if geo_info else None,
|
|
geo_info.get('organization') if geo_info else None,
|
|
geo_info.get('as_number') if geo_info else None,
|
|
geo_info.get('as_name') if geo_info else None,
|
|
geo_info.get('isp') if geo_info else None
|
|
))
|
|
|
|
conn.commit()
|
|
|
|
# Auto-block se richiesto
|
|
blocked_count = 0
|
|
if request.auto_block and detections:
|
|
# Fetch routers abilitati
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if routers:
|
|
for det in detections:
|
|
if det['risk_score'] >= 80: # Solo rischio CRITICO
|
|
# Controlla whitelist
|
|
cursor.execute(
|
|
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
|
|
(det['source_ip'],)
|
|
)
|
|
if cursor.fetchone():
|
|
continue # Skip whitelisted
|
|
|
|
# Blocca su tutti i router
|
|
results = await mikrotik_manager.block_ip_on_all_routers(
|
|
routers,
|
|
det['source_ip'],
|
|
comment=f"IDS: {det['anomaly_type']}"
|
|
)
|
|
|
|
if any(results.values()):
|
|
# Aggiorna detection
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET blocked = true, blocked_at = NOW()
|
|
WHERE source_ip = %s
|
|
""", (det['source_ip'],))
|
|
blocked_count += 1
|
|
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"detections": detections,
|
|
"total": len(detections),
|
|
"blocked": blocked_count if request.auto_block else 0,
|
|
"message": f"Trovate {len(detections)} anomalie"
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/block-ip")
|
|
async def block_ip(request: BlockIPRequest):
|
|
"""Blocca manualmente un IP su tutti i router"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Controlla whitelist
|
|
cursor.execute(
|
|
"SELECT id FROM whitelist WHERE ip_address = %s AND active = true",
|
|
(request.ip_address,)
|
|
)
|
|
if cursor.fetchone():
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"IP {request.ip_address} è in whitelist"
|
|
)
|
|
|
|
# Fetch routers
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if not routers:
|
|
raise HTTPException(status_code=400, detail="Nessun router configurato")
|
|
|
|
# Blocca su tutti i router
|
|
results = await mikrotik_manager.block_ip_on_all_routers(
|
|
routers,
|
|
request.ip_address,
|
|
list_name=request.list_name,
|
|
comment=request.comment or "Manual block",
|
|
timeout_duration=request.timeout_duration
|
|
)
|
|
|
|
success_count = sum(1 for v in results.values() if v)
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"ip_address": request.ip_address,
|
|
"blocked_on": success_count,
|
|
"total_routers": len(routers),
|
|
"results": results
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/unblock-ip")
|
|
async def unblock_ip(request: UnblockIPRequest):
|
|
"""Sblocca un IP da tutti i router"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Fetch routers
|
|
cursor.execute("SELECT * FROM routers WHERE enabled = true")
|
|
routers = cursor.fetchall()
|
|
|
|
if not routers:
|
|
raise HTTPException(status_code=400, detail="Nessun router configurato")
|
|
|
|
# Sblocca da tutti i router
|
|
results = await mikrotik_manager.unblock_ip_on_all_routers(
|
|
routers,
|
|
request.ip_address,
|
|
list_name=request.list_name
|
|
)
|
|
|
|
success_count = sum(1 for v in results.values() if v)
|
|
|
|
# Aggiorna database
|
|
cursor.execute("""
|
|
UPDATE detections
|
|
SET blocked = false
|
|
WHERE source_ip = %s
|
|
""", (request.ip_address,))
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"ip_address": request.ip_address,
|
|
"unblocked_from": success_count,
|
|
"total_routers": len(routers),
|
|
"results": results
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/stats")
|
|
async def get_stats():
|
|
"""Statistiche sistema"""
|
|
try:
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(cursor_factory=RealDictCursor)
|
|
|
|
# Log stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM network_logs")
|
|
result = cursor.fetchone()
|
|
total_logs = result['total'] if result else 0
|
|
|
|
cursor.execute("""
|
|
SELECT COUNT(*) as recent FROM network_logs
|
|
WHERE timestamp >= NOW() - INTERVAL '1 hour'
|
|
""")
|
|
result = cursor.fetchone()
|
|
recent_logs = result['recent'] if result else 0
|
|
|
|
# Detection stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM detections")
|
|
result = cursor.fetchone()
|
|
total_detections = result['total'] if result else 0
|
|
|
|
cursor.execute("""
|
|
SELECT COUNT(*) as blocked FROM detections WHERE blocked = true
|
|
""")
|
|
result = cursor.fetchone()
|
|
blocked_ips = result['blocked'] if result else 0
|
|
|
|
# Router stats
|
|
cursor.execute("SELECT COUNT(*) as total FROM routers WHERE enabled = true")
|
|
result = cursor.fetchone()
|
|
active_routers = result['total'] if result else 0
|
|
|
|
# Latest training
|
|
cursor.execute("""
|
|
SELECT * FROM training_history
|
|
ORDER BY trained_at DESC LIMIT 1
|
|
""")
|
|
latest_training = cursor.fetchone()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return {
|
|
"logs": {
|
|
"total": total_logs,
|
|
"last_hour": recent_logs
|
|
},
|
|
"detections": {
|
|
"total": total_detections,
|
|
"blocked": blocked_ips
|
|
},
|
|
"routers": {
|
|
"active": active_routers
|
|
},
|
|
"latest_training": dict(latest_training) if latest_training else None
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# Service Management Endpoints (Secured with API Key)
|
|
|
|
@app.get("/services/status")
|
|
async def get_services_status(authorized: bool = Security(verify_api_key)):
|
|
"""
|
|
Verifica status di tutti i servizi gestiti da systemd
|
|
RICHIEDE: API Key valida per accesso
|
|
"""
|
|
try:
|
|
import subprocess
|
|
|
|
services_to_check = {
|
|
"ml_backend": "ids-ml-backend",
|
|
"syslog_parser": "ids-syslog-parser"
|
|
}
|
|
|
|
results = {}
|
|
|
|
for service_name, systemd_unit in services_to_check.items():
|
|
try:
|
|
# Check systemd service status
|
|
result = subprocess.run(
|
|
["systemctl", "is-active", systemd_unit],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5
|
|
)
|
|
|
|
is_active = result.stdout.strip() == "active"
|
|
|
|
# Get more details if active
|
|
details = {}
|
|
if is_active:
|
|
status_result = subprocess.run(
|
|
["systemctl", "status", systemd_unit, "--no-pager"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5
|
|
)
|
|
# Extract PID from status if available
|
|
for line in status_result.stdout.split('\n'):
|
|
if 'Main PID:' in line:
|
|
pid = line.split('Main PID:')[1].strip().split()[0]
|
|
details['pid'] = pid
|
|
break
|
|
|
|
results[service_name] = {
|
|
"running": is_active,
|
|
"systemd_unit": systemd_unit,
|
|
"details": details
|
|
}
|
|
|
|
except subprocess.TimeoutExpired:
|
|
results[service_name] = {
|
|
"running": False,
|
|
"error": "Timeout checking service"
|
|
}
|
|
except FileNotFoundError:
|
|
# systemctl not available (likely development environment)
|
|
results[service_name] = {
|
|
"running": False,
|
|
"error": "systemd not available"
|
|
}
|
|
except Exception as e:
|
|
results[service_name] = {
|
|
"running": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
return {
|
|
"services": results,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
# Prova a caricare modello esistente
|
|
ml_analyzer.load_model()
|
|
|
|
print("🚀 Starting IDS API on http://0.0.0.0:8000")
|
|
print("📚 Docs available at http://0.0.0.0:8000/docs")
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|