""" IDS Backend FastAPI - Intrusion Detection System Gestisce training ML, detection real-time e comunicazione con router MikroTik """ from fastapi import FastAPI, HTTPException, BackgroundTasks, Security, Header from fastapi.middleware.cors import CORSMiddleware from fastapi.security import APIKeyHeader from pydantic import BaseModel from typing import List, Optional, Dict from datetime import datetime, timedelta import pandas as pd import psycopg2 from psycopg2.extras import RealDictCursor import os from dotenv import load_dotenv import asyncio import secrets from ml_analyzer import MLAnalyzer from ml_hybrid_detector import MLHybridDetector from mikrotik_manager import MikroTikManager from ip_geolocation import get_geo_service # Load environment variables load_dotenv() # API Key Security API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) def get_api_key(): """Get API key from environment""" return os.getenv("IDS_API_KEY") async def verify_api_key(api_key: str = Security(api_key_header)): """Verify API key for internal service communication""" expected_key = get_api_key() # In development without API key, allow access (backward compatibility) if not expected_key: return True if not api_key or api_key != expected_key: raise HTTPException( status_code=403, detail="Invalid or missing API key" ) return True app = FastAPI(title="IDS API", version="2.0.0") # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global instances - Try hybrid first, fallback to legacy USE_HYBRID_DETECTOR = os.getenv("USE_HYBRID_DETECTOR", "true").lower() == "true" if USE_HYBRID_DETECTOR: print("[ML] Using Hybrid ML Detector (Extended Isolation Forest + Feature Selection)") ml_detector = MLHybridDetector(model_dir="models") # Try to load existing model if not ml_detector.load_models(): print("[ML] No hybrid model found, will use on-demand training") ml_analyzer = None # Legacy disabled else: print("[ML] Using Legacy ML Analyzer (standard Isolation Forest)") ml_analyzer = MLAnalyzer(model_dir="models") ml_detector = None mikrotik_manager = MikroTikManager() # Database connection def get_db_connection(): return psycopg2.connect( host=os.getenv("PGHOST"), port=os.getenv("PGPORT"), database=os.getenv("PGDATABASE"), user=os.getenv("PGUSER"), password=os.getenv("PGPASSWORD") ) # Pydantic models class TrainRequest(BaseModel): max_records: int = 10000 hours_back: int = 24 contamination: float = 0.01 class DetectRequest(BaseModel): max_records: int = 5000 hours_back: int = 1 risk_threshold: float = 60.0 auto_block: bool = False class BlockIPRequest(BaseModel): ip_address: str list_name: str = "ddos_blocked" comment: Optional[str] = None timeout_duration: str = "1h" class UnblockIPRequest(BaseModel): ip_address: str list_name: str = "ddos_blocked" # API Endpoints @app.get("/") async def root(): return { "service": "IDS API", "version": "1.0.0", "status": "running", "model_loaded": ml_analyzer.model is not None } @app.get("/health") async def health_check(): """Check system health""" try: conn = get_db_connection() conn.close() db_status = "connected" except Exception as e: db_status = f"error: {str(e)}" return { "status": "healthy", "database": db_status, "ml_model": "loaded" if ml_analyzer.model is not None else "not_loaded", "timestamp": datetime.now().isoformat() } @app.post("/train") async def train_model(request: TrainRequest, background_tasks: BackgroundTasks): """ Addestra il modello ML sui log recenti Esegue in background per non bloccare l'API """ def do_training(): conn = None cursor = None try: print("[TRAIN] Inizio training...") conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Fetch logs recenti min_timestamp = datetime.now() - timedelta(hours=request.hours_back) query = """ SELECT * FROM network_logs WHERE timestamp >= %s ORDER BY timestamp DESC LIMIT %s """ cursor.execute(query, (min_timestamp, request.max_records)) logs = cursor.fetchall() if not logs: print("[TRAIN] Nessun log trovato per training") return print(f"[TRAIN] Trovati {len(logs)} log per training") # Converti in DataFrame df = pd.DataFrame(logs) # Training print("[TRAIN] Addestramento modello...") result = ml_analyzer.train(df, contamination=request.contamination) print(f"[TRAIN] Modello addestrato: {result}") # Salva nel database print("[TRAIN] Salvataggio training history nel database...") cursor.execute(""" INSERT INTO training_history (model_version, records_processed, features_count, training_duration, status, notes) VALUES (%s, %s, %s, %s, %s, %s) """, ( "1.0.0", result['records_processed'], result['features_count'], 0, # duration non ancora implementato result['status'], f"Anomalie: {result['anomalies_detected']}/{result['unique_ips']}" )) conn.commit() print("[TRAIN] โœ… Training history salvato con successo!") cursor.close() conn.close() print(f"[TRAIN] Completato: {result}") except Exception as e: print(f"[TRAIN ERROR] โŒ Errore durante training: {e}") import traceback traceback.print_exc() if conn: conn.rollback() finally: if cursor: cursor.close() if conn: conn.close() # Esegui in background background_tasks.add_task(do_training) return { "message": "Training avviato in background", "max_records": request.max_records, "hours_back": request.hours_back } @app.post("/detect") async def detect_anomalies(request: DetectRequest): """ Rileva anomalie nei log recenti Opzionalmente blocca automaticamente IP anomali """ if ml_analyzer.model is None: # Prova a caricare modello salvato if not ml_analyzer.load_model(): raise HTTPException( status_code=400, detail="Modello non addestrato. Esegui /train prima." ) try: conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Fetch logs recenti min_timestamp = datetime.now() - timedelta(hours=request.hours_back) query = """ SELECT * FROM network_logs WHERE timestamp >= %s ORDER BY timestamp DESC LIMIT %s """ cursor.execute(query, (min_timestamp, request.max_records)) logs = cursor.fetchall() if not logs: return {"detections": [], "message": "Nessun log da analizzare"} # Converti in DataFrame df = pd.DataFrame(logs) # Detection detections = ml_analyzer.detect(df, risk_threshold=request.risk_threshold) # Geolocation lookup service - BATCH ASYNC per performance geo_service = get_geo_service() # Estrai lista IP unici per batch lookup unique_ips = list(set(det['source_ip'] for det in detections)) # Batch lookup async (VELOCE - tutti in parallelo!) geo_results = await geo_service.lookup_batch_async(unique_ips) # Salva detections nel database for det in detections: # Get geo info from batch results geo_info = geo_results.get(det['source_ip']) # Controlla se giร  esiste cursor.execute( "SELECT id FROM detections WHERE source_ip = %s ORDER BY detected_at DESC LIMIT 1", (det['source_ip'],) ) existing = cursor.fetchone() if existing: # Aggiorna esistente (con geo info) cursor.execute(""" UPDATE detections SET risk_score = %s, confidence = %s, anomaly_type = %s, reason = %s, log_count = %s, last_seen = %s, country = %s, country_code = %s, city = %s, organization = %s, as_number = %s, as_name = %s, isp = %s WHERE id = %s """, ( det['risk_score'], det['confidence'], det['anomaly_type'], det['reason'], det['log_count'], det['last_seen'], geo_info.get('country') if geo_info else None, geo_info.get('country_code') if geo_info else None, geo_info.get('city') if geo_info else None, geo_info.get('organization') if geo_info else None, geo_info.get('as_number') if geo_info else None, geo_info.get('as_name') if geo_info else None, geo_info.get('isp') if geo_info else None, existing['id'] )) else: # Inserisci nuovo (con geo info) cursor.execute(""" INSERT INTO detections (source_ip, risk_score, confidence, anomaly_type, reason, log_count, first_seen, last_seen, country, country_code, city, organization, as_number, as_name, isp) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( det['source_ip'], det['risk_score'], det['confidence'], det['anomaly_type'], det['reason'], det['log_count'], det['first_seen'], det['last_seen'], geo_info.get('country') if geo_info else None, geo_info.get('country_code') if geo_info else None, geo_info.get('city') if geo_info else None, geo_info.get('organization') if geo_info else None, geo_info.get('as_number') if geo_info else None, geo_info.get('as_name') if geo_info else None, geo_info.get('isp') if geo_info else None )) conn.commit() # Auto-block se richiesto blocked_count = 0 if request.auto_block and detections: # Fetch routers abilitati cursor.execute("SELECT * FROM routers WHERE enabled = true") routers = cursor.fetchall() if routers: for det in detections: if det['risk_score'] >= 80: # Solo rischio CRITICO # Controlla whitelist cursor.execute( "SELECT id FROM whitelist WHERE ip_address = %s AND active = true", (det['source_ip'],) ) if cursor.fetchone(): continue # Skip whitelisted # Blocca su tutti i router results = await mikrotik_manager.block_ip_on_all_routers( routers, det['source_ip'], comment=f"IDS: {det['anomaly_type']}" ) if any(results.values()): # Aggiorna detection cursor.execute(""" UPDATE detections SET blocked = true, blocked_at = NOW() WHERE source_ip = %s """, (det['source_ip'],)) blocked_count += 1 conn.commit() cursor.close() conn.close() return { "detections": detections, "total": len(detections), "blocked": blocked_count if request.auto_block else 0, "message": f"Trovate {len(detections)} anomalie" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/block-ip") async def block_ip(request: BlockIPRequest): """Blocca manualmente un IP su tutti i router""" try: conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Controlla whitelist cursor.execute( "SELECT id FROM whitelist WHERE ip_address = %s AND active = true", (request.ip_address,) ) if cursor.fetchone(): raise HTTPException( status_code=400, detail=f"IP {request.ip_address} รจ in whitelist" ) # Fetch routers cursor.execute("SELECT * FROM routers WHERE enabled = true") routers = cursor.fetchall() if not routers: raise HTTPException(status_code=400, detail="Nessun router configurato") # Blocca su tutti i router results = await mikrotik_manager.block_ip_on_all_routers( routers, request.ip_address, list_name=request.list_name, comment=request.comment or "Manual block", timeout_duration=request.timeout_duration ) success_count = sum(1 for v in results.values() if v) cursor.close() conn.close() return { "ip_address": request.ip_address, "blocked_on": success_count, "total_routers": len(routers), "results": results } except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/unblock-ip") async def unblock_ip(request: UnblockIPRequest): """Sblocca un IP da tutti i router""" try: conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Fetch routers cursor.execute("SELECT * FROM routers WHERE enabled = true") routers = cursor.fetchall() if not routers: raise HTTPException(status_code=400, detail="Nessun router configurato") # Sblocca da tutti i router results = await mikrotik_manager.unblock_ip_on_all_routers( routers, request.ip_address, list_name=request.list_name ) success_count = sum(1 for v in results.values() if v) # Aggiorna database cursor.execute(""" UPDATE detections SET blocked = false WHERE source_ip = %s """, (request.ip_address,)) conn.commit() cursor.close() conn.close() return { "ip_address": request.ip_address, "unblocked_from": success_count, "total_routers": len(routers), "results": results } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/stats") async def get_stats(): """Statistiche sistema""" try: conn = get_db_connection() cursor = conn.cursor(cursor_factory=RealDictCursor) # Log stats cursor.execute("SELECT COUNT(*) as total FROM network_logs") result = cursor.fetchone() total_logs = result['total'] if result else 0 cursor.execute(""" SELECT COUNT(*) as recent FROM network_logs WHERE timestamp >= NOW() - INTERVAL '1 hour' """) result = cursor.fetchone() recent_logs = result['recent'] if result else 0 # Detection stats cursor.execute("SELECT COUNT(*) as total FROM detections") result = cursor.fetchone() total_detections = result['total'] if result else 0 cursor.execute(""" SELECT COUNT(*) as blocked FROM detections WHERE blocked = true """) result = cursor.fetchone() blocked_ips = result['blocked'] if result else 0 # Router stats cursor.execute("SELECT COUNT(*) as total FROM routers WHERE enabled = true") result = cursor.fetchone() active_routers = result['total'] if result else 0 # Latest training cursor.execute(""" SELECT * FROM training_history ORDER BY trained_at DESC LIMIT 1 """) latest_training = cursor.fetchone() cursor.close() conn.close() return { "logs": { "total": total_logs, "last_hour": recent_logs }, "detections": { "total": total_detections, "blocked": blocked_ips }, "routers": { "active": active_routers }, "latest_training": dict(latest_training) if latest_training else None } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Service Management Endpoints (Secured with API Key) @app.get("/services/status") async def get_services_status(authorized: bool = Security(verify_api_key)): """ Verifica status di tutti i servizi gestiti da systemd RICHIEDE: API Key valida per accesso """ try: import subprocess services_to_check = { "ml_backend": "ids-ml-backend", "syslog_parser": "ids-syslog-parser" } results = {} for service_name, systemd_unit in services_to_check.items(): try: # Check systemd service status result = subprocess.run( ["systemctl", "is-active", systemd_unit], capture_output=True, text=True, timeout=5 ) is_active = result.stdout.strip() == "active" # Get more details if active details = {} if is_active: status_result = subprocess.run( ["systemctl", "status", systemd_unit, "--no-pager"], capture_output=True, text=True, timeout=5 ) # Extract PID from status if available for line in status_result.stdout.split('\n'): if 'Main PID:' in line: pid = line.split('Main PID:')[1].strip().split()[0] details['pid'] = pid break results[service_name] = { "running": is_active, "systemd_unit": systemd_unit, "details": details } except subprocess.TimeoutExpired: results[service_name] = { "running": False, "error": "Timeout checking service" } except FileNotFoundError: # systemctl not available (likely development environment) results[service_name] = { "running": False, "error": "systemd not available" } except Exception as e: results[service_name] = { "running": False, "error": str(e) } return { "services": results, "timestamp": datetime.now().isoformat() } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn # Prova a caricare modello esistente ml_analyzer.load_model() print("๐Ÿš€ Starting IDS API on http://0.0.0.0:8000") print("๐Ÿ“š Docs available at http://0.0.0.0:8000/docs") uvicorn.run(app, host="0.0.0.0", port=8000)