""" IDS Backend FastAPI - Intrusion Detection System Gestisce training ML, detection real-time e comunicazione con router MikroTik """ from fastapi import FastAPI, HTTPException, BackgroundTasks, Security, Header from fastapi.middleware.cors import CORSMiddleware from fastapi.security import APIKeyHeader from pydantic import BaseModel from typing import List, Optional, Dict from datetime import datetime, timedelta import pandas as pd import psycopg2 from psycopg2.extras import RealDictCursor import os from dotenv import load_dotenv import asyncio import secrets from ml_analyzer import MLAnalyzer from mikrotik_manager import MikroTikManager # Load environment variables load_dotenv() # API Key Security API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) def get_api_key(): """Get API key from environment""" return os.getenv("IDS_API_KEY") async def verify_api_key(api_key: str = Security(api_key_header)): """Verify API key for internal service communication""" expected_key = get_api_key() # In development without API key, allow access (backward compatibility) if not expected_key: return True if not api_key or api_key != expected_key: raise HTTPException( status_code=403, detail="Invalid or missing API key" ) return True app = FastAPI(title="IDS API", version="1.0.0") # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global instances ml_analyzer = MLAnalyzer(model_dir="models") mikrotik_manager = MikroTikManager() # Database connection def get_db_connection(): return psycopg2.connect( host=os.getenv("PGHOST"), port=os.getenv("PGPORT"), database=os.getenv("PGDATABASE"), user=os.getenv("PGUSER"), password=os.getenv("PGPASSWORD") ) # Pydantic models class TrainRequest(BaseModel): max_records: int = 10000 hours_back: int = 24 contamination: float = 0.01 class DetectRequest(BaseModel): max_records: int = 5000 hours_back: int = 1 risk_threshold: float = 60.0 auto_block: bool = False class BlockIPRequest(BaseModel): ip_address: str list_name: str = "ddos_blocked" comment: Optional[str] = None timeout_duration: str = "1h" class UnblockIPRequest(BaseModel): ip_address: str list_name: str = "ddos_blocked" # API Endpoints @app.get("/") async def root(): return { "service": "IDS API", "version": "1.0.0", "status": "running", "model_loaded": ml_analyzer.model is not None } @app.get("/health") async def health_check(): """Check system health""" try: conn = get_db_connection() conn.close() db_status = "connected" except Exception as e: db_status = f"error: {str(e)}" return { "status": "healthy", "database": db_status, "ml_model": "loaded" if ml_analyzer.model is not None else "not_loaded", "timestamp": datetime.now().isoformat() } @app.post("/train") async def train_model(request: TrainRequest, background_tasks: BackgroundTasks): """ Addestra il modello ML sui log recenti Esegue in background per non bloccare l'API """ def do_training(): conn = None cursor = None try: print("[TRAIN] Inizio training...") conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Fetch logs recenti min_timestamp = datetime.now() - timedelta(hours=request.hours_back) query = """ SELECT * FROM network_logs WHERE timestamp >= %s ORDER BY timestamp DESC LIMIT %s """ cursor.execute(query, (min_timestamp, request.max_records)) logs = cursor.fetchall() if not logs: print("[TRAIN] Nessun log trovato per training") return print(f"[TRAIN] Trovati {len(logs)} log per training") # Converti in DataFrame df = pd.DataFrame(logs) # Training print("[TRAIN] Addestramento modello...") result = ml_analyzer.train(df, contamination=request.contamination) print(f"[TRAIN] Modello addestrato: {result}") # Salva nel database print("[TRAIN] Salvataggio training history nel database...") cursor.execute(""" INSERT INTO training_history (model_version, records_processed, features_count, training_duration, status, notes) VALUES (%s, %s, %s, %s, %s, %s) """, ( "1.0.0", result['records_processed'], result['features_count'], 0, # duration non ancora implementato result['status'], f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']}" )) conn.commit() print("[TRAIN] โœ… Training history salvato con successo!") cursor.close() conn.close() print(f"[TRAIN] Completato: {result}") except Exception as e: print(f"[TRAIN ERROR] โŒ Errore durante training: {e}") import traceback traceback.print_exc() if conn: conn.rollback() finally: if cursor: cursor.close() if conn: conn.close() # Esegui in background background_tasks.add_task(do_training) return { "message": "Training avviato in background", "max_records": request.max_records, "hours_back": request.hours_back } @app.post("/detect") async def detect_anomalies(request: DetectRequest): """ Rileva anomalie nei log recenti Opzionalmente blocca automaticamente IP anomali """ if ml_analyzer.model is None: # Prova a caricare modello salvato if not ml_analyzer.load_model(): raise HTTPException( status_code=400, detail="Modello non addestrato. Esegui /train prima." ) try: conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Fetch logs recenti min_timestamp = datetime.now() - timedelta(hours=request.hours_back) query = """ SELECT * FROM network_logs WHERE timestamp >= %s ORDER BY timestamp DESC LIMIT %s """ cursor.execute(query, (min_timestamp, request.max_records)) logs = cursor.fetchall() if not logs: return {"detections": [], "message": "Nessun log da analizzare"} # Converti in DataFrame df = pd.DataFrame(logs) # Detection detections = ml_analyzer.detect(df, risk_threshold=request.risk_threshold) # Salva detections nel database for det in detections: # Controlla se giร  esiste cursor.execute( "SELECT id FROM detections WHERE source_ip = %s ORDER BY detected_at DESC LIMIT 1", (det['source_ip'],) ) existing = cursor.fetchone() if existing: # Aggiorna esistente cursor.execute(""" UPDATE detections SET risk_score = %s, confidence = %s, anomaly_type = %s, reason = %s, log_count = %s, last_seen = %s WHERE id = %s """, ( det['risk_score'], det['confidence'], det['anomaly_type'], det['reason'], det['log_count'], det['last_seen'], existing['id'] )) else: # Inserisci nuovo cursor.execute(""" INSERT INTO detections (source_ip, risk_score, confidence, anomaly_type, reason, log_count, first_seen, last_seen) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, ( det['source_ip'], det['risk_score'], det['confidence'], det['anomaly_type'], det['reason'], det['log_count'], det['first_seen'], det['last_seen'] )) conn.commit() # Auto-block se richiesto blocked_count = 0 if request.auto_block and detections: # Fetch routers abilitati cursor.execute("SELECT * FROM routers WHERE enabled = true") routers = cursor.fetchall() if routers: for det in detections: if det['risk_score'] >= 80: # Solo rischio CRITICO # Controlla whitelist cursor.execute( "SELECT id FROM whitelist WHERE ip_address = %s AND active = true", (det['source_ip'],) ) if cursor.fetchone(): continue # Skip whitelisted # Blocca su tutti i router results = await mikrotik_manager.block_ip_on_all_routers( routers, det['source_ip'], comment=f"IDS: {det['anomaly_type']}" ) if any(results.values()): # Aggiorna detection cursor.execute(""" UPDATE detections SET blocked = true, blocked_at = NOW() WHERE source_ip = %s """, (det['source_ip'],)) blocked_count += 1 conn.commit() cursor.close() conn.close() return { "detections": detections, "total": len(detections), "blocked": blocked_count if request.auto_block else 0, "message": f"Trovate {len(detections)} anomalie" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/block-ip") async def block_ip(request: BlockIPRequest): """Blocca manualmente un IP su tutti i router""" try: conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Controlla whitelist cursor.execute( "SELECT id FROM whitelist WHERE ip_address = %s AND active = true", (request.ip_address,) ) if cursor.fetchone(): raise HTTPException( status_code=400, detail=f"IP {request.ip_address} รจ in whitelist" ) # Fetch routers cursor.execute("SELECT * FROM routers WHERE enabled = true") routers = cursor.fetchall() if not routers: raise HTTPException(status_code=400, detail="Nessun router configurato") # Blocca su tutti i router results = await mikrotik_manager.block_ip_on_all_routers( routers, request.ip_address, list_name=request.list_name, comment=request.comment or "Manual block", timeout_duration=request.timeout_duration ) success_count = sum(1 for v in results.values() if v) cursor.close() conn.close() return { "ip_address": request.ip_address, "blocked_on": success_count, "total_routers": len(routers), "results": results } except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/unblock-ip") async def unblock_ip(request: UnblockIPRequest): """Sblocca un IP da tutti i router""" try: conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Fetch routers cursor.execute("SELECT * FROM routers WHERE enabled = true") routers = cursor.fetchall() if not routers: raise HTTPException(status_code=400, detail="Nessun router configurato") # Sblocca da tutti i router results = await mikrotik_manager.unblock_ip_on_all_routers( routers, request.ip_address, list_name=request.list_name ) success_count = sum(1 for v in results.values() if v) # Aggiorna database cursor.execute(""" UPDATE detections SET blocked = false WHERE source_ip = %s """, (request.ip_address,)) conn.commit() cursor.close() conn.close() return { "ip_address": request.ip_address, "unblocked_from": success_count, "total_routers": len(routers), "results": results } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/stats") async def get_stats(): """Statistiche sistema""" try: conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Log stats cursor.execute("SELECT COUNT(*) as total FROM network_logs") result = cursor.fetchone() total_logs = result['total'] if result else 0 cursor.execute(""" SELECT COUNT(*) as recent FROM network_logs WHERE timestamp >= NOW() - INTERVAL '1 hour' """) result = cursor.fetchone() recent_logs = result['recent'] if result else 0 # Detection stats cursor.execute("SELECT COUNT(*) as total FROM detections") result = cursor.fetchone() total_detections = result['total'] if result else 0 cursor.execute(""" SELECT COUNT(*) as blocked FROM detections WHERE blocked = true """) result = cursor.fetchone() blocked_ips = result['blocked'] if result else 0 # Router stats cursor.execute("SELECT COUNT(*) as total FROM routers WHERE enabled = true") result = cursor.fetchone() active_routers = result['total'] if result else 0 # Latest training cursor.execute(""" SELECT * FROM training_history ORDER BY trained_at DESC LIMIT 1 """) latest_training = cursor.fetchone() cursor.close() conn.close() return { "logs": { "total": total_logs, "last_hour": recent_logs }, "detections": { "total": total_detections, "blocked": blocked_ips }, "routers": { "active": active_routers }, "latest_training": dict(latest_training) if latest_training else None } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Service Management Endpoints (Secured with API Key) @app.get("/services/status") async def get_services_status(authorized: bool = Security(verify_api_key)): """ Verifica status di tutti i servizi gestiti da systemd RICHIEDE: API Key valida per accesso """ try: import subprocess services_to_check = { "ml_backend": "ids-ml-backend", "syslog_parser": "ids-syslog-parser" } results = {} for service_name, systemd_unit in services_to_check.items(): try: # Check systemd service status result = subprocess.run( ["systemctl", "is-active", systemd_unit], capture_output=True, text=True, timeout=5 ) is_active = result.stdout.strip() == "active" # Get more details if active details = {} if is_active: status_result = subprocess.run( ["systemctl", "status", systemd_unit, "--no-pager"], capture_output=True, text=True, timeout=5 ) # Extract PID from status if available for line in status_result.stdout.split('\n'): if 'Main PID:' in line: pid = line.split('Main PID:')[1].strip().split()[0] details['pid'] = pid break results[service_name] = { "running": is_active, "systemd_unit": systemd_unit, "details": details } except subprocess.TimeoutExpired: results[service_name] = { "running": False, "error": "Timeout checking service" } except FileNotFoundError: # systemctl not available (likely development environment) results[service_name] = { "running": False, "error": "systemd not available" } except Exception as e: results[service_name] = { "running": False, "error": str(e) } return { "services": results, "timestamp": datetime.now().isoformat() } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn # Prova a caricare modello esistente ml_analyzer.load_model() print("๐Ÿš€ Starting IDS API on http://0.0.0.0:8000") print("๐Ÿ“š Docs available at http://0.0.0.0:8000/docs") uvicorn.run(app, host="0.0.0.0", port=8000)