8.1 与关系型数据库集成
8.1.1 数据同步策略
import psycopg2
import mysql.connector
from neo4j import GraphDatabase
import json
import logging
from typing import Dict, List, Any, Optional
from datetime import datetime
import threading
import time
from dataclasses import dataclass
from enum import Enum
class SyncDirection(Enum):
"""同步方向"""
RDBMS_TO_NEO4J = "rdbms_to_neo4j"
NEO4J_TO_RDBMS = "neo4j_to_rdbms"
BIDIRECTIONAL = "bidirectional"
class SyncMode(Enum):
"""同步模式"""
FULL = "full"
INCREMENTAL = "incremental"
REAL_TIME = "real_time"
@dataclass
class SyncConfig:
"""同步配置"""
source_table: str
target_label: str
key_mapping: Dict[str, str]
relationship_mapping: Optional[Dict[str, Any]] = None
sync_direction: SyncDirection = SyncDirection.RDBMS_TO_NEO4J
sync_mode: SyncMode = SyncMode.INCREMENTAL
batch_size: int = 1000
sync_interval: int = 300 # 秒
class Neo4jRDBMSIntegrator:
"""Neo4j与关系型数据库集成器"""
def __init__(self, neo4j_config: Dict[str, Any], rdbms_config: Dict[str, Any]):
# Neo4j连接
self.neo4j_driver = GraphDatabase.driver(
neo4j_config['uri'],
auth=(neo4j_config['username'], neo4j_config['password'])
)
# 关系型数据库连接
self.rdbms_config = rdbms_config
self.rdbms_type = rdbms_config['type'] # 'postgresql' or 'mysql'
self.rdbms_connection = None
self._connect_rdbms()
# 同步配置
self.sync_configs: List[SyncConfig] = []
self.sync_status = {}
self.is_running = False
# 配置日志
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def _connect_rdbms(self):
"""连接关系型数据库"""
try:
if self.rdbms_type == 'postgresql':
self.rdbms_connection = psycopg2.connect(
host=self.rdbms_config['host'],
port=self.rdbms_config['port'],
database=self.rdbms_config['database'],
user=self.rdbms_config['username'],
password=self.rdbms_config['password']
)
elif self.rdbms_type == 'mysql':
self.rdbms_connection = mysql.connector.connect(
host=self.rdbms_config['host'],
port=self.rdbms_config['port'],
database=self.rdbms_config['database'],
user=self.rdbms_config['username'],
password=self.rdbms_config['password']
)
self.logger.info(f"Connected to {self.rdbms_type} database")
except Exception as e:
self.logger.error(f"Failed to connect to RDBMS: {e}")
raise
def add_sync_config(self, config: SyncConfig):
"""添加同步配置"""
self.sync_configs.append(config)
self.sync_status[config.source_table] = {
'last_sync': None,
'records_synced': 0,
'errors': 0,
'status': 'configured'
}
self.logger.info(f"Added sync config for {config.source_table} -> {config.target_label}")
def start_sync(self):
"""启动同步服务"""
if self.is_running:
self.logger.warning("Sync service is already running")
return
self.is_running = True
def sync_loop():
while self.is_running:
try:
for config in self.sync_configs:
if config.sync_mode == SyncMode.REAL_TIME:
continue # 实时同步由触发器处理
self._perform_sync(config)
# 等待下一次同步
time.sleep(min(config.sync_interval for config in self.sync_configs))
except Exception as e:
self.logger.error(f"Sync loop error: {e}")
time.sleep(60) # 错误后等待1分钟
sync_thread = threading.Thread(target=sync_loop, daemon=True)
sync_thread.start()
self.logger.info("Sync service started")
def stop_sync(self):
"""停止同步服务"""
self.is_running = False
self.logger.info("Sync service stopped")
def _perform_sync(self, config: SyncConfig):
"""执行同步"""
try:
if config.sync_direction == SyncDirection.RDBMS_TO_NEO4J:
self._sync_rdbms_to_neo4j(config)
elif config.sync_direction == SyncDirection.NEO4J_TO_RDBMS:
self._sync_neo4j_to_rdbms(config)
elif config.sync_direction == SyncDirection.BIDIRECTIONAL:
self._sync_rdbms_to_neo4j(config)
self._sync_neo4j_to_rdbms(config)
# 更新同步状态
self.sync_status[config.source_table]['last_sync'] = datetime.now()
self.sync_status[config.source_table]['status'] = 'success'
except Exception as e:
self.logger.error(f"Sync failed for {config.source_table}: {e}")
self.sync_status[config.source_table]['errors'] += 1
self.sync_status[config.source_table]['status'] = 'error'
def _sync_rdbms_to_neo4j(self, config: SyncConfig):
"""从关系型数据库同步到Neo4j"""
cursor = self.rdbms_connection.cursor()
try:
# 构建查询
if config.sync_mode == SyncMode.FULL:
query = f"SELECT * FROM {config.source_table}"
else:
# 增量同步:基于时间戳
last_sync = self.sync_status[config.source_table]['last_sync']
if last_sync:
query = f"SELECT * FROM {config.source_table} WHERE updated_at > %s"
cursor.execute(query, (last_sync,))
else:
query = f"SELECT * FROM {config.source_table}"
cursor.execute(query)
if config.sync_mode == SyncMode.FULL:
cursor.execute(query)
# 批量处理数据
batch = []
records_processed = 0
for row in cursor:
# 转换数据格式
if self.rdbms_type == 'postgresql':
columns = [desc[0] for desc in cursor.description]
record = dict(zip(columns, row))
else: # MySQL
record = dict(zip(cursor.column_names, row))
batch.append(record)
if len(batch) >= config.batch_size:
self._create_neo4j_nodes(batch, config)
records_processed += len(batch)
batch = []
# 处理剩余数据
if batch:
self._create_neo4j_nodes(batch, config)
records_processed += len(batch)
self.sync_status[config.source_table]['records_synced'] += records_processed
self.logger.info(f"Synced {records_processed} records from {config.source_table} to Neo4j")
finally:
cursor.close()
def _create_neo4j_nodes(self, records: List[Dict[str, Any]], config: SyncConfig):
"""在Neo4j中创建节点"""
with self.neo4j_driver.session() as session:
# 构建Cypher查询
properties = []
for neo4j_prop, rdbms_col in config.key_mapping.items():
properties.append(f"{neo4j_prop}: $record.{rdbms_col}")
cypher = f"""
UNWIND $records AS record
MERGE (n:{config.target_label} {{id: record.{list(config.key_mapping.values())[0]}}})
SET n += {{{', '.join(properties)}}}
"""
session.run(cypher, records=records)
# 创建关系(如果配置了)
if config.relationship_mapping:
self._create_relationships(records, config, session)
def _create_relationships(self, records: List[Dict[str, Any]], config: SyncConfig, session):
"""创建关系"""
for rel_config in config.relationship_mapping:
rel_type = rel_config['type']
source_key = rel_config['source_key']
target_key = rel_config['target_key']
target_label = rel_config['target_label']
cypher = f"""
UNWIND $records AS record
MATCH (source:{config.target_label} {{id: record.{source_key}}})
MATCH (target:{target_label} {{id: record.{target_key}}})
MERGE (source)-[:{rel_type}]->(target)
"""
session.run(cypher, records=records)
def _sync_neo4j_to_rdbms(self, config: SyncConfig):
"""从Neo4j同步到关系型数据库"""
with self.neo4j_driver.session() as session:
# 查询Neo4j数据
cypher = f"MATCH (n:{config.target_label}) RETURN n"
result = session.run(cypher)
cursor = self.rdbms_connection.cursor()
try:
batch = []
records_processed = 0
for record in result:
node = record['n']
# 转换数据格式
rdbms_record = {}
for rdbms_col, neo4j_prop in config.key_mapping.items():
rdbms_record[rdbms_col] = node.get(neo4j_prop)
batch.append(rdbms_record)
if len(batch) >= config.batch_size:
self._upsert_rdbms_records(batch, config, cursor)
records_processed += len(batch)
batch = []
# 处理剩余数据
if batch:
self._upsert_rdbms_records(batch, config, cursor)
records_processed += len(batch)
self.rdbms_connection.commit()
self.logger.info(f"Synced {records_processed} records from Neo4j to {config.source_table}")
finally:
cursor.close()
def _upsert_rdbms_records(self, records: List[Dict[str, Any]], config: SyncConfig, cursor):
"""在关系型数据库中插入或更新记录"""
for record in records:
columns = list(record.keys())
values = list(record.values())
placeholders = ['%s'] * len(values)
if self.rdbms_type == 'postgresql':
# PostgreSQL UPSERT
conflict_column = list(config.key_mapping.keys())[0]
update_clause = ', '.join([f"{col} = EXCLUDED.{col}" for col in columns if col != conflict_column])
query = f"""
INSERT INTO {config.source_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ({conflict_column}) DO UPDATE SET {update_clause}
"""
else:
# MySQL UPSERT
update_clause = ', '.join([f"{col} = VALUES({col})" for col in columns])
query = f"""
INSERT INTO {config.source_table} ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
ON DUPLICATE KEY UPDATE {update_clause}
"""
cursor.execute(query, values)
def get_sync_status(self) -> Dict[str, Any]:
"""获取同步状态"""
return {
'is_running': self.is_running,
'total_configs': len(self.sync_configs),
'sync_details': self.sync_status
}
def manual_sync(self, table_name: str):
"""手动触发同步"""
config = next((c for c in self.sync_configs if c.source_table == table_name), None)
if config:
self._perform_sync(config)
self.logger.info(f"Manual sync completed for {table_name}")
else:
self.logger.error(f"No sync config found for {table_name}")
def close(self):
"""关闭连接"""
self.stop_sync()
if self.neo4j_driver:
self.neo4j_driver.close()
if self.rdbms_connection:
self.rdbms_connection.close()
self.logger.info("Connections closed")
# 使用示例
neo4j_config = {
'uri': 'bolt://localhost:7687',
'username': 'neo4j',
'password': 'password'
}
postgresql_config = {
'type': 'postgresql',
'host': 'localhost',
'port': 5432,
'database': 'myapp',
'username': 'postgres',
'password': 'password'
}
# 创建集成器
integrator = Neo4jRDBMSIntegrator(neo4j_config, postgresql_config)
# 配置用户表同步
user_sync_config = SyncConfig(
source_table='users',
target_label='User',
key_mapping={
'id': 'id',
'name': 'name',
'email': 'email',
'created_at': 'created_at'
},
sync_direction=SyncDirection.BIDIRECTIONAL,
sync_mode=SyncMode.INCREMENTAL,
batch_size=500
)
# 配置订单表同步(包含关系)
order_sync_config = SyncConfig(
source_table='orders',
target_label='Order',
key_mapping={
'id': 'id',
'user_id': 'user_id',
'total': 'total',
'status': 'status',
'created_at': 'created_at'
},
relationship_mapping=[
{
'type': 'PLACED_BY',
'source_key': 'id',
'target_key': 'user_id',
'target_label': 'User'
}
],
sync_direction=SyncDirection.RDBMS_TO_NEO4J,
sync_mode=SyncMode.INCREMENTAL
)
integrator.add_sync_config(user_sync_config)
integrator.add_sync_config(order_sync_config)
# 启动同步服务
integrator.start_sync()
# 检查同步状态
status = integrator.get_sync_status()
print(f"Sync service running: {status['is_running']}")
print(f"Total configurations: {status['total_configs']}")
# 手动触发同步
integrator.manual_sync('users')
# 等待一段时间后关闭
time.sleep(10)
integrator.close()
8.2 与大数据平台集成
8.2.1 Apache Spark集成
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, when, isnan, isnull
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from neo4j import GraphDatabase
import json
from typing import Dict, List, Any, Optional
import logging
class Neo4jSparkConnector:
"""Neo4j与Spark集成连接器"""
def __init__(self, spark_config: Dict[str, Any], neo4j_config: Dict[str, Any]):
# 初始化Spark会话
self.spark = SparkSession.builder \
.appName(spark_config.get('app_name', 'Neo4jSparkIntegration')) \
.config('spark.jars.packages', 'neo4j-contrib:neo4j-connector-apache-spark_2.12:4.1.5_for_spark_3') \
.getOrCreate()
# Neo4j配置
self.neo4j_config = neo4j_config
self.neo4j_driver = GraphDatabase.driver(
neo4j_config['uri'],
auth=(neo4j_config['username'], neo4j_config['password'])
)
# 配置日志
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def read_from_neo4j(self, cypher_query: str, parameters: Dict[str, Any] = None) -> 'DataFrame':
"""从Neo4j读取数据到Spark DataFrame"""
try:
# 使用Neo4j Spark连接器
df = self.spark.read.format("org.neo4j.spark.DataSource") \
.option("url", self.neo4j_config['uri']) \
.option("authentication.basic.username", self.neo4j_config['username']) \
.option("authentication.basic.password", self.neo4j_config['password']) \
.option("query", cypher_query)
if parameters:
for key, value in parameters.items():
df = df.option(f"query.{key}", value)
return df.load()
except Exception as e:
self.logger.error(f"Failed to read from Neo4j: {e}")
raise
def write_to_neo4j(self, df: 'DataFrame', node_label: str,
node_keys: List[str], mode: str = "Overwrite"):
"""将Spark DataFrame写入Neo4j"""
try:
df.write.format("org.neo4j.spark.DataSource") \
.option("url", self.neo4j_config['uri']) \
.option("authentication.basic.username", self.neo4j_config['username']) \
.option("authentication.basic.password", self.neo4j_config['password']) \
.option("labels", node_label) \
.option("node.keys", ",".join(node_keys)) \
.mode(mode) \
.save()
self.logger.info(f"Successfully wrote DataFrame to Neo4j with label {node_label}")
except Exception as e:
self.logger.error(f"Failed to write to Neo4j: {e}")
raise
def create_relationships_from_df(self, df: 'DataFrame',
source_label: str, target_label: str,
relationship_type: str,
source_key: str, target_key: str):
"""从DataFrame创建关系"""
try:
# 准备关系数据
rel_df = df.select(
col(source_key).alias("source.id"),
col(target_key).alias("target.id")
).filter(
col(source_key).isNotNull() & col(target_key).isNotNull()
)
rel_df.write.format("org.neo4j.spark.DataSource") \
.option("url", self.neo4j_config['uri']) \
.option("authentication.basic.username", self.neo4j_config['username']) \
.option("authentication.basic.password", self.neo4j_config['password']) \
.option("relationship", relationship_type) \
.option("relationship.source.labels", source_label) \
.option("relationship.source.save.mode", "Match") \
.option("relationship.target.labels", target_label) \
.option("relationship.target.save.mode", "Match") \
.mode("Append") \
.save()
self.logger.info(f"Successfully created {relationship_type} relationships")
except Exception as e:
self.logger.error(f"Failed to create relationships: {e}")
raise
def analyze_graph_with_spark(self, analysis_type: str) -> Dict[str, Any]:
"""使用Spark分析图数据"""
if analysis_type == "node_degree_distribution":
return self._analyze_node_degree_distribution()
elif analysis_type == "community_detection":
return self._analyze_communities()
elif analysis_type == "centrality_metrics":
return self._calculate_centrality_metrics()
else:
raise ValueError(f"Unsupported analysis type: {analysis_type}")
def _analyze_node_degree_distribution(self) -> Dict[str, Any]:
"""分析节点度分布"""
# 查询所有关系
relationships_query = """
MATCH (a)-[r]->(b)
RETURN id(a) as source_id, id(b) as target_id, type(r) as rel_type
"""
rel_df = self.read_from_neo4j(relationships_query)
# 计算出度
out_degree = rel_df.groupBy("source_id").count().withColumnRenamed("count", "out_degree")
# 计算入度
in_degree = rel_df.groupBy("target_id").count().withColumnRenamed("count", "in_degree")
# 合并度数据
degree_df = out_degree.join(
in_degree,
out_degree.source_id == in_degree.target_id,
"full_outer"
).select(
when(col("source_id").isNull(), col("target_id")).otherwise(col("source_id")).alias("node_id"),
when(col("out_degree").isNull(), 0).otherwise(col("out_degree")).alias("out_degree"),
when(col("in_degree").isNull(), 0).otherwise(col("in_degree")).alias("in_degree")
).withColumn("total_degree", col("out_degree") + col("in_degree"))
# 计算统计信息
degree_stats = degree_df.select(
"total_degree"
).describe().collect()
# 度分布
degree_distribution = degree_df.groupBy("total_degree").count().orderBy("total_degree").collect()
return {
'statistics': {row['summary']: float(row['total_degree']) for row in degree_stats},
'distribution': [(row['total_degree'], row['count']) for row in degree_distribution]
}
def _analyze_communities(self) -> Dict[str, Any]:
"""社区检测分析"""
# 使用Louvain算法进行社区检测
community_query = """
CALL gds.louvain.stream('myGraph')
YIELD nodeId, communityId
RETURN gds.util.asNode(nodeId).id as node_id, communityId
"""
try:
community_df = self.read_from_neo4j(community_query)
# 计算社区统计
community_stats = community_df.groupBy("communityId").count().orderBy(col("count").desc())
# 获取最大的几个社区
top_communities = community_stats.limit(10).collect()
return {
'total_communities': community_df.select("communityId").distinct().count(),
'top_communities': [(row['communityId'], row['count']) for row in top_communities],
'modularity': self._calculate_modularity(community_df)
}
except Exception as e:
self.logger.warning(f"Community detection failed, using simple clustering: {e}")
return self._simple_clustering_analysis()
def _simple_clustering_analysis(self) -> Dict[str, Any]:
"""简单聚类分析(当GDS不可用时)"""
# 基于连接模式的简单聚类
clustering_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-(m)
RETURN id(n) as node_id, count(r) as connections, collect(distinct id(m)) as neighbors
"""
cluster_df = self.read_from_neo4j(clustering_query)
# 基于连接数进行简单分组
connection_groups = cluster_df.groupBy(
when(col("connections") == 0, "isolated")
.when(col("connections") <= 2, "low_connectivity")
.when(col("connections") <= 5, "medium_connectivity")
.otherwise("high_connectivity")
.alias("group")
).count().collect()
return {
'connection_groups': [(row['group'], row['count']) for row in connection_groups]
}
def _calculate_centrality_metrics(self) -> Dict[str, Any]:
"""计算中心性指标"""
# 度中心性
degree_centrality_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
RETURN id(n) as node_id, count(r) as degree
ORDER BY degree DESC
LIMIT 20
"""
degree_df = self.read_from_neo4j(degree_centrality_query)
top_degree_nodes = degree_df.collect()
# 尝试计算PageRank(如果GDS可用)
try:
pagerank_query = """
CALL gds.pageRank.stream('myGraph')
YIELD nodeId, score
RETURN gds.util.asNode(nodeId).id as node_id, score
ORDER BY score DESC
LIMIT 20
"""
pagerank_df = self.read_from_neo4j(pagerank_query)
top_pagerank_nodes = pagerank_df.collect()
except Exception:
top_pagerank_nodes = []
return {
'top_degree_centrality': [(row['node_id'], row['degree']) for row in top_degree_nodes],
'top_pagerank': [(row['node_id'], float(row['score'])) for row in top_pagerank_nodes] if top_pagerank_nodes else []
}
def _calculate_modularity(self, community_df: 'DataFrame') -> float:
"""计算模块度"""
# 简化的模块度计算
try:
total_edges = self.read_from_neo4j("MATCH ()-[r]->() RETURN count(r) as edge_count").collect()[0]['edge_count']
# 计算社区内边数
intra_community_edges = 0
communities = community_df.select("communityId").distinct().collect()
for community in communities:
community_id = community['communityId']
community_nodes = community_df.filter(col("communityId") == community_id).select("node_id").collect()
node_ids = [row['node_id'] for row in community_nodes]
if len(node_ids) > 1:
# 查询社区内边数
intra_edges_query = f"""
MATCH (a)-[r]->(b)
WHERE id(a) IN {node_ids} AND id(b) IN {node_ids}
RETURN count(r) as intra_edges
"""
result = self.read_from_neo4j(intra_edges_query).collect()
if result:
intra_community_edges += result[0]['intra_edges']
# 简化的模块度计算
if total_edges > 0:
return (intra_community_edges / total_edges) - 0.5 # 简化公式
else:
return 0.0
except Exception:
return 0.0
def batch_process_large_dataset(self, input_path: str, output_label: str,
batch_size: int = 10000):
"""批量处理大型数据集"""
try:
# 读取大型数据集
large_df = self.spark.read.option("header", "true").csv(input_path)
# 数据清洗和转换
cleaned_df = large_df.filter(
~(col("id").isNull() | isnan(col("id")))
).dropDuplicates(["id"])
# 分批写入Neo4j
total_rows = cleaned_df.count()
num_batches = (total_rows + batch_size - 1) // batch_size
self.logger.info(f"Processing {total_rows} rows in {num_batches} batches")
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, total_rows)
# 创建批次DataFrame
batch_df = cleaned_df.limit(end_idx).offset(start_idx)
# 写入Neo4j
self.write_to_neo4j(batch_df, output_label, ["id"], "Append")
self.logger.info(f"Processed batch {i + 1}/{num_batches}")
self.logger.info("Batch processing completed")
except Exception as e:
self.logger.error(f"Batch processing failed: {e}")
raise
def export_graph_to_parquet(self, output_path: str):
"""导出图数据到Parquet格式"""
try:
# 导出节点
nodes_query = "MATCH (n) RETURN id(n) as node_id, labels(n) as labels, properties(n) as properties"
nodes_df = self.read_from_neo4j(nodes_query)
nodes_df.write.mode("overwrite").parquet(f"{output_path}/nodes")
# 导出关系
relationships_query = """
MATCH (a)-[r]->(b)
RETURN id(a) as source_id, id(b) as target_id, type(r) as rel_type, properties(r) as properties
"""
rels_df = self.read_from_neo4j(relationships_query)
rels_df.write.mode("overwrite").parquet(f"{output_path}/relationships")
self.logger.info(f"Graph data exported to {output_path}")
except Exception as e:
self.logger.error(f"Export failed: {e}")
raise
def close(self):
"""关闭连接"""
if self.spark:
self.spark.stop()
if self.neo4j_driver:
self.neo4j_driver.close()
self.logger.info("Spark and Neo4j connections closed")
# 使用示例
spark_config = {
'app_name': 'Neo4jSparkAnalysis'
}
neo4j_config = {
'uri': 'bolt://localhost:7687',
'username': 'neo4j',
'password': 'password'
}
# 创建连接器
connector = Neo4jSparkConnector(spark_config, neo4j_config)
# 从Neo4j读取数据进行分析
user_df = connector.read_from_neo4j("MATCH (u:User) RETURN u.id as id, u.name as name, u.age as age")
user_df.show()
# 分析图结构
degree_analysis = connector.analyze_graph_with_spark("node_degree_distribution")
print(f"Degree distribution statistics: {degree_analysis['statistics']}")
# 社区检测
community_analysis = connector.analyze_graph_with_spark("community_detection")
print(f"Total communities found: {community_analysis.get('total_communities', 'N/A')}")
# 中心性分析
centrality_analysis = connector.analyze_graph_with_spark("centrality_metrics")
print(f"Top degree centrality nodes: {centrality_analysis['top_degree_centrality'][:5]}")
# 批量处理大数据集
# connector.batch_process_large_dataset('/path/to/large_dataset.csv', 'LargeDataNode')
# 导出图数据
# connector.export_graph_to_parquet('/path/to/output')
# 关闭连接
connector.close()
8.2.2 Kafka流处理集成
from kafka import KafkaConsumer, KafkaProducer
from neo4j import GraphDatabase
import json
import threading
import time
from typing import Dict, List, Any, Callable, Optional
from datetime import datetime
import logging
from dataclasses import dataclass
from enum import Enum
class EventType(Enum):
"""事件类型"""
NODE_CREATE = "node_create"
NODE_UPDATE = "node_update"
NODE_DELETE = "node_delete"
RELATIONSHIP_CREATE = "relationship_create"
RELATIONSHIP_DELETE = "relationship_delete"
CUSTOM = "custom"
@dataclass
class GraphEvent:
"""图事件"""
event_type: EventType
timestamp: datetime
data: Dict[str, Any]
source: str
event_id: Optional[str] = None
class Neo4jKafkaStreaming:
"""Neo4j与Kafka流处理集成"""
def __init__(self, kafka_config: Dict[str, Any], neo4j_config: Dict[str, Any]):
# Kafka配置
self.kafka_config = kafka_config
self.bootstrap_servers = kafka_config['bootstrap_servers']
# Neo4j连接
self.neo4j_driver = GraphDatabase.driver(
neo4j_config['uri'],
auth=(neo4j_config['username'], neo4j_config['password'])
)
# 消费者和生产者
self.consumers: Dict[str, KafkaConsumer] = {}
self.producer = KafkaProducer(
bootstrap_servers=self.bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode('utf-8'),
key_serializer=lambda k: k.encode('utf-8') if k else None
)
# 事件处理器
self.event_handlers: Dict[EventType, List[Callable]] = {
event_type: [] for event_type in EventType
}
# 流处理状态
self.is_streaming = False
self.processing_threads: List[threading.Thread] = []
# 配置日志
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def add_event_handler(self, event_type: EventType, handler: Callable[[GraphEvent], None]):
"""添加事件处理器"""
self.event_handlers[event_type].append(handler)
self.logger.info(f"Added handler for {event_type.value} events")
def start_streaming(self, topics: List[str]):
"""启动流处理"""
if self.is_streaming:
self.logger.warning("Streaming is already running")
return
self.is_streaming = True
for topic in topics:
# 为每个主题创建消费者
consumer = KafkaConsumer(
topic,
bootstrap_servers=self.bootstrap_servers,
value_deserializer=lambda m: json.loads(m.decode('utf-8')),
key_deserializer=lambda k: k.decode('utf-8') if k else None,
group_id=f'neo4j_consumer_{topic}',
auto_offset_reset='latest'
)
self.consumers[topic] = consumer
# 启动消费线程
thread = threading.Thread(
target=self._consume_messages,
args=(topic, consumer),
daemon=True
)
thread.start()
self.processing_threads.append(thread)
self.logger.info(f"Started streaming for topics: {topics}")
def stop_streaming(self):
"""停止流处理"""
self.is_streaming = False
# 关闭所有消费者
for consumer in self.consumers.values():
consumer.close()
self.consumers.clear()
self.processing_threads.clear()
self.logger.info("Stopped streaming")
def _consume_messages(self, topic: str, consumer: KafkaConsumer):
"""消费消息"""
self.logger.info(f"Started consuming messages from topic: {topic}")
try:
for message in consumer:
if not self.is_streaming:
break
try:
# 解析事件
event_data = message.value
event = self._parse_event(event_data, topic)
if event:
# 处理事件
self._process_event(event)
except Exception as e:
self.logger.error(f"Error processing message from {topic}: {e}")
except Exception as e:
self.logger.error(f"Consumer error for topic {topic}: {e}")
finally:
self.logger.info(f"Stopped consuming from topic: {topic}")
def _parse_event(self, event_data: Dict[str, Any], topic: str) -> Optional[GraphEvent]:
"""解析事件数据"""
try:
event_type_str = event_data.get('event_type', 'custom')
event_type = EventType(event_type_str)
return GraphEvent(
event_type=event_type,
timestamp=datetime.fromisoformat(event_data.get('timestamp', datetime.now().isoformat())),
data=event_data.get('data', {}),
source=event_data.get('source', topic),
event_id=event_data.get('event_id')
)
except Exception as e:
self.logger.error(f"Failed to parse event: {e}")
return None
def _process_event(self, event: GraphEvent):
"""处理图事件"""
try:
# 执行注册的处理器
for handler in self.event_handlers[event.event_type]:
handler(event)
# 默认处理逻辑
if event.event_type == EventType.NODE_CREATE:
self._handle_node_create(event)
elif event.event_type == EventType.NODE_UPDATE:
self._handle_node_update(event)
elif event.event_type == EventType.NODE_DELETE:
self._handle_node_delete(event)
elif event.event_type == EventType.RELATIONSHIP_CREATE:
self._handle_relationship_create(event)
elif event.event_type == EventType.RELATIONSHIP_DELETE:
self._handle_relationship_delete(event)
except Exception as e:
self.logger.error(f"Error processing event {event.event_id}: {e}")
def _handle_node_create(self, event: GraphEvent):
"""处理节点创建事件"""
data = event.data
label = data.get('label', 'Node')
properties = data.get('properties', {})
with self.neo4j_driver.session() as session:
# 构建属性字符串
prop_items = [f"{k}: ${k}" for k in properties.keys()]
prop_str = "{" + ", ".join(prop_items) + "}"
cypher = f"CREATE (n:{label} {prop_str}) RETURN id(n) as node_id"
result = session.run(cypher, properties)
node_id = result.single()['node_id']
self.logger.info(f"Created node {node_id} with label {label}")
def _handle_node_update(self, event: GraphEvent):
"""处理节点更新事件"""
data = event.data
node_id = data.get('node_id')
properties = data.get('properties', {})
if not node_id:
self.logger.error("Node update event missing node_id")
return
with self.neo4j_driver.session() as session:
# 构建SET子句
set_items = [f"n.{k} = ${k}" for k in properties.keys()]
set_str = ", ".join(set_items)
cypher = f"MATCH (n) WHERE id(n) = $node_id SET {set_str}"
session.run(cypher, {**properties, 'node_id': node_id})
self.logger.info(f"Updated node {node_id}")
def _handle_node_delete(self, event: GraphEvent):
"""处理节点删除事件"""
data = event.data
node_id = data.get('node_id')
if not node_id:
self.logger.error("Node delete event missing node_id")
return
with self.neo4j_driver.session() as session:
cypher = "MATCH (n) WHERE id(n) = $node_id DETACH DELETE n"
session.run(cypher, {'node_id': node_id})
self.logger.info(f"Deleted node {node_id}")
def _handle_relationship_create(self, event: GraphEvent):
"""处理关系创建事件"""
data = event.data
source_id = data.get('source_id')
target_id = data.get('target_id')
rel_type = data.get('relationship_type', 'RELATED_TO')
properties = data.get('properties', {})
if not source_id or not target_id:
self.logger.error("Relationship create event missing source_id or target_id")
return
with self.neo4j_driver.session() as session:
# 构建属性字符串
if properties:
prop_items = [f"{k}: ${k}" for k in properties.keys()]
prop_str = "{" + ", ".join(prop_items) + "}"
else:
prop_str = ""
cypher = f"""
MATCH (a), (b)
WHERE id(a) = $source_id AND id(b) = $target_id
CREATE (a)-[r:{rel_type} {prop_str}]->(b)
RETURN id(r) as rel_id
"""
result = session.run(cypher, {
**properties,
'source_id': source_id,
'target_id': target_id
})
rel_id = result.single()['rel_id']
self.logger.info(f"Created relationship {rel_id} of type {rel_type}")
def _handle_relationship_delete(self, event: GraphEvent):
"""处理关系删除事件"""
data = event.data
rel_id = data.get('relationship_id')
if not rel_id:
self.logger.error("Relationship delete event missing relationship_id")
return
with self.neo4j_driver.session() as session:
cypher = "MATCH ()-[r]-() WHERE id(r) = $rel_id DELETE r"
session.run(cypher, {'rel_id': rel_id})
self.logger.info(f"Deleted relationship {rel_id}")
def publish_event(self, topic: str, event: GraphEvent, key: Optional[str] = None):
"""发布事件到Kafka"""
try:
event_data = {
'event_type': event.event_type.value,
'timestamp': event.timestamp.isoformat(),
'data': event.data,
'source': event.source,
'event_id': event.event_id
}
future = self.producer.send(topic, value=event_data, key=key)
future.get(timeout=10) # 等待发送完成
self.logger.info(f"Published event {event.event_id} to topic {topic}")
except Exception as e:
self.logger.error(f"Failed to publish event: {e}")
def setup_change_data_capture(self):
"""设置变更数据捕获"""
# 创建触发器来捕获Neo4j变更
with self.neo4j_driver.session() as session:
# 注册事务事件监听器(需要APOC插件)
try:
session.run("""
CALL apoc.trigger.add('node-created',
'UNWIND $createdNodes AS n
CALL apoc.util.kafka.send("neo4j-changes", null, {
event_type: "node_create",
timestamp: datetime().epochMillis,
data: {
node_id: id(n),
labels: labels(n),
properties: properties(n)
},
source: "neo4j-cdc"
})
RETURN count(*)',
{phase: 'after'})
""")
session.run("""
CALL apoc.trigger.add('relationship-created',
'UNWIND $createdRelationships AS r
CALL apoc.util.kafka.send("neo4j-changes", null, {
event_type: "relationship_create",
timestamp: datetime().epochMillis,
data: {
relationship_id: id(r),
source_id: id(startNode(r)),
target_id: id(endNode(r)),
relationship_type: type(r),
properties: properties(r)
},
source: "neo4j-cdc"
})
RETURN count(*)',
{phase: 'after'})
""")
self.logger.info("Change data capture triggers installed")
except Exception as e:
self.logger.warning(f"Failed to install CDC triggers (APOC may not be available): {e}")
def create_real_time_analytics_pipeline(self, analytics_topic: str):
"""创建实时分析管道"""
def analytics_handler(event: GraphEvent):
"""分析事件处理器"""
try:
# 实时计算图指标
if event.event_type in [EventType.NODE_CREATE, EventType.RELATIONSHIP_CREATE]:
metrics = self._calculate_real_time_metrics()
# 发布分析结果
analytics_event = GraphEvent(
event_type=EventType.CUSTOM,
timestamp=datetime.now(),
data={
'metrics': metrics,
'trigger_event': event.event_id
},
source='real-time-analytics'
)
self.publish_event(analytics_topic, analytics_event)
except Exception as e:
self.logger.error(f"Analytics handler error: {e}")
# 注册分析处理器
self.add_event_handler(EventType.NODE_CREATE, analytics_handler)
self.add_event_handler(EventType.RELATIONSHIP_CREATE, analytics_handler)
self.logger.info(f"Real-time analytics pipeline created for topic {analytics_topic}")
def _calculate_real_time_metrics(self) -> Dict[str, Any]:
"""计算实时图指标"""
with self.neo4j_driver.session() as session:
# 节点数量
node_count = session.run("MATCH (n) RETURN count(n) as count").single()['count']
# 关系数量
rel_count = session.run("MATCH ()-[r]->() RETURN count(r) as count").single()['count']
# 平均度数
avg_degree_result = session.run("""
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, count(r) as degree
RETURN avg(degree) as avg_degree
""").single()
avg_degree = avg_degree_result['avg_degree'] if avg_degree_result['avg_degree'] else 0
return {
'timestamp': datetime.now().isoformat(),
'node_count': node_count,
'relationship_count': rel_count,
'average_degree': float(avg_degree),
'density': (2.0 * rel_count) / (node_count * (node_count - 1)) if node_count > 1 else 0
}
def close(self):
"""关闭连接"""
self.stop_streaming()
if self.producer:
self.producer.close()
if self.neo4j_driver:
self.neo4j_driver.close()
self.logger.info("Kafka and Neo4j connections closed")
# 使用示例
kafka_config = {
'bootstrap_servers': ['localhost:9092']
}
neo4j_config = {
'uri': 'bolt://localhost:7687',
'username': 'neo4j',
'password': 'password'
}
# 创建流处理器
streaming = Neo4jKafkaStreaming(kafka_config, neo4j_config)
# 添加自定义事件处理器
def custom_node_handler(event: GraphEvent):
print(f"Custom handler: Node created with data {event.data}")
streaming.add_event_handler(EventType.NODE_CREATE, custom_node_handler)
# 设置变更数据捕获
streaming.setup_change_data_capture()
# 创建实时分析管道
streaming.create_real_time_analytics_pipeline('graph-analytics')
# 启动流处理
streaming.start_streaming(['graph-events', 'neo4j-changes'])
# 发布测试事件
test_event = GraphEvent(
event_type=EventType.NODE_CREATE,
timestamp=datetime.now(),
data={
'label': 'TestNode',
'properties': {
'name': 'Test Node',
'created_at': datetime.now().isoformat()
}
},
source='test',
event_id='test-001'
)
streaming.publish_event('graph-events', test_event)
# 运行一段时间后关闭
time.sleep(30)
streaming.close()
8.3 与机器学习平台集成
8.3.1 图神经网络集成
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.data import Data, DataLoader
from neo4j import GraphDatabase
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from typing import Dict, List, Any, Tuple, Optional
import logging
from datetime import datetime
class Neo4jGNNIntegration:
"""Neo4j与图神经网络集成"""
def __init__(self, neo4j_config: Dict[str, Any]):
# Neo4j连接
self.driver = GraphDatabase.driver(
neo4j_config['uri'],
auth=(neo4j_config['username'], neo4j_config['password'])
)
# 数据预处理器
self.label_encoders = {}
self.feature_scaler = StandardScaler()
# 模型和数据
self.model = None
self.graph_data = None
self.node_mapping = {}
# 配置日志
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def extract_graph_data(self, node_query: str, edge_query: str,
feature_columns: List[str], target_column: Optional[str] = None) -> Data:
"""从Neo4j提取图数据"""
with self.driver.session() as session:
# 提取节点数据
self.logger.info("Extracting node data...")
node_result = session.run(node_query)
nodes_df = pd.DataFrame([record.data() for record in node_result])
# 提取边数据
self.logger.info("Extracting edge data...")
edge_result = session.run(edge_query)
edges_df = pd.DataFrame([record.data() for record in edge_result])
# 创建节点映射
unique_nodes = set(nodes_df['node_id'].unique()) | set(edges_df['source_id'].unique()) | set(edges_df['target_id'].unique())
self.node_mapping = {node_id: idx for idx, node_id in enumerate(sorted(unique_nodes))}
# 准备节点特征
num_nodes = len(self.node_mapping)
# 处理特征数据
if feature_columns:
# 确保所有节点都有特征数据
feature_data = []
for node_id in sorted(self.node_mapping.keys()):
node_row = nodes_df[nodes_df['node_id'] == node_id]
if not node_row.empty:
features = [node_row.iloc[0][col] for col in feature_columns]
else:
features = [0] * len(feature_columns) # 默认特征
feature_data.append(features)
# 标准化特征
feature_matrix = np.array(feature_data, dtype=np.float32)
feature_matrix = self.feature_scaler.fit_transform(feature_matrix)
x = torch.tensor(feature_matrix, dtype=torch.float)
else:
# 使用单位矩阵作为特征
x = torch.eye(num_nodes, dtype=torch.float)
# 处理边数据
edge_index = []
for _, row in edges_df.iterrows():
source_idx = self.node_mapping[row['source_id']]
target_idx = self.node_mapping[row['target_id']]
edge_index.append([source_idx, target_idx])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
# 处理标签数据
y = None
if target_column and target_column in nodes_df.columns:
labels = []
for node_id in sorted(self.node_mapping.keys()):
node_row = nodes_df[nodes_df['node_id'] == node_id]
if not node_row.empty and pd.notna(node_row.iloc[0][target_column]):
labels.append(node_row.iloc[0][target_column])
else:
labels.append(-1) # 未标记节点
# 编码标签
valid_labels = [label for label in labels if label != -1]
if valid_labels:
if target_column not in self.label_encoders:
self.label_encoders[target_column] = LabelEncoder()
self.label_encoders[target_column].fit(valid_labels)
encoded_labels = []
for label in labels:
if label != -1:
encoded_labels.append(self.label_encoders[target_column].transform([label])[0])
else:
encoded_labels.append(-1)
y = torch.tensor(encoded_labels, dtype=torch.long)
# 创建PyTorch Geometric数据对象
self.graph_data = Data(x=x, edge_index=edge_index, y=y)
self.logger.info(f"Extracted graph with {num_nodes} nodes and {edge_index.size(1)} edges")
return self.graph_data
def create_gnn_model(self, model_type: str, input_dim: int, hidden_dim: int,
output_dim: int, num_layers: int = 2, dropout: float = 0.5) -> nn.Module:
"""创建图神经网络模型"""
if model_type.lower() == 'gcn':
model = GCNModel(input_dim, hidden_dim, output_dim, num_layers, dropout)
elif model_type.lower() == 'sage':
model = GraphSAGEModel(input_dim, hidden_dim, output_dim, num_layers, dropout)
elif model_type.lower() == 'gat':
model = GATModel(input_dim, hidden_dim, output_dim, num_layers, dropout)
else:
raise ValueError(f"Unsupported model type: {model_type}")
self.model = model
self.logger.info(f"Created {model_type.upper()} model with {sum(p.numel() for p in model.parameters())} parameters")
return model
def train_model(self, train_mask: torch.Tensor, val_mask: torch.Tensor,
epochs: int = 200, lr: float = 0.01, weight_decay: float = 5e-4) -> Dict[str, List[float]]:
"""训练模型"""
if self.model is None or self.graph_data is None:
raise ValueError("Model and graph data must be initialized first")
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
train_losses = []
val_accuracies = []
self.model.train()
for epoch in range(epochs):
optimizer.zero_grad()
# 前向传播
out = self.model(self.graph_data.x, self.graph_data.edge_index)
# 计算训练损失
train_loss = criterion(out[train_mask], self.graph_data.y[train_mask])
# 反向传播
train_loss.backward()
optimizer.step()
# 验证
if epoch % 10 == 0:
self.model.eval()
with torch.no_grad():
val_out = self.model(self.graph_data.x, self.graph_data.edge_index)
val_pred = val_out[val_mask].argmax(dim=1)
val_acc = accuracy_score(self.graph_data.y[val_mask].cpu(), val_pred.cpu())
val_accuracies.append(val_acc)
self.model.train()
self.logger.info(f"Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f}")
train_losses.append(train_loss.item())
return {'train_losses': train_losses, 'val_accuracies': val_accuracies}
def predict_nodes(self, node_ids: List[Any]) -> Dict[Any, Any]:
"""预测节点标签"""
if self.model is None or self.graph_data is None:
raise ValueError("Model and graph data must be initialized first")
self.model.eval()
predictions = {}
with torch.no_grad():
out = self.model(self.graph_data.x, self.graph_data.edge_index)
probabilities = F.softmax(out, dim=1)
for node_id in node_ids:
if node_id in self.node_mapping:
node_idx = self.node_mapping[node_id]
pred_class = out[node_idx].argmax().item()
confidence = probabilities[node_idx].max().item()
# 解码预测结果
if hasattr(self, 'label_encoders') and self.label_encoders:
encoder = list(self.label_encoders.values())[0]
pred_label = encoder.inverse_transform([pred_class])[0]
else:
pred_label = pred_class
predictions[node_id] = {
'predicted_label': pred_label,
'confidence': confidence,
'probabilities': probabilities[node_idx].cpu().numpy().tolist()
}
else:
predictions[node_id] = {'error': 'Node not found in graph'}
return predictions
def generate_embeddings(self, layer_idx: int = -1) -> Dict[Any, np.ndarray]:
"""生成节点嵌入"""
if self.model is None or self.graph_data is None:
raise ValueError("Model and graph data must be initialized first")
self.model.eval()
embeddings = {}
with torch.no_grad():
# 获取指定层的嵌入
if hasattr(self.model, 'get_embeddings'):
node_embeddings = self.model.get_embeddings(self.graph_data.x, self.graph_data.edge_index, layer_idx)
else:
# 如果模型没有get_embeddings方法,使用最后一层前的输出
node_embeddings = self.model(self.graph_data.x, self.graph_data.edge_index)
# 映射回原始节点ID
for original_id, node_idx in self.node_mapping.items():
embeddings[original_id] = node_embeddings[node_idx].cpu().numpy()
return embeddings
def save_embeddings_to_neo4j(self, embeddings: Dict[Any, np.ndarray], property_prefix: str = 'embedding'):
"""将嵌入保存到Neo4j"""
with self.driver.session() as session:
for node_id, embedding in embeddings.items():
# 将嵌入转换为列表
embedding_list = embedding.tolist()
cypher = f"""
MATCH (n) WHERE id(n) = $node_id
SET n.{property_prefix} = $embedding
"""
session.run(cypher, {'node_id': node_id, 'embedding': embedding_list})
self.logger.info(f"Saved embeddings for {len(embeddings)} nodes to Neo4j")
def close(self):
"""关闭连接"""
if self.driver:
self.driver.close()
self.logger.info("Neo4j connection closed")
class GCNModel(nn.Module):
"""图卷积网络模型"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
num_layers: int = 2, dropout: float = 0.5):
super(GCNModel, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
# 构建层
self.convs = nn.ModuleList()
self.convs.append(GCNConv(input_dim, hidden_dim))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_dim, hidden_dim))
self.convs.append(GCNConv(hidden_dim, output_dim))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
def get_embeddings(self, x, edge_index, layer_idx=-1):
"""获取指定层的嵌入"""
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i == len(self.convs) + layer_idx: # 支持负索引
return x
if i < len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return x
class GraphSAGEModel(nn.Module):
"""GraphSAGE模型"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
num_layers: int = 2, dropout: float = 0.5):
super(GraphSAGEModel, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(input_dim, hidden_dim))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_dim, hidden_dim))
self.convs.append(SAGEConv(hidden_dim, output_dim))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
def get_embeddings(self, x, edge_index, layer_idx=-1):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i == len(self.convs) + layer_idx:
return x
if i < len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return x
class GATModel(nn.Module):
"""图注意力网络模型"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
num_layers: int = 2, dropout: float = 0.5, heads: int = 8):
super(GATModel, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.convs = nn.ModuleList()
self.convs.append(GATConv(input_dim, hidden_dim, heads=heads, dropout=dropout))
for _ in range(num_layers - 2):
self.convs.append(GATConv(hidden_dim * heads, hidden_dim, heads=heads, dropout=dropout))
self.convs.append(GATConv(hidden_dim * heads, output_dim, heads=1, dropout=dropout))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
def get_embeddings(self, x, edge_index, layer_idx=-1):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i == len(self.convs) + layer_idx:
return x
if i < len(self.convs) - 1:
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return x
# 使用示例
neo4j_config = {
'uri': 'bolt://localhost:7687',
'username': 'neo4j',
'password': 'password'
}
# 创建集成器
gnn_integration = Neo4jGNNIntegration(neo4j_config)
# 定义查询
node_query = """
MATCH (u:User)
RETURN id(u) as node_id, u.age as age, u.income as income, u.category as category
"""
edge_query = """
MATCH (u1:User)-[:FRIEND]->(u2:User)
RETURN id(u1) as source_id, id(u2) as target_id
"""
# 提取图数据
feature_columns = ['age', 'income']
target_column = 'category'
graph_data = gnn_integration.extract_graph_data(node_query, edge_query, feature_columns, target_column)
# 创建训练/验证掩码
num_nodes = graph_data.x.size(0)
labeled_nodes = (graph_data.y != -1).nonzero(dim=0).squeeze()
train_nodes, val_nodes = train_test_split(labeled_nodes.numpy(), test_size=0.3, random_state=42)
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[train_nodes] = True
val_mask[val_nodes] = True
# 创建和训练模型
input_dim = graph_data.x.size(1)
hidden_dim = 64
output_dim = len(torch.unique(graph_data.y[graph_data.y != -1]))
model = gnn_integration.create_gnn_model('gcn', input_dim, hidden_dim, output_dim)
training_history = gnn_integration.train_model(train_mask, val_mask, epochs=100)
# 预测新节点
test_node_ids = [1, 2, 3] # 示例节点ID
predictions = gnn_integration.predict_nodes(test_node_ids)
print(f"Predictions: {predictions}")
# 生成和保存嵌入
embeddings = gnn_integration.generate_embeddings()
gnn_integration.save_embeddings_to_neo4j(embeddings)
# 关闭连接
gnn_integration.close()
8.4 本章总结
8.4.1 核心知识点
关系型数据库集成
- 数据同步策略:全量、增量、实时同步
- 双向数据同步机制
- 变更数据捕获(CDC)
- 批量数据处理优化
大数据平台集成
- Apache Spark集成:大规模图数据分析
- Kafka流处理:实时图事件处理
- 分布式图计算和分析
- 实时指标计算和监控
机器学习平台集成
- 图神经网络(GNN)集成
- 节点分类和链接预测
- 图嵌入生成和应用
- 模型训练和推理管道
8.4.2 最佳实践
数据同步最佳实践
- 选择合适的同步模式和频率
- 实现幂等性操作
- 监控同步状态和错误处理
- 优化批量操作性能
流处理最佳实践
- 设计合理的事件模式
- 实现容错和重试机制
- 监控流处理性能
- 处理背压和流量控制
机器学习最佳实践
- 合理设计图特征工程
- 选择适合的GNN架构
- 实现模型版本管理
- 监控模型性能和漂移
8.4.3 架构设计原则
松耦合设计
- 使用消息队列解耦系统
- 实现标准化接口
- 支持异步处理
可扩展性
- 支持水平扩展
- 实现负载均衡
- 优化资源利用
可靠性
- 实现故障转移
- 数据一致性保证
- 监控和告警机制
8.4.4 性能优化策略
数据传输优化
- 批量操作减少网络开销
- 数据压缩和序列化优化
- 连接池管理
计算优化
- 并行处理和分布式计算
- 内存管理和缓存策略
- GPU加速(适用于ML工作负载)
存储优化
- 索引策略优化
- 数据分区和分片
- 冷热数据分离
8.5 练习题
8.5.1 基础练习
数据同步练习
- 实现一个简单的PostgreSQL到Neo4j的数据同步工具
- 支持增量同步和错误重试
- 添加同步状态监控
流处理练习
- 使用Kafka实现图事件的实时处理
- 实现节点创建、更新、删除事件的处理
- 添加事件去重和顺序保证
GNN练习
- 使用社交网络数据训练一个节点分类模型
- 实现用户兴趣预测
- 生成用户嵌入并保存到Neo4j
8.5.2 进阶练习
多源数据集成
- 集成多个关系型数据库到Neo4j
- 实现数据血缘追踪
- 处理数据冲突和一致性
实时推荐系统
- 结合Kafka和Neo4j构建实时推荐系统
- 实现用户行为实时更新
- 基于图结构生成推荐结果
图异常检测
- 使用GNN实现图异常检测
- 检测社交网络中的异常用户
- 实现在线学习和模型更新
8.5.3 项目练习
企业级数据集成平台
- 设计一个支持多种数据源的集成平台
- 实现可视化配置界面
- 支持数据质量监控和治理
智能图分析平台
- 结合Spark和Neo4j构建图分析平台
- 支持大规模图算法计算
- 实现交互式图可视化
图机器学习平台
- 构建端到端的图机器学习平台
- 支持多种GNN模型
- 实现模型训练、部署和监控
8.5.4 思考题
在设计Neo4j与关系型数据库的集成方案时,如何处理数据一致性问题?
如何选择合适的图神经网络架构来解决特定的业务问题?
在大规模图数据处理中,如何平衡计算性能和资源消耗?
如何设计一个可扩展的实时图事件处理架构?
在图机器学习项目中,如何评估和优化模型性能?
通过本章的学习,你应该掌握了Neo4j与其他技术栈的集成方法,能够构建复杂的图数据处理和分析系统。这些集成技术是构建现代数据驱动应用的重要基础。