Replit-Commit-Author: Agent Replit-Commit-Session-Id: 7a657272-55ba-4a79-9a2e-f1ed9bc7a528 Replit-Commit-Checkpoint-Type: full_checkpoint Replit-Commit-Event-Id: 1c71ce6e-1a3e-4f53-bb5d-77cdd22b8ea3
646 lines
27 KiB
Python
646 lines
27 KiB
Python
import pandas as pd
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.sql import text
|
|
from sklearn.ensemble import IsolationForest
|
|
from sklearn.neighbors import LocalOutlierFactor
|
|
from sklearn.svm import OneClassSVM
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from joblib import dump, load
|
|
import logging
|
|
import gc
|
|
import os
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
import numpy as np
|
|
import argparse # Aggiunto per gestire gli argomenti da linea di comando
|
|
import sys
|
|
import traceback
|
|
|
|
# Configurazione del logging migliorata
|
|
logging.basicConfig(
|
|
level=logging.DEBUG, # Cambiato da INFO a DEBUG
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(sys.stdout),
|
|
logging.FileHandler('analisys_debug.log') # File di log separato
|
|
]
|
|
)
|
|
|
|
# Cartella per i modelli
|
|
MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
|
|
try:
|
|
os.makedirs(MODEL_DIR, exist_ok=True)
|
|
logging.debug(f"Directory models creata/verificata: {MODEL_DIR}")
|
|
except Exception as e:
|
|
logging.error(f"Errore nella creazione della directory models: {e}")
|
|
# Fallback alla directory corrente
|
|
MODEL_DIR = os.path.join(os.getcwd(), 'models')
|
|
try:
|
|
os.makedirs(MODEL_DIR, exist_ok=True)
|
|
logging.debug(f"Directory models creata come fallback in: {MODEL_DIR}")
|
|
except Exception as e2:
|
|
logging.error(f"Impossibile creare la directory models anche come fallback: {e2}")
|
|
MODEL_DIR = '.' # Usa la directory corrente come ultima risorsa
|
|
|
|
# Percorsi dei modelli
|
|
IF_MODEL_PATH = os.path.join(MODEL_DIR, 'isolation_forest.joblib')
|
|
LOF_MODEL_PATH = os.path.join(MODEL_DIR, 'lof.joblib')
|
|
SVM_MODEL_PATH = os.path.join(MODEL_DIR, 'svm.joblib')
|
|
ENSEMBLE_MODEL_PATH = os.path.join(MODEL_DIR, 'ensemble_weights.joblib')
|
|
PREPROCESSOR_PATH = os.path.join(MODEL_DIR, 'preprocessor.joblib')
|
|
ACCUMULATED_DATA_PATH = os.path.join(MODEL_DIR, 'accumulated_data.pkl')
|
|
LAST_TRAINING_PATH = os.path.join(MODEL_DIR, 'last_training.txt')
|
|
|
|
# Parametri di configurazione
|
|
TRAINING_FREQUENCY_HOURS = 12 # Riaddestra ogni 12 ore
|
|
CONTINUOUS_LEARNING = True
|
|
|
|
def extract_time_features(df):
|
|
"""
|
|
Estrae caratteristiche temporali dai dati
|
|
"""
|
|
logging.info("Estrazione delle caratteristiche temporali...")
|
|
|
|
# Converti timestamp in ora del giorno e giorno della settimana
|
|
df['hour_of_day'] = df['Timestamp'].dt.hour
|
|
df['day_of_week'] = df['Timestamp'].dt.dayofweek
|
|
|
|
# Calcola il tempo tra eventi consecutivi per lo stesso IP
|
|
ip_features = pd.DataFrame()
|
|
|
|
if 'IndirizzoIP' in df.columns:
|
|
# Ordina per IP e timestamp
|
|
df_sorted = df.sort_values(['IndirizzoIP', 'Timestamp'])
|
|
|
|
# Per ogni IP, calcola il tempo tra eventi consecutivi
|
|
ip_groups = df_sorted.groupby('IndirizzoIP')
|
|
|
|
# Inizializza colonne per le nuove caratteristiche
|
|
df['time_since_last'] = np.nan
|
|
df['events_last_hour'] = 0
|
|
df['events_last_day'] = 0
|
|
|
|
for ip, group in ip_groups:
|
|
if len(group) > 1:
|
|
# Calcola il tempo tra eventi consecutivi
|
|
group = group.copy()
|
|
group['time_since_last'] = group['Timestamp'].diff().dt.total_seconds()
|
|
|
|
# Aggiorna il DataFrame originale
|
|
df.loc[group.index, 'time_since_last'] = group['time_since_last']
|
|
|
|
# Conta eventi nell'ultima ora e giorno per ogni IP
|
|
for idx, row in group.iterrows():
|
|
current_time = row['Timestamp']
|
|
one_hour_ago = current_time - timedelta(hours=1)
|
|
one_day_ago = current_time - timedelta(days=1)
|
|
|
|
# Conta eventi nell'ultima ora
|
|
events_last_hour = len(group[(group['Timestamp'] > one_hour_ago) &
|
|
(group['Timestamp'] <= current_time)])
|
|
|
|
# Conta eventi nell'ultimo giorno
|
|
events_last_day = len(group[(group['Timestamp'] > one_day_ago) &
|
|
(group['Timestamp'] <= current_time)])
|
|
|
|
df.loc[idx, 'events_last_hour'] = events_last_hour
|
|
df.loc[idx, 'events_last_day'] = events_last_day
|
|
|
|
# Estrai statistiche per IP
|
|
ip_stats = ip_groups.agg({
|
|
'time_since_last': ['mean', 'std', 'min', 'max'],
|
|
'events_last_hour': 'max',
|
|
'events_last_day': 'max'
|
|
})
|
|
|
|
# Rinomina le colonne
|
|
ip_stats.columns = ['_'.join(col).strip() for col in ip_stats.columns.values]
|
|
|
|
# Resetta l'indice per avere IndirizzoIP come colonna
|
|
ip_stats = ip_stats.reset_index()
|
|
|
|
# Merge con il DataFrame originale
|
|
df = df.merge(ip_stats, on='IndirizzoIP', how='left')
|
|
|
|
logging.info("Caratteristiche temporali estratte con successo.")
|
|
return df
|
|
|
|
def connect_to_database():
|
|
"""
|
|
Connette al database MySQL usando le credenziali da variabili d'ambiente
|
|
"""
|
|
try:
|
|
logging.info("Connessione al database...")
|
|
|
|
db_user = os.environ.get('MYSQL_USER', 'root')
|
|
db_password = os.environ.get('MYSQL_PASSWORD', 'Hdgtejskjjc0-')
|
|
db_host = os.environ.get('MYSQL_HOST', 'localhost')
|
|
db_name = os.environ.get('MYSQL_DATABASE', 'LOG_MIKROTIK')
|
|
|
|
connection_string = f"mysql+mysqlconnector://{db_user}:{db_password}@{db_host}/{db_name}"
|
|
engine = create_engine(connection_string)
|
|
|
|
return engine
|
|
except Exception as e:
|
|
logging.error(f"Errore nella connessione al database: {e}")
|
|
return None
|
|
|
|
def extract_new_data(engine, window_minutes=90):
|
|
"""
|
|
Estrae nuovi dati dal database per l'addestramento del modello
|
|
"""
|
|
try:
|
|
logging.info(f"Estrazione dei dati degli ultimi {window_minutes} minuti...")
|
|
|
|
with engine.connect() as conn:
|
|
query = text("""
|
|
SELECT ip, port, COUNT(*) as count
|
|
FROM logs
|
|
WHERE timestamp >= DATE_SUB(NOW(), INTERVAL :window MINUTE)
|
|
GROUP BY ip, port
|
|
""")
|
|
|
|
df = pd.read_sql(query, conn, params={"window": window_minutes})
|
|
logging.info(f"Estratti {len(df)} record dal database")
|
|
return df
|
|
except Exception as e:
|
|
logging.error(f"Errore nell'estrazione dei dati: {e}")
|
|
return pd.DataFrame()
|
|
|
|
def save_model_timestamp():
|
|
"""
|
|
Salva il timestamp dell'ultimo addestramento del modello
|
|
"""
|
|
try:
|
|
engine = connect_to_database()
|
|
if not engine:
|
|
return False
|
|
|
|
with engine.connect() as conn:
|
|
# Crea la tabella se non esiste
|
|
create_table_query = text("""
|
|
CREATE TABLE IF NOT EXISTS model_metadata (
|
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
model_name VARCHAR(50) NOT NULL,
|
|
last_trained TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
model_path VARCHAR(255),
|
|
UNIQUE KEY unique_model (model_name)
|
|
)
|
|
""")
|
|
conn.execute(create_table_query)
|
|
|
|
# Aggiorna o inserisci il timestamp
|
|
upsert_query = text("""
|
|
INSERT INTO model_metadata (model_name, last_trained, model_path)
|
|
VALUES ('ensemble', NOW(), :model_path)
|
|
ON DUPLICATE KEY UPDATE last_trained = NOW(), model_path = :model_path
|
|
""")
|
|
|
|
conn.execute(upsert_query, {"model_path": ENSEMBLE_MODEL_PATH})
|
|
|
|
logging.info("Timestamp di addestramento del modello salvato con successo")
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Errore nel salvare il timestamp di addestramento: {e}")
|
|
return False
|
|
|
|
def needs_training(force_training=False):
|
|
"""
|
|
Verifica se il modello deve essere riaddestrato (ogni 12 ore)
|
|
"""
|
|
if force_training:
|
|
logging.info("Riaddestramento forzato richiesto.")
|
|
return True
|
|
|
|
try:
|
|
engine = connect_to_database()
|
|
if not engine:
|
|
return True
|
|
|
|
with engine.connect() as conn:
|
|
# Verifica se la tabella esiste
|
|
try:
|
|
query = text("""
|
|
SELECT last_trained
|
|
FROM model_metadata
|
|
WHERE model_name = 'ensemble'
|
|
""")
|
|
|
|
result = conn.execute(query).fetchone()
|
|
|
|
if not result:
|
|
logging.info("Nessun dato di addestramento precedente trovato, riaddestramento necessario")
|
|
return True
|
|
|
|
last_trained = result[0]
|
|
now = datetime.now()
|
|
|
|
# Se l'ultimo addestramento è più vecchio di 12 ore, riaddestra
|
|
hours_diff = (now - last_trained).total_seconds() / 3600
|
|
|
|
if hours_diff >= 12:
|
|
logging.info(f"Ultimo addestramento: {last_trained}, {hours_diff:.1f} ore fa. Riaddestramento necessario")
|
|
return True
|
|
else:
|
|
logging.info(f"Ultimo addestramento: {last_trained}, {hours_diff:.1f} ore fa. Riaddestramento non necessario")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logging.warning(f"Errore nel controllo della tabella model_metadata: {e}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logging.error(f"Errore nella verifica del bisogno di riaddestramento: {e}")
|
|
return True
|
|
|
|
def update_last_training_time():
|
|
"""
|
|
Aggiorna il timestamp dell'ultimo addestramento
|
|
"""
|
|
with open(LAST_TRAINING_PATH, 'w') as f:
|
|
f.write(datetime.now().isoformat())
|
|
|
|
def train_models(X):
|
|
"""
|
|
Addestra più modelli e li combina in un ensemble
|
|
"""
|
|
logging.info("Addestramento dei modelli di anomaly detection...")
|
|
|
|
# Isolation Forest
|
|
isolation_forest = IsolationForest(n_estimators=100, contamination=0.01, random_state=42, verbose=0)
|
|
isolation_forest.fit(X)
|
|
logging.info("Isolation Forest addestrato.")
|
|
|
|
# Local Outlier Factor
|
|
lof = LocalOutlierFactor(n_neighbors=20, contamination=0.01, novelty=True)
|
|
lof.fit(X)
|
|
logging.info("Local Outlier Factor addestrato.")
|
|
|
|
# One-Class SVM (più lento, usa solo un sottoinsieme dei dati se necessario)
|
|
max_svm_samples = min(10000, X.shape[0])
|
|
svm_indices = np.random.choice(X.shape[0], max_svm_samples, replace=False)
|
|
svm = OneClassSVM(kernel='rbf', gamma='scale', nu=0.01)
|
|
svm.fit(X[svm_indices])
|
|
logging.info("One-Class SVM addestrato.")
|
|
|
|
# Pesi per l'ensemble (possono essere ottimizzati con validation)
|
|
ensemble_weights = {
|
|
'isolation_forest': 0.5,
|
|
'lof': 0.3,
|
|
'svm': 0.2
|
|
}
|
|
|
|
# Salva i modelli
|
|
dump(isolation_forest, IF_MODEL_PATH)
|
|
dump(lof, LOF_MODEL_PATH)
|
|
dump(svm, SVM_MODEL_PATH)
|
|
dump(ensemble_weights, ENSEMBLE_MODEL_PATH)
|
|
|
|
logging.info("Tutti i modelli salvati con successo.")
|
|
|
|
# Aggiorna il timestamp dell'ultimo addestramento
|
|
update_last_training_time()
|
|
|
|
return isolation_forest, lof, svm, ensemble_weights
|
|
|
|
# Funzione per testare la connessione al database
|
|
def test_database_connection(conn_string=None):
|
|
"""
|
|
Testa la connessione al database e visualizza informazioni sulle tabelle
|
|
"""
|
|
if conn_string is None:
|
|
conn_string = 'mysql+mysqlconnector://root:Hdgtejskjjc0-@localhost/LOG_MIKROTIK'
|
|
|
|
try:
|
|
logging.debug(f"Test di connessione al database con: {conn_string}")
|
|
engine = create_engine(conn_string)
|
|
with engine.connect() as conn:
|
|
# Test semplice
|
|
result = conn.execute(text("SELECT 1")).fetchone()
|
|
if result and result[0] == 1:
|
|
logging.debug("Test connessione di base superato!")
|
|
|
|
# Elenca le tabelle
|
|
tables = conn.execute(text("SHOW TABLES")).fetchall()
|
|
table_names = [t[0] for t in tables]
|
|
logging.debug(f"Tabelle disponibili: {table_names}")
|
|
|
|
# Verifica tabella Fibra
|
|
if 'Fibra' in table_names:
|
|
# Ottieni informazioni sulla tabella
|
|
columns = conn.execute(text("DESCRIBE Fibra")).fetchall()
|
|
logging.debug(f"Struttura tabella Fibra: {[c[0] for c in columns]}")
|
|
|
|
# Conta i record
|
|
count = conn.execute(text("SELECT COUNT(*) FROM Fibra")).fetchone()[0]
|
|
logging.debug(f"Tabella Fibra contiene {count} record")
|
|
|
|
# Visualizza alcuni dati di esempio
|
|
sample = conn.execute(text("SELECT * FROM Fibra LIMIT 3")).fetchall()
|
|
if sample:
|
|
logging.debug(f"Esempio di dati da Fibra: {sample[0]}")
|
|
else:
|
|
logging.error("Tabella Fibra non trovata!")
|
|
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logging.error(f"Errore nel test di connessione al database: {e}")
|
|
logging.error(traceback.format_exc())
|
|
return False
|
|
|
|
|
|
def extract_data_sample(engine, table='Fibra', limit=5):
|
|
"""
|
|
Estrae un campione di dati per test/debug
|
|
"""
|
|
try:
|
|
query = f"""
|
|
SELECT * FROM {table}
|
|
ORDER BY ID DESC
|
|
LIMIT {limit}
|
|
"""
|
|
sample = pd.read_sql(query, engine)
|
|
logging.debug(f"Estratto campione da {table}: {len(sample)} record")
|
|
logging.debug(f"Colonne: {sample.columns.tolist()}")
|
|
return sample
|
|
except Exception as e:
|
|
logging.error(f"Errore nell'estrazione del campione: {e}")
|
|
return pd.DataFrame()
|
|
|
|
def main():
|
|
# Parsing degli argomenti da linea di comando
|
|
parser = argparse.ArgumentParser(description='Addestra modelli di anomaly detection')
|
|
parser.add_argument('--force-training', action='store_true', help='Forza il riaddestramento dei modelli')
|
|
parser.add_argument('--test', action='store_true', help='Esegue test diagnostici senza addestramento')
|
|
parser.add_argument('--time-window', type=str, default='12 HOUR', help='Finestra temporale per i dati (es. "12 HOUR", "1 DAY")')
|
|
args = parser.parse_args()
|
|
|
|
logging.info(f"Opzioni: force_training={args.force_training}, test={args.test}, time_window={args.time_window}")
|
|
|
|
# Se è richiesto solo il test
|
|
if args.test:
|
|
logging.info("MODALITÀ TEST - esecuzione diagnostica")
|
|
|
|
# Test connessione al database
|
|
if not test_database_connection():
|
|
logging.error("Test database fallito - impossibile continuare")
|
|
return False
|
|
|
|
# Connessione al database per test aggiuntivi
|
|
try:
|
|
engine = create_engine('mysql+mysqlconnector://root:Hdgtejskjjc0-@localhost/LOG_MIKROTIK')
|
|
logging.info("Test: estrazione di un campione di dati...")
|
|
sample = extract_data_sample(engine)
|
|
if not sample.empty:
|
|
logging.info(f"Test di estrazione dati riuscito. Colonne: {sample.columns.tolist()}")
|
|
logging.info(f"Esempio record: {sample.iloc[0].to_dict()}")
|
|
else:
|
|
logging.error("Impossibile estrarre dati di esempio")
|
|
|
|
# Verifica percorsi
|
|
logging.info(f"Test: controllo percorsi per i modelli...")
|
|
logging.info(f"MODEL_DIR esiste: {os.path.exists(MODEL_DIR)}")
|
|
if os.path.exists(ACCUMULATED_DATA_PATH):
|
|
logging.info(f"Dati accumulati esistenti: {os.path.getsize(ACCUMULATED_DATA_PATH)/1024:.2f} KB")
|
|
|
|
logging.info("Modalità test completata")
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Errore durante i test: {e}")
|
|
logging.error(traceback.format_exc())
|
|
return False
|
|
|
|
# 1. Caricamento dei dati accumulati
|
|
if os.path.exists(ACCUMULATED_DATA_PATH):
|
|
logging.info("Caricamento dei dati accumulati...")
|
|
try:
|
|
accumulated_data = pd.read_pickle(ACCUMULATED_DATA_PATH)
|
|
logging.info(f"Dati accumulati caricati: {len(accumulated_data)} record.")
|
|
logging.debug(f"Colonne nei dati accumulati: {accumulated_data.columns.tolist()}")
|
|
except Exception as e:
|
|
logging.error(f"Errore nel caricamento dei dati accumulati: {e}")
|
|
logging.error(traceback.format_exc())
|
|
accumulated_data = pd.DataFrame()
|
|
logging.info("Creazione di un nuovo dataset a causa dell'errore.")
|
|
else:
|
|
accumulated_data = pd.DataFrame()
|
|
logging.info("Nessun dato accumulato trovato. Creazione di un nuovo dataset.")
|
|
|
|
# 2. Connessione al database
|
|
logging.info("Connessione al database...")
|
|
try:
|
|
engine = create_engine('mysql+mysqlconnector://root:Hdgtejskjjc0-@localhost/LOG_MIKROTIK')
|
|
logging.info("Connessione stabilita.")
|
|
except Exception as e:
|
|
logging.error(f"Errore di connessione al database: {e}")
|
|
logging.error(traceback.format_exc())
|
|
return False
|
|
|
|
# 3. Estrazione dei dati
|
|
query_time_window = args.time_window # Usa il valore dall'argomento
|
|
|
|
if args.force_training:
|
|
# Se richiesto l'addestramento forzato, prendi più dati storici
|
|
logging.info(f"Addestramento forzato: estrazione dati degli ultimi {query_time_window}...")
|
|
else:
|
|
logging.info(f"Estrazione dei dati degli ultimi {query_time_window}...")
|
|
|
|
try:
|
|
query = f"""
|
|
SELECT ID, Data, Ora, Host, IndirizzoIP, Messaggio1, Messaggio2, Messaggio3, Messaggio4
|
|
FROM Fibra
|
|
WHERE CONCAT(Data, ' ', Ora) >= NOW() - INTERVAL {query_time_window}
|
|
LIMIT 50000
|
|
"""
|
|
|
|
logging.debug(f"Query SQL: {query}")
|
|
new_data = pd.read_sql(query, engine)
|
|
logging.info(f"Dati estratti: {len(new_data)} record.")
|
|
|
|
if not new_data.empty:
|
|
logging.debug(f"Intervallo date: {new_data['Data'].min()} - {new_data['Data'].max()}")
|
|
except Exception as e:
|
|
logging.error(f"Errore nell'estrazione dei dati: {e}")
|
|
logging.error(traceback.format_exc())
|
|
new_data = pd.DataFrame()
|
|
logging.error("Impossibile estrarre nuovi dati.")
|
|
|
|
has_data_to_process = False
|
|
|
|
if new_data.empty:
|
|
logging.info("Nessun nuovo dato da aggiungere.")
|
|
# Verifica se ci sono dati già accumulati e se è richiesto l'addestramento forzato
|
|
if args.force_training:
|
|
if not accumulated_data.empty:
|
|
logging.info("Riaddestramento forzato richiesto con dati accumulati esistenti.")
|
|
has_data_to_process = True
|
|
else:
|
|
# Prova a cercare dati più vecchi dalla tabella
|
|
logging.info("Nessun dato recente trovato. Tentativo di estrazione dati più vecchi...")
|
|
fallback_query = """
|
|
SELECT ID, Data, Ora, Host, IndirizzoIP, Messaggio1, Messaggio2, Messaggio3, Messaggio4
|
|
FROM Fibra
|
|
ORDER BY ID DESC
|
|
LIMIT 10000
|
|
"""
|
|
new_data = pd.read_sql(fallback_query, engine)
|
|
logging.info(f"Dati più vecchi estratti: {len(new_data)} record.")
|
|
|
|
if not new_data.empty:
|
|
accumulated_data = new_data
|
|
has_data_to_process = True
|
|
else:
|
|
logging.warning("Impossibile trovare dati nella tabella Fibra. Addestramento impossibile.")
|
|
else:
|
|
# 4. Aggiunta del nuovo blocco ai dati accumulati
|
|
if accumulated_data.empty:
|
|
accumulated_data = new_data
|
|
else:
|
|
accumulated_data = pd.concat([accumulated_data, new_data], ignore_index=True)
|
|
has_data_to_process = True
|
|
|
|
# Salva i dati accumulati
|
|
accumulated_data.to_pickle(ACCUMULATED_DATA_PATH)
|
|
logging.info(f"Dati accumulati salvati: {len(accumulated_data)} record.")
|
|
|
|
# Procedi solo se ci sono dati da elaborare
|
|
if has_data_to_process:
|
|
logging.info(f"Elaborazione di {len(accumulated_data)} record...")
|
|
|
|
# Verifica che i dati contengano le colonne necessarie
|
|
required_columns = ['Data', 'Ora', 'Host', 'IndirizzoIP', 'Messaggio1', 'Messaggio2', 'Messaggio3', 'Messaggio4']
|
|
missing_columns = [col for col in required_columns if col not in accumulated_data.columns]
|
|
|
|
if missing_columns:
|
|
logging.warning(f"I dati accumulati non contengono le colonne necessarie: {missing_columns}")
|
|
logging.warning("Impossibile procedere con il riaddestramento. È necessario generare nuovi dati.")
|
|
|
|
# Verifica quali colonne sono presenti
|
|
logging.info(f"Colonne presenti nei dati: {accumulated_data.columns.tolist()}")
|
|
|
|
# Se i dati hanno già la struttura elaborata (senza le colonne originali)
|
|
if 'Host' in accumulated_data.columns and 'IndirizzoIP' in accumulated_data.columns and 'Messaggio' in accumulated_data.columns:
|
|
logging.info("I dati sembrano essere già stati preprocessati.")
|
|
|
|
# Controlla se ci sono le caratteristiche temporali necessarie
|
|
time_features = ['hour_of_day', 'day_of_week', 'events_last_hour', 'events_last_day']
|
|
has_time_features = all(col in accumulated_data.columns for col in time_features)
|
|
|
|
if has_time_features:
|
|
logging.info("Le caratteristiche temporali sono presenti. Posso procedere con l'addestramento del modello.")
|
|
# Salta alla parte di codifica delle variabili categoriali
|
|
goto_preprocessing = True
|
|
else:
|
|
logging.warning("Mancano le caratteristiche temporali. Non posso procedere con l'addestramento.")
|
|
goto_preprocessing = False
|
|
else:
|
|
logging.warning("I dati non hanno la struttura corretta. Non posso procedere con l'addestramento.")
|
|
goto_preprocessing = False
|
|
|
|
if not goto_preprocessing:
|
|
logging.info("Salto il riaddestramento. Si prega di generare nuovi dati con 'python3 ddetect.py'.")
|
|
return
|
|
else:
|
|
goto_preprocessing = False
|
|
# 5. Preprocessing dei dati accumulati
|
|
accumulated_data['Data'] = pd.to_datetime(accumulated_data['Data'], errors='coerce')
|
|
accumulated_data['Ora'] = pd.to_timedelta(accumulated_data['Ora'].astype(str), errors='coerce')
|
|
accumulated_data.dropna(subset=['Data', 'Ora'], inplace=True)
|
|
accumulated_data['Timestamp'] = accumulated_data['Data'] + accumulated_data['Ora']
|
|
|
|
# 6. Rimozione dei dati più vecchi di 5 ore
|
|
threshold_time = datetime.now() - timedelta(hours=5)
|
|
accumulated_data = accumulated_data[accumulated_data['Timestamp'] >= threshold_time]
|
|
|
|
# 7. Unione dei messaggi
|
|
accumulated_data['Messaggio'] = accumulated_data[['Messaggio1', 'Messaggio2', 'Messaggio3', 'Messaggio4']].fillna('').agg(' '.join, axis=1)
|
|
|
|
# Mantieni temporaneamente le colonne originali per l'estrazione delle caratteristiche temporali
|
|
temp_columns = accumulated_data[['Messaggio1', 'Messaggio2', 'Messaggio3', 'Messaggio4', 'Data', 'Ora']].copy()
|
|
|
|
# 8. Estrai caratteristiche temporali avanzate
|
|
accumulated_data = extract_time_features(accumulated_data)
|
|
|
|
# Ora possiamo eliminare le colonne originali
|
|
accumulated_data.drop(columns=['Messaggio1', 'Messaggio2', 'Messaggio3', 'Messaggio4', 'Data', 'Ora'], inplace=True)
|
|
gc.collect()
|
|
|
|
# 9. Codifica delle variabili categoriali
|
|
logging.info("Codifica delle variabili categoriali...")
|
|
from category_encoders import HashingEncoder
|
|
|
|
# Encoder separati per 'Host' e 'IndirizzoIP'
|
|
he_host = HashingEncoder(n_components=8, hash_method='md5')
|
|
X_host = he_host.fit_transform(accumulated_data['Host'].astype(str))
|
|
|
|
he_ip = HashingEncoder(n_components=8, hash_method='md5')
|
|
X_ip = he_ip.fit_transform(accumulated_data['IndirizzoIP'].astype(str))
|
|
|
|
accumulated_data.drop(columns=['Host', 'IndirizzoIP'], inplace=True)
|
|
gc.collect()
|
|
|
|
# 10. Trasformazione TF-IDF
|
|
logging.info("Trasformazione dei messaggi con TF-IDF...")
|
|
vectorizer = TfidfVectorizer(max_features=500)
|
|
X_messages = vectorizer.fit_transform(accumulated_data['Messaggio'])
|
|
accumulated_data.drop(columns=['Messaggio'], inplace=True)
|
|
gc.collect()
|
|
|
|
# 11. Creazione del DataFrame delle caratteristiche
|
|
logging.info("Creazione del DataFrame delle caratteristiche...")
|
|
from scipy.sparse import hstack
|
|
from scipy import sparse
|
|
|
|
# Converti X_host e X_ip in matrici sparse e assicurati che i tipi siano compatibili
|
|
X_host_sparse = sparse.csr_matrix(X_host).astype('float64')
|
|
X_ip_sparse = sparse.csr_matrix(X_ip).astype('float64')
|
|
X_messages = X_messages.astype('float64')
|
|
|
|
# Estrai caratteristiche temporali numeriche
|
|
time_features = accumulated_data[['hour_of_day', 'day_of_week', 'events_last_hour',
|
|
'events_last_day', 'time_since_last_mean',
|
|
'time_since_last_std', 'time_since_last_min',
|
|
'time_since_last_max']].fillna(0)
|
|
X_time = sparse.csr_matrix(time_features.values).astype('float64')
|
|
|
|
X = hstack([X_host_sparse, X_ip_sparse, X_messages, X_time]).tocsr()
|
|
del X_host, X_ip, X_host_sparse, X_ip_sparse, X_messages, X_time
|
|
gc.collect()
|
|
|
|
# 12. Verifica se è necessario riaddestrare il modello
|
|
if needs_training(args.force_training):
|
|
logging.info("Riaddestrando i modelli...")
|
|
isolation_forest, lof, svm, ensemble_weights = train_models(X)
|
|
|
|
# Salvataggio nel database del timestamp di addestramento
|
|
save_model_timestamp()
|
|
else:
|
|
logging.info("Utilizzo dei modelli esistenti.")
|
|
# Carica i modelli esistenti se non è necessario riaddestrarli
|
|
if os.path.exists(IF_MODEL_PATH) and os.path.exists(LOF_MODEL_PATH) and os.path.exists(SVM_MODEL_PATH):
|
|
isolation_forest = load(IF_MODEL_PATH)
|
|
lof = load(LOF_MODEL_PATH)
|
|
svm = load(SVM_MODEL_PATH)
|
|
|
|
if os.path.exists(ENSEMBLE_MODEL_PATH):
|
|
ensemble_weights = load(ENSEMBLE_MODEL_PATH)
|
|
else:
|
|
ensemble_weights = {'isolation_forest': 0.5, 'lof': 0.3, 'svm': 0.2}
|
|
|
|
logging.info("Modelli caricati con successo.")
|
|
else:
|
|
logging.info("Modelli non trovati, addestramento in corso...")
|
|
isolation_forest, lof, svm, ensemble_weights = train_models(X)
|
|
|
|
# 13. Salvataggio del modello e dei dati
|
|
logging.info("Salvataggio dei modelli, encoder e dati accumulati...")
|
|
dump(he_host, 'hashing_encoder_host.joblib')
|
|
dump(he_ip, 'hashing_encoder_ip.joblib')
|
|
dump(vectorizer, 'tfidf_vectorizer.joblib')
|
|
accumulated_data.to_pickle(ACCUMULATED_DATA_PATH)
|
|
logging.info("Salvataggio completato.")
|
|
|
|
if __name__ == '__main__':
|
|
main()
|