Updates IP address handling to include CIDR notation for more comprehensive network range matching, enhances database schema with INET/CIDR types, and refactors logic for accurate IP detection and whitelisting. Replit-Commit-Author: Agent Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528 Replit-Commit-Checkpoint-Type: intermediate_checkpoint Replit-Commit-Event-Id: 49a5a4b7-82b5-4dd4-84c1-9f0e855bea8a Replit-Commit-Screenshot-Url: https://storage.googleapis.com/screenshot-production-us-central1/449cf7c4-c97a-45ae-8234-e5c5b8d6a84f/7a657272-55ba-4a79-9a2e-f1ed9bc7a528/qHCi0Qg
355 lines
13 KiB
Python
Executable File
355 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Merge Logic for Public Lists Integration
|
|
Implements priority: Manual Whitelist > Public Whitelist > Public Blacklist
|
|
"""
|
|
import os
|
|
import psycopg2
|
|
from typing import Dict, Set, Optional
|
|
from datetime import datetime
|
|
import logging
|
|
import ipaddress
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def ip_matches_cidr(ip_address: str, cidr_range: Optional[str]) -> bool:
|
|
"""
|
|
Check if IP address matches CIDR range
|
|
Returns True if cidr_range is None (exact match) or if IP is in range
|
|
"""
|
|
if not cidr_range:
|
|
return True # Exact match handling
|
|
|
|
try:
|
|
ip = ipaddress.ip_address(ip_address)
|
|
network = ipaddress.ip_network(cidr_range, strict=False)
|
|
return ip in network
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Invalid IP/CIDR: {ip_address}/{cidr_range}")
|
|
return False
|
|
|
|
|
|
class MergeLogic:
|
|
"""
|
|
Handles merge logic between manual entries and public lists
|
|
Priority: Manual whitelist > Public whitelist > Public blacklist
|
|
"""
|
|
|
|
def __init__(self, database_url: str):
|
|
self.database_url = database_url
|
|
|
|
def get_db_connection(self):
|
|
"""Create database connection"""
|
|
return psycopg2.connect(self.database_url)
|
|
|
|
def get_all_whitelisted_ips(self) -> Set[str]:
|
|
"""
|
|
Get all whitelisted IPs (manual + public)
|
|
Manual whitelist has higher priority than public whitelist
|
|
"""
|
|
conn = self.get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT DISTINCT ip_address
|
|
FROM whitelist
|
|
WHERE active = true
|
|
""")
|
|
return {row[0] for row in cur.fetchall()}
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_public_blacklist_ips(self) -> Set[str]:
|
|
"""Get all active public blacklist IPs"""
|
|
conn = self.get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT DISTINCT ip_address
|
|
FROM public_blacklist_ips
|
|
WHERE is_active = true
|
|
""")
|
|
return {row[0] for row in cur.fetchall()}
|
|
finally:
|
|
conn.close()
|
|
|
|
def should_block_ip(self, ip_address: str) -> tuple[bool, str]:
|
|
"""
|
|
Determine if IP should be blocked based on merge logic
|
|
Returns: (should_block, reason)
|
|
|
|
Priority:
|
|
1. Manual whitelist (exact or CIDR) → DON'T block (highest priority)
|
|
2. Public whitelist (exact or CIDR) → DON'T block
|
|
3. Public blacklist (exact or CIDR) → DO block
|
|
4. Not in any list → DON'T block (only ML decides)
|
|
"""
|
|
conn = self.get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# Check manual whitelist (highest priority) - exact + CIDR matching
|
|
cur.execute("""
|
|
SELECT ip_address, list_id FROM whitelist
|
|
WHERE active = true
|
|
AND source = 'manual'
|
|
""")
|
|
for row in cur.fetchall():
|
|
wl_ip, wl_cidr = row[0], None
|
|
# Check if whitelist entry has CIDR notation
|
|
if '/' in wl_ip:
|
|
wl_cidr = wl_ip
|
|
if wl_ip == ip_address or ip_matches_cidr(ip_address, wl_cidr):
|
|
return (False, "manual_whitelist")
|
|
|
|
# Check public whitelist (any source except 'manual') - exact + CIDR
|
|
cur.execute("""
|
|
SELECT ip_address, list_id FROM whitelist
|
|
WHERE active = true
|
|
AND source != 'manual'
|
|
""")
|
|
for row in cur.fetchall():
|
|
wl_ip, wl_cidr = row[0], None
|
|
if '/' in wl_ip:
|
|
wl_cidr = wl_ip
|
|
if wl_ip == ip_address or ip_matches_cidr(ip_address, wl_cidr):
|
|
return (False, "public_whitelist")
|
|
|
|
# Check public blacklist - exact + CIDR matching
|
|
cur.execute("""
|
|
SELECT id, ip_address, cidr_range FROM public_blacklist_ips
|
|
WHERE is_active = true
|
|
""")
|
|
for row in cur.fetchall():
|
|
bl_id, bl_ip, bl_cidr = row
|
|
# Match exact IP or check if IP is in CIDR range
|
|
if bl_ip == ip_address or ip_matches_cidr(ip_address, bl_cidr):
|
|
return (True, f"public_blacklist:{bl_id}")
|
|
|
|
# Not in any list
|
|
return (False, "not_listed")
|
|
finally:
|
|
conn.close()
|
|
|
|
def create_detection_from_blacklist(
|
|
self,
|
|
ip_address: str,
|
|
blacklist_id: str,
|
|
risk_score: int = 75
|
|
) -> Optional[str]:
|
|
"""
|
|
Create detection record for public blacklist IP
|
|
Only if not whitelisted (priority check)
|
|
"""
|
|
should_block, reason = self.should_block_ip(ip_address)
|
|
|
|
if not should_block:
|
|
logger.info(f"IP {ip_address} not blocked - reason: {reason}")
|
|
return None
|
|
|
|
conn = self.get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# Check if detection already exists
|
|
cur.execute("""
|
|
SELECT id FROM detections
|
|
WHERE source_ip = %s
|
|
AND detection_source = 'public_blacklist'
|
|
LIMIT 1
|
|
""", (ip_address,))
|
|
|
|
existing = cur.fetchone()
|
|
if existing:
|
|
logger.info(f"Detection already exists for {ip_address}")
|
|
return existing[0]
|
|
|
|
# Create new detection
|
|
cur.execute("""
|
|
INSERT INTO detections (
|
|
source_ip,
|
|
risk_score,
|
|
anomaly_type,
|
|
detection_source,
|
|
blacklist_id,
|
|
detected_at,
|
|
blocked
|
|
) VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
RETURNING id
|
|
""", (
|
|
ip_address,
|
|
str(risk_score),
|
|
'public_blacklist',
|
|
'public_blacklist',
|
|
blacklist_id,
|
|
datetime.utcnow(),
|
|
False # Will be blocked by auto-block service if risk_score >= 80
|
|
))
|
|
|
|
result = cur.fetchone()
|
|
if not result:
|
|
logger.error(f"Failed to get detection ID after insert for {ip_address}")
|
|
return None
|
|
|
|
detection_id = result[0]
|
|
conn.commit()
|
|
|
|
logger.info(f"Created detection {detection_id} for blacklisted IP {ip_address}")
|
|
return detection_id
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logger.error(f"Failed to create detection for {ip_address}: {e}")
|
|
return None
|
|
finally:
|
|
conn.close()
|
|
|
|
def cleanup_invalid_detections(self) -> int:
|
|
"""
|
|
Remove detections for IPs that are now whitelisted
|
|
CIDR-aware: checks both exact match and network containment
|
|
Respects priority: manual/public whitelist overrides blacklist
|
|
"""
|
|
conn = self.get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# Delete detections for IPs in whitelist ranges (CIDR-aware)
|
|
cur.execute("""
|
|
DELETE FROM detections d
|
|
WHERE d.detection_source = 'public_blacklist'
|
|
AND EXISTS (
|
|
SELECT 1 FROM whitelist wl
|
|
WHERE wl.active = true
|
|
AND wl.ip_inet IS NOT NULL
|
|
AND (
|
|
d.source_ip::inet = wl.ip_inet
|
|
OR d.source_ip::inet <<= wl.ip_inet
|
|
)
|
|
)
|
|
""")
|
|
deleted = cur.rowcount
|
|
conn.commit()
|
|
|
|
if deleted > 0:
|
|
logger.info(f"Cleaned up {deleted} detections for whitelisted IPs (CIDR-aware)")
|
|
|
|
return deleted
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logger.error(f"Failed to cleanup detections: {e}")
|
|
return 0
|
|
finally:
|
|
conn.close()
|
|
|
|
def sync_public_blacklist_detections(self) -> Dict[str, int]:
|
|
"""
|
|
Sync detections with current public blacklist state using BULK operations
|
|
Creates detections for blacklisted IPs (if not whitelisted)
|
|
Removes detections for IPs no longer blacklisted or now whitelisted
|
|
"""
|
|
stats = {
|
|
'created': 0,
|
|
'cleaned': 0,
|
|
'skipped_whitelisted': 0
|
|
}
|
|
|
|
conn = self.get_db_connection()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# Cleanup whitelisted IPs first (priority)
|
|
stats['cleaned'] = self.cleanup_invalid_detections()
|
|
|
|
# Bulk create detections with CIDR-aware matching
|
|
# Uses PostgreSQL INET operators for network containment
|
|
# Priority: Manual whitelist > Public whitelist > Blacklist
|
|
cur.execute("""
|
|
INSERT INTO detections (
|
|
source_ip,
|
|
risk_score,
|
|
anomaly_type,
|
|
detection_source,
|
|
blacklist_id,
|
|
detected_at,
|
|
blocked
|
|
)
|
|
SELECT DISTINCT
|
|
bl.ip_address,
|
|
'75',
|
|
'public_blacklist',
|
|
'public_blacklist',
|
|
bl.id,
|
|
NOW(),
|
|
false
|
|
FROM public_blacklist_ips bl
|
|
WHERE bl.is_active = true
|
|
AND bl.ip_inet IS NOT NULL
|
|
-- Priority 1: Exclude if in manual whitelist (highest priority)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM whitelist wl
|
|
WHERE wl.active = true
|
|
AND wl.source = 'manual'
|
|
AND wl.ip_inet IS NOT NULL
|
|
AND (
|
|
bl.ip_inet = wl.ip_inet
|
|
OR bl.ip_inet <<= wl.ip_inet
|
|
)
|
|
)
|
|
-- Priority 2: Exclude if in public whitelist
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM whitelist wl
|
|
WHERE wl.active = true
|
|
AND wl.source != 'manual'
|
|
AND wl.ip_inet IS NOT NULL
|
|
AND (
|
|
bl.ip_inet = wl.ip_inet
|
|
OR bl.ip_inet <<= wl.ip_inet
|
|
)
|
|
)
|
|
-- Avoid duplicate detections
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM detections d
|
|
WHERE d.source_ip = bl.ip_address
|
|
AND d.detection_source = 'public_blacklist'
|
|
)
|
|
RETURNING id
|
|
""")
|
|
|
|
created_ids = cur.fetchall()
|
|
stats['created'] = len(created_ids)
|
|
conn.commit()
|
|
|
|
logger.info(f"Bulk sync complete: {stats}")
|
|
return stats
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logger.error(f"Failed to sync detections: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return stats
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def main():
|
|
"""Run merge logic sync"""
|
|
database_url = os.environ.get('DATABASE_URL')
|
|
if not database_url:
|
|
logger.error("DATABASE_URL environment variable not set")
|
|
return 1
|
|
|
|
merge = MergeLogic(database_url)
|
|
stats = merge.sync_public_blacklist_detections()
|
|
|
|
print(f"\n{'='*60}")
|
|
print("MERGE LOGIC SYNC COMPLETED")
|
|
print(f"{'='*60}")
|
|
print(f"Created detections: {stats['created']}")
|
|
print(f"Cleaned invalid detections: {stats['cleaned']}")
|
|
print(f"Skipped (whitelisted): {stats['skipped_whitelisted']}")
|
|
print(f"{'='*60}\n")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|