#!/usr/bin/env python3 """ Merge Logic for Public Lists Integration Implements priority: Manual Whitelist > Public Whitelist > Public Blacklist """ import os import psycopg2 from typing import Dict, Set, Optional from datetime import datetime import logging import ipaddress logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def ip_matches_cidr(ip_address: str, cidr_range: Optional[str]) -> bool: """ Check if IP address matches CIDR range Returns True if cidr_range is None (exact match) or if IP is in range """ if not cidr_range: return True # Exact match handling try: ip = ipaddress.ip_address(ip_address) network = ipaddress.ip_network(cidr_range, strict=False) return ip in network except (ValueError, TypeError): logger.warning(f"Invalid IP/CIDR: {ip_address}/{cidr_range}") return False class MergeLogic: """ Handles merge logic between manual entries and public lists Priority: Manual whitelist > Public whitelist > Public blacklist """ def __init__(self, database_url: str): self.database_url = database_url def get_db_connection(self): """Create database connection""" return psycopg2.connect(self.database_url) def get_all_whitelisted_ips(self) -> Set[str]: """ Get all whitelisted IPs (manual + public) Manual whitelist has higher priority than public whitelist """ conn = self.get_db_connection() try: with conn.cursor() as cur: cur.execute(""" SELECT DISTINCT ip_address FROM whitelist WHERE active = true """) return {row[0] for row in cur.fetchall()} finally: conn.close() def get_public_blacklist_ips(self) -> Set[str]: """Get all active public blacklist IPs""" conn = self.get_db_connection() try: with conn.cursor() as cur: cur.execute(""" SELECT DISTINCT ip_address FROM public_blacklist_ips WHERE is_active = true """) return {row[0] for row in cur.fetchall()} finally: conn.close() def should_block_ip(self, ip_address: str) -> tuple[bool, str]: """ Determine if IP should be blocked based on merge logic Returns: (should_block, reason) Priority: 1. Manual whitelist (exact or CIDR) → DON'T block (highest priority) 2. Public whitelist (exact or CIDR) → DON'T block 3. Public blacklist (exact or CIDR) → DO block 4. Not in any list → DON'T block (only ML decides) """ conn = self.get_db_connection() try: with conn.cursor() as cur: # Check manual whitelist (highest priority) - exact + CIDR matching cur.execute(""" SELECT ip_address, list_id FROM whitelist WHERE active = true AND source = 'manual' """) for row in cur.fetchall(): wl_ip, wl_cidr = row[0], None # Check if whitelist entry has CIDR notation if '/' in wl_ip: wl_cidr = wl_ip if wl_ip == ip_address or ip_matches_cidr(ip_address, wl_cidr): return (False, "manual_whitelist") # Check public whitelist (any source except 'manual') - exact + CIDR cur.execute(""" SELECT ip_address, list_id FROM whitelist WHERE active = true AND source != 'manual' """) for row in cur.fetchall(): wl_ip, wl_cidr = row[0], None if '/' in wl_ip: wl_cidr = wl_ip if wl_ip == ip_address or ip_matches_cidr(ip_address, wl_cidr): return (False, "public_whitelist") # Check public blacklist - exact + CIDR matching cur.execute(""" SELECT id, ip_address, cidr_range FROM public_blacklist_ips WHERE is_active = true """) for row in cur.fetchall(): bl_id, bl_ip, bl_cidr = row # Match exact IP or check if IP is in CIDR range if bl_ip == ip_address or ip_matches_cidr(ip_address, bl_cidr): return (True, f"public_blacklist:{bl_id}") # Not in any list return (False, "not_listed") finally: conn.close() def create_detection_from_blacklist( self, ip_address: str, blacklist_id: str, risk_score: int = 75 ) -> Optional[str]: """ Create detection record for public blacklist IP Only if not whitelisted (priority check) """ should_block, reason = self.should_block_ip(ip_address) if not should_block: logger.info(f"IP {ip_address} not blocked - reason: {reason}") return None conn = self.get_db_connection() try: with conn.cursor() as cur: # Check if detection already exists cur.execute(""" SELECT id FROM detections WHERE source_ip = %s AND detection_source = 'public_blacklist' LIMIT 1 """, (ip_address,)) existing = cur.fetchone() if existing: logger.info(f"Detection already exists for {ip_address}") return existing[0] # Create new detection cur.execute(""" INSERT INTO detections ( source_ip, risk_score, anomaly_type, detection_source, blacklist_id, detected_at, blocked ) VALUES (%s, %s, %s, %s, %s, %s, %s) RETURNING id """, ( ip_address, str(risk_score), 'public_blacklist', 'public_blacklist', blacklist_id, datetime.utcnow(), False # Will be blocked by auto-block service if risk_score >= 80 )) result = cur.fetchone() if not result: logger.error(f"Failed to get detection ID after insert for {ip_address}") return None detection_id = result[0] conn.commit() logger.info(f"Created detection {detection_id} for blacklisted IP {ip_address}") return detection_id except Exception as e: conn.rollback() logger.error(f"Failed to create detection for {ip_address}: {e}") return None finally: conn.close() def cleanup_invalid_detections(self) -> int: """ Remove detections for IPs that are now whitelisted Respects priority: manual/public whitelist overrides blacklist """ conn = self.get_db_connection() try: with conn.cursor() as cur: # Delete detections for whitelisted IPs cur.execute(""" DELETE FROM detections WHERE detection_source = 'public_blacklist' AND source_ip IN ( SELECT ip_address FROM whitelist WHERE active = true ) """) deleted = cur.rowcount conn.commit() if deleted > 0: logger.info(f"Cleaned up {deleted} detections for whitelisted IPs") return deleted except Exception as e: conn.rollback() logger.error(f"Failed to cleanup detections: {e}") return 0 finally: conn.close() def sync_public_blacklist_detections(self) -> Dict[str, int]: """ Sync detections with current public blacklist state using BULK operations Creates detections for blacklisted IPs (if not whitelisted) Removes detections for IPs no longer blacklisted or now whitelisted """ stats = { 'created': 0, 'cleaned': 0, 'skipped_whitelisted': 0 } conn = self.get_db_connection() try: with conn.cursor() as cur: # Cleanup whitelisted IPs first (priority) stats['cleaned'] = self.cleanup_invalid_detections() # Bulk create detections for blacklisted IPs (excluding whitelisted) # MVP: Exact IP matching (CIDR expansion in future iteration) # Note: CIDR ranges stored but not yet matched - requires schema optimization cur.execute(""" INSERT INTO detections ( source_ip, risk_score, anomaly_type, detection_source, blacklist_id, detected_at, blocked ) SELECT DISTINCT bl.ip_address, '75', 'public_blacklist', 'public_blacklist', bl.id, NOW(), false FROM public_blacklist_ips bl WHERE bl.is_active = true -- Priority: Manual whitelist > Public whitelist > Blacklist AND NOT EXISTS ( SELECT 1 FROM whitelist wl WHERE wl.ip_address = bl.ip_address AND wl.active = true ) -- Avoid duplicate detections AND NOT EXISTS ( SELECT 1 FROM detections d WHERE d.source_ip = bl.ip_address AND d.detection_source = 'public_blacklist' ) RETURNING id """) created_ids = cur.fetchall() stats['created'] = len(created_ids) conn.commit() logger.info(f"Bulk sync complete: {stats}") return stats except Exception as e: conn.rollback() logger.error(f"Failed to sync detections: {e}") import traceback traceback.print_exc() return stats finally: conn.close() def main(): """Run merge logic sync""" database_url = os.environ.get('DATABASE_URL') if not database_url: logger.error("DATABASE_URL environment variable not set") return 1 merge = MergeLogic(database_url) stats = merge.sync_public_blacklist_detections() print(f"\n{'='*60}") print("MERGE LOGIC SYNC COMPLETED") print(f"{'='*60}") print(f"Created detections: {stats['created']}") print(f"Cleaned invalid detections: {stats['cleaned']}") print(f"Skipped (whitelisted): {stats['skipped_whitelisted']}") print(f"{'='*60}\n") return 0 if __name__ == "__main__": exit(main())