3.1 Spark SQL概述

什么是Spark SQL

Spark SQL是Apache Spark的一个模块,用于处理结构化数据。它提供了一个编程抽象叫做DataFrame,并且可以作为分布式SQL查询引擎使用。Spark SQL将关系型处理与Spark的函数式编程API融合在一起,使得用户可以使用SQL查询、DataFrame API或者Dataset API来处理数据。

Spark SQL的核心特性

class SparkSQLOverview:
    """
    Spark SQL概述和核心特性演示
    """
    
    def __init__(self):
        self.setup_spark_session()
    
    def setup_spark_session(self):
        """
        创建SparkSession
        """
        from pyspark.sql import SparkSession
        
        # SparkSession是Spark SQL的入口点
        self.spark = SparkSession.builder \
            .appName("SparkSQLOverview") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
            .getOrCreate()
        
        print("SparkSession创建成功")
        print(f"Spark版本: {self.spark.version}")
        print(f"应用名称: {self.spark.sparkContext.appName}")
    
    def demonstrate_core_features(self):
        """
        演示Spark SQL的核心特性
        """
        print("\nSpark SQL核心特性演示:")
        print("=" * 25)
        
        # 1. 统一的数据访问
        print("\n1. 统一的数据访问:")
        
        # 创建示例数据
        data = [
            (1, "Alice", 25, "Engineer", 75000),
            (2, "Bob", 30, "Manager", 85000),
            (3, "Charlie", 35, "Director", 95000),
            (4, "Diana", 28, "Engineer", 78000),
            (5, "Eve", 32, "Manager", 88000)
        ]
        
        columns = ["id", "name", "age", "position", "salary"]
        
        # 创建DataFrame
        df = self.spark.createDataFrame(data, columns)
        
        print("DataFrame创建成功:")
        df.show()
        
        # 2. SQL查询支持
        print("\n2. SQL查询支持:")
        
        # 注册临时视图
        df.createOrReplaceTempView("employees")
        
        # 使用SQL查询
        sql_result = self.spark.sql("""
            SELECT position, 
                   COUNT(*) as count,
                   AVG(salary) as avg_salary,
                   MAX(age) as max_age
            FROM employees 
            GROUP BY position
            ORDER BY avg_salary DESC
        """)
        
        print("SQL查询结果:")
        sql_result.show()
        
        # 3. DataFrame API
        print("\n3. DataFrame API:")
        
        # 使用DataFrame API实现相同的查询
        from pyspark.sql.functions import count, avg, max as spark_max
        
        api_result = df.groupBy("position") \
                      .agg(count("*").alias("count"),
                           avg("salary").alias("avg_salary"),
                           spark_max("age").alias("max_age")) \
                      .orderBy("avg_salary", ascending=False)
        
        print("DataFrame API查询结果:")
        api_result.show()
        
        # 4. 多种数据源支持
        print("\n4. 多种数据源支持:")
        
        # 保存为不同格式
        temp_path = "/tmp/spark_demo"
        
        # JSON格式
        df.write.mode("overwrite").json(f"{temp_path}/employees.json")
        print("✓ 保存为JSON格式")
        
        # Parquet格式
        df.write.mode("overwrite").parquet(f"{temp_path}/employees.parquet")
        print("✓ 保存为Parquet格式")
        
        # CSV格式
        df.write.mode("overwrite").option("header", "true").csv(f"{temp_path}/employees.csv")
        print("✓ 保存为CSV格式")
        
        # 读取不同格式
        json_df = self.spark.read.json(f"{temp_path}/employees.json")
        parquet_df = self.spark.read.parquet(f"{temp_path}/employees.parquet")
        csv_df = self.spark.read.option("header", "true").option("inferSchema", "true").csv(f"{temp_path}/employees.csv")
        
        print("\n从不同格式读取数据:")
        print(f"JSON记录数: {json_df.count()}")
        print(f"Parquet记录数: {parquet_df.count()}")
        print(f"CSV记录数: {csv_df.count()}")
        
        # 5. 性能优化
        print("\n5. 性能优化特性:")
        
        # 查看执行计划
        print("\n查询执行计划:")
        optimized_query = df.filter(df.salary > 80000).select("name", "position", "salary")
        optimized_query.explain(True)
        
        return {
            'dataframe_count': df.count(),
            'sql_result_count': sql_result.count(),
            'api_result_count': api_result.count()
        }
    
    def compare_with_rdd(self):
        """
        比较DataFrame与RDD的差异
        """
        print("\n\nDataFrame vs RDD比较:")
        print("=" * 25)
        
        import time
        
        # 创建测试数据
        data = [(i, f"user_{i}", i % 100, i * 1000) for i in range(100000)]
        columns = ["id", "name", "age", "salary"]
        
        # DataFrame方式
        print("\n1. DataFrame方式:")
        start_time = time.time()
        
        df = self.spark.createDataFrame(data, columns)
        result_df = df.filter(df.salary > 50000) \
                     .groupBy("age") \
                     .agg({"salary": "avg"}) \
                     .orderBy("age")
        
        df_count = result_df.count()
        df_time = time.time() - start_time
        
        print(f"DataFrame处理时间: {df_time:.3f} 秒")
        print(f"结果记录数: {df_count}")
        
        # RDD方式
        print("\n2. RDD方式:")
        start_time = time.time()
        
        rdd = self.spark.sparkContext.parallelize(data)
        result_rdd = rdd.filter(lambda x: x[3] > 50000) \
                       .map(lambda x: (x[2], x[3])) \
                       .groupByKey() \
                       .mapValues(lambda values: sum(values) / len(list(values))) \
                       .sortByKey()
        
        rdd_count = result_rdd.count()
        rdd_time = time.time() - start_time
        
        print(f"RDD处理时间: {rdd_time:.3f} 秒")
        print(f"结果记录数: {rdd_count}")
        
        # 性能对比
        print("\n3. 性能对比:")
        if df_time < rdd_time:
            improvement = ((rdd_time - df_time) / rdd_time) * 100
            print(f"DataFrame比RDD快 {improvement:.1f}%")
        else:
            degradation = ((df_time - rdd_time) / rdd_time) * 100
            print(f"DataFrame比RDD慢 {degradation:.1f}%")
        
        # 特性对比
        print("\n4. 特性对比:")
        comparison = {
            "特性": ["类型安全", "性能优化", "易用性", "表达能力", "调试难度"],
            "DataFrame": ["编译时检查", "Catalyst优化", "高级API", "SQL+API", "容易"],
            "RDD": ["运行时检查", "手动优化", "底层API", "函数式", "困难"]
        }
        
        for i, feature in enumerate(comparison["特性"]):
            print(f"{feature:<10}: DataFrame({comparison['DataFrame'][i]:<12}) vs RDD({comparison['RDD'][i]})")
        
        return {
            'df_time': df_time,
            'rdd_time': rdd_time,
            'df_count': df_count,
            'rdd_count': rdd_count
        }
    
    def visualize_spark_sql_architecture(self):
        """
        可视化Spark SQL架构
        """
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
        
        # 1. Spark SQL架构图
        ax1.set_xlim(0, 10)
        ax1.set_ylim(0, 10)
        ax1.set_title('Spark SQL架构', fontsize=14, fontweight='bold')
        
        # 绘制架构层次
        layers = [
            {"name": "SQL/DataFrame/Dataset API", "y": 8.5, "color": "lightblue"},
            {"name": "Catalyst Optimizer", "y": 7, "color": "lightgreen"},
            {"name": "Tungsten Execution Engine", "y": 5.5, "color": "lightyellow"},
            {"name": "Spark Core (RDD)", "y": 4, "color": "lightcoral"},
            {"name": "Cluster Manager", "y": 2.5, "color": "lightgray"}
        ]
        
        for layer in layers:
            rect = patches.Rectangle((1, layer["y"]-0.4), 8, 0.8, 
                                   linewidth=1, edgecolor='black', 
                                   facecolor=layer["color"], alpha=0.7)
            ax1.add_patch(rect)
            ax1.text(5, layer["y"], layer["name"], ha='center', va='center', 
                    fontsize=10, fontweight='bold')
        
        # 添加箭头表示数据流
        for i in range(len(layers)-1):
            ax1.arrow(5, layers[i]["y"]-0.5, 0, -0.6, head_width=0.2, 
                     head_length=0.1, fc='red', ec='red')
        
        # 添加数据源
        data_sources = ["JSON", "Parquet", "CSV", "JDBC", "Hive"]
        for i, source in enumerate(data_sources):
            x = 1.5 + i * 1.5
            rect = patches.Rectangle((x-0.3, 0.5), 0.6, 0.8, 
                                   linewidth=1, edgecolor='blue', 
                                   facecolor='lightsteelblue', alpha=0.7)
            ax1.add_patch(rect)
            ax1.text(x, 0.9, source, ha='center', va='center', fontsize=8)
            ax1.arrow(x, 1.4, 0, 0.8, head_width=0.1, head_length=0.1, 
                     fc='blue', ec='blue', alpha=0.7)
        
        ax1.text(5, 1.5, "数据源", ha='center', va='center', 
                fontsize=12, fontweight='bold')
        
        ax1.set_xticks([])
        ax1.set_yticks([])
        ax1.spines['top'].set_visible(False)
        ax1.spines['right'].set_visible(False)
        ax1.spines['bottom'].set_visible(False)
        ax1.spines['left'].set_visible(False)
        
        # 2. 查询执行流程
        ax2.set_xlim(0, 10)
        ax2.set_ylim(0, 10)
        ax2.set_title('查询执行流程', fontsize=14, fontweight='bold')
        
        # 执行步骤
        steps = [
            {"name": "SQL/DataFrame\n查询", "pos": (2, 9), "color": "lightblue"},
            {"name": "解析\n(Parser)", "pos": (2, 7.5), "color": "lightgreen"},
            {"name": "分析\n(Analyzer)", "pos": (2, 6), "color": "lightyellow"},
            {"name": "逻辑优化\n(Optimizer)", "pos": (5, 6), "color": "lightcoral"},
            {"name": "物理计划\n(Planner)", "pos": (8, 6), "color": "lightgray"},
            {"name": "代码生成\n(CodeGen)", "pos": (8, 4.5), "color": "lightsteelblue"},
            {"name": "执行\n(Execute)", "pos": (8, 3), "color": "lightpink"},
            {"name": "结果", "pos": (5, 1.5), "color": "lightcyan"}
        ]
        
        # 绘制步骤
        for step in steps:
            circle = patches.Circle(step["pos"], 0.8, linewidth=1, 
                                  edgecolor='black', facecolor=step["color"], alpha=0.7)
            ax2.add_patch(circle)
            ax2.text(step["pos"][0], step["pos"][1], step["name"], 
                    ha='center', va='center', fontsize=9, fontweight='bold')
        
        # 绘制箭头
        arrows = [
            ((2, 8.2), (2, 7.8)),  # 查询 -> 解析
            ((2, 6.8), (2, 6.8)),  # 解析 -> 分析
            ((2.8, 6), (4.2, 6)),  # 分析 -> 优化
            ((5.8, 6), (7.2, 6)),  # 优化 -> 计划
            ((8, 5.2), (8, 5.3)),  # 计划 -> 代码生成
            ((8, 3.8), (8, 3.8)),  # 代码生成 -> 执行
            ((7.2, 3), (5.8, 1.5)) # 执行 -> 结果
        ]
        
        for start, end in arrows:
            ax2.annotate('', xy=end, xytext=start,
                        arrowprops=dict(arrowstyle='->', lw=1.5, color='red'))
        
        ax2.set_xticks([])
        ax2.set_yticks([])
        ax2.spines['top'].set_visible(False)
        ax2.spines['right'].set_visible(False)
        ax2.spines['bottom'].set_visible(False)
        ax2.spines['left'].set_visible(False)
        
        plt.tight_layout()
        plt.show()

# Spark SQL概述演示
spark_sql_overview = SparkSQLOverview()

print("Spark SQL概述")
print("=" * 15)

# 演示核心特性
core_features_result = spark_sql_overview.demonstrate_core_features()

# 比较DataFrame和RDD
comparison_result = spark_sql_overview.compare_with_rdd()

# 可视化架构
spark_sql_overview.visualize_spark_sql_architecture()

print("\n概述总结:")
print("=" * 10)
print(f"✓ DataFrame记录数: {core_features_result['dataframe_count']}")
print(f"✓ SQL查询结果数: {core_features_result['sql_result_count']}")
print(f"✓ API查询结果数: {core_features_result['api_result_count']}")
print(f"✓ DataFrame处理时间: {comparison_result['df_time']:.3f} 秒")
print(f"✓ RDD处理时间: {comparison_result['rdd_time']:.3f} 秒")

print("\nSpark SQL优势:")
print("- 统一的数据访问接口")
print("- 强大的查询优化器(Catalyst)")
print("- 多种数据源支持")
print("- SQL和编程API的结合")
print("- 更好的性能和易用性")

3.2 DataFrame基础

DataFrame概念和特性

DataFrame是Spark SQL的核心抽象,它是一个分布式的数据集合,以命名列的形式组织。DataFrame概念上等同于关系数据库中的表或者R/Python中的data frame,但在底层进行了更多优化。

class DataFrameBasics:
    """
    DataFrame基础操作演示
    """
    
    def __init__(self):
        from pyspark.sql import SparkSession
        
        self.spark = SparkSession.builder \
            .appName("DataFrameBasics") \
            .getOrCreate()
    
    def create_dataframes(self):
        """
        演示DataFrame的多种创建方式
        """
        print("\nDataFrame创建方式:")
        print("=" * 20)
        
        # 1. 从Python数据结构创建
        print("\n1. 从Python列表创建:")
        
        data = [
            ("Alice", 25, "Engineer", 75000, "2020-01-15"),
            ("Bob", 30, "Manager", 85000, "2019-03-20"),
            ("Charlie", 35, "Director", 95000, "2018-07-10"),
            ("Diana", 28, "Engineer", 78000, "2021-02-28"),
            ("Eve", 32, "Manager", 88000, "2020-11-05")
        ]
        
        columns = ["name", "age", "position", "salary", "hire_date"]
        
        df1 = self.spark.createDataFrame(data, columns)
        print("从列表创建的DataFrame:")
        df1.show()
        
        # 2. 从字典列表创建
        print("\n2. 从字典列表创建:")
        
        dict_data = [
            {"name": "Frank", "age": 29, "position": "Analyst", "salary": 65000},
            {"name": "Grace", "age": 31, "position": "Engineer", "salary": 76000},
            {"name": "Henry", "age": 27, "position": "Intern", "salary": 45000}
        ]
        
        df2 = self.spark.createDataFrame(dict_data)
        print("从字典列表创建的DataFrame:")
        df2.show()
        
        # 3. 使用Schema定义
        print("\n3. 使用Schema定义:")
        
        from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType
        from pyspark.sql.functions import to_date
        
        # 定义Schema
        schema = StructType([
            StructField("name", StringType(), True),
            StructField("age", IntegerType(), True),
            StructField("position", StringType(), True),
            StructField("salary", DoubleType(), True),
            StructField("hire_date", StringType(), True)
        ])
        
        df3 = self.spark.createDataFrame(data, schema)
        
        # 转换日期类型
        df3 = df3.withColumn("hire_date", to_date(df3.hire_date, "yyyy-MM-dd"))
        
        print("使用Schema创建的DataFrame:")
        df3.show()
        df3.printSchema()
        
        # 4. 从RDD创建
        print("\n4. 从RDD创建:")
        
        rdd = self.spark.sparkContext.parallelize(data)
        df4 = rdd.toDF(columns)
        
        print("从RDD创建的DataFrame:")
        df4.show()
        
        # 5. 从外部数据源创建(模拟)
        print("\n5. 从外部数据源创建:")
        
        # 保存为临时文件
        temp_path = "/tmp/sample_data"
        df1.write.mode("overwrite").option("header", "true").csv(temp_path)
        
        # 从CSV读取
        df5 = self.spark.read.option("header", "true") \
                           .option("inferSchema", "true") \
                           .csv(temp_path)
        
        print("从CSV文件创建的DataFrame:")
        df5.show()
        
        return df3  # 返回带有正确Schema的DataFrame
    
    def explore_dataframe_structure(self, df):
        """
        探索DataFrame结构
        """
        print("\n\nDataFrame结构探索:")
        print("=" * 20)
        
        # 1. 基本信息
        print("\n1. 基本信息:")
        print(f"行数: {df.count()}")
        print(f"列数: {len(df.columns)}")
        print(f"列名: {df.columns}")
        
        # 2. Schema信息
        print("\n2. Schema信息:")
        df.printSchema()
        
        # 3. 数据类型
        print("\n3. 数据类型:")
        for col_name, col_type in df.dtypes:
            print(f"  {col_name}: {col_type}")
        
        # 4. 统计信息
        print("\n4. 统计信息:")
        df.describe().show()
        
        # 5. 数值列的详细统计
        print("\n5. 数值列详细统计:")
        numeric_cols = [col for col, dtype in df.dtypes if dtype in ['int', 'double', 'float']]
        
        if numeric_cols:
            from pyspark.sql.functions import mean, stddev, min as spark_min, max as spark_max, count
            
            stats_df = df.select(
                [mean(col).alias(f"{col}_mean") for col in numeric_cols] +
                [stddev(col).alias(f"{col}_stddev") for col in numeric_cols] +
                [spark_min(col).alias(f"{col}_min") for col in numeric_cols] +
                [spark_max(col).alias(f"{col}_max") for col in numeric_cols]
            )
            
            stats_df.show()
        
        # 6. 空值检查
        print("\n6. 空值检查:")
        from pyspark.sql.functions import col, isnan, when, count as spark_count
        
        null_counts = df.select([spark_count(when(col(c).isNull() | isnan(col(c)), c)).alias(c) 
                                for c in df.columns])
        null_counts.show()
        
        # 7. 唯一值统计
        print("\n7. 唯一值统计:")
        for column in df.columns:
            unique_count = df.select(column).distinct().count()
            print(f"  {column}: {unique_count} 个唯一值")
        
        return {
            'row_count': df.count(),
            'column_count': len(df.columns),
            'numeric_columns': numeric_cols
        }
    
    def basic_operations(self, df):
        """
        DataFrame基本操作
        """
        print("\n\nDataFrame基本操作:")
        print("=" * 20)
        
        # 1. 选择列
        print("\n1. 选择列:")
        
        # 选择单列
        print("选择单列(name):")
        df.select("name").show()
        
        # 选择多列
        print("选择多列(name, age, salary):")
        df.select("name", "age", "salary").show()
        
        # 使用列表选择
        print("使用列表选择:")
        df.select(["name", "position"]).show()
        
        # 2. 过滤行
        print("\n2. 过滤行:")
        
        # 单条件过滤
        print("年龄大于30的员工:")
        df.filter(df.age > 30).show()
        
        # 多条件过滤
        print("年龄大于25且薪资大于75000的员工:")
        df.filter((df.age > 25) & (df.salary > 75000)).show()
        
        # 使用SQL表达式过滤
        print("使用SQL表达式过滤:")
        df.filter("position = 'Engineer' OR position = 'Manager'").show()
        
        # 3. 添加新列
        print("\n3. 添加新列:")
        
        from pyspark.sql.functions import col, when, year, current_date
        
        # 添加年薪列
        df_with_annual = df.withColumn("annual_salary", col("salary") * 12)
        
        # 添加薪资等级列
        df_with_grade = df_with_annual.withColumn(
            "salary_grade",
            when(col("salary") >= 90000, "High")
            .when(col("salary") >= 75000, "Medium")
            .otherwise("Low")
        )
        
        # 添加工作年限列(基于入职日期)
        df_with_experience = df_with_grade.withColumn(
            "years_of_service",
            (year(current_date()) - year(col("hire_date")))
        )
        
        print("添加新列后的DataFrame:")
        df_with_experience.show()
        
        # 4. 重命名列
        print("\n4. 重命名列:")
        
        df_renamed = df_with_experience.withColumnRenamed("name", "employee_name") \
                                      .withColumnRenamed("age", "employee_age")
        
        print("重命名列后的DataFrame:")
        df_renamed.select("employee_name", "employee_age", "position").show()
        
        # 5. 删除列
        print("\n5. 删除列:")
        
        df_dropped = df_with_experience.drop("annual_salary")
        
        print("删除annual_salary列后:")
        print(f"原列数: {len(df_with_experience.columns)}")
        print(f"新列数: {len(df_dropped.columns)}")
        print(f"剩余列: {df_dropped.columns}")
        
        # 6. 排序
        print("\n6. 排序:")
        
        # 按单列排序
        print("按薪资升序排序:")
        df.orderBy("salary").show()
        
        # 按多列排序
        print("按职位和薪资排序:")
        df.orderBy(["position", "salary"], ascending=[True, False]).show()
        
        # 7. 去重
        print("\n7. 去重:")
        
        # 整行去重
        print(f"原始行数: {df.count()}")
        print(f"去重后行数: {df.distinct().count()}")
        
        # 按特定列去重
        print("按职位去重:")
        df.select("position").distinct().show()
        
        return df_with_experience
    
    def demonstrate_column_operations(self, df):
        """
        演示列操作
        """
        print("\n\n列操作演示:")
        print("=" * 15)
        
        from pyspark.sql.functions import (
            col, lit, concat, upper, lower, length, substring,
            round as spark_round, abs as spark_abs, sqrt, 
            when, regexp_replace, split
        )
        
        # 1. 字符串操作
        print("\n1. 字符串操作:")
        
        string_ops_df = df.select(
            col("name"),
            upper(col("name")).alias("name_upper"),
            lower(col("name")).alias("name_lower"),
            length(col("name")).alias("name_length"),
            substring(col("name"), 1, 3).alias("name_prefix"),
            concat(col("name"), lit(" - "), col("position")).alias("name_position")
        )
        
        print("字符串操作结果:")
        string_ops_df.show(truncate=False)
        
        # 2. 数值操作
        print("\n2. 数值操作:")
        
        numeric_ops_df = df.select(
            col("name"),
            col("salary"),
            spark_round(col("salary") / 1000, 2).alias("salary_k"),
            (col("salary") * 0.8).alias("after_tax_salary"),
            sqrt(col("salary")).alias("salary_sqrt"),
            spark_abs(col("age") - 30).alias("age_diff_from_30")
        )
        
        print("数值操作结果:")
        numeric_ops_df.show()
        
        # 3. 条件操作
        print("\n3. 条件操作:")
        
        conditional_df = df.select(
            col("name"),
            col("age"),
            col("salary"),
            when(col("age") < 30, "Young")
            .when(col("age") < 35, "Middle")
            .otherwise("Senior").alias("age_group"),
            
            when(col("salary") > 85000, col("salary") * 1.1)
            .otherwise(col("salary")).alias("adjusted_salary")
        )
        
        print("条件操作结果:")
        conditional_df.show()
        
        # 4. 正则表达式操作
        print("\n4. 正则表达式操作:")
        
        # 添加一些包含特殊字符的数据进行演示
        special_data = [
            ("John-Doe", "john.doe@email.com", "123-456-7890"),
            ("Jane_Smith", "jane_smith@company.org", "987-654-3210"),
            ("Bob O'Connor", "bob.oconnor@test.net", "555-123-4567")
        ]
        
        special_df = self.spark.createDataFrame(special_data, ["full_name", "email", "phone"])
        
        regex_df = special_df.select(
            col("full_name"),
            regexp_replace(col("full_name"), "[-_']", " ").alias("cleaned_name"),
            regexp_replace(col("email"), "@.*", "").alias("username"),
            regexp_replace(col("phone"), "-", "").alias("phone_no_dash"),
            split(col("email"), "@").alias("email_parts")
        )
        
        print("正则表达式操作结果:")
        regex_df.show(truncate=False)
        
        return conditional_df
    
    def visualize_dataframe_operations(self, df):
        """
        可视化DataFrame操作
        """
        import matplotlib.pyplot as plt
        import numpy as np
        
        # 收集数据用于可视化
        data_for_viz = df.select("name", "age", "salary", "position").collect()
        
        names = [row.name for row in data_for_viz]
        ages = [row.age for row in data_for_viz]
        salaries = [row.salary for row in data_for_viz]
        positions = [row.position for row in data_for_viz]
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # 1. 薪资分布
        ax1 = axes[0, 0]
        bars = ax1.bar(names, salaries, color=['skyblue', 'lightgreen', 'lightcoral', 'gold', 'lightpink'])
        ax1.set_title('员工薪资分布', fontweight='bold')
        ax1.set_xlabel('员工姓名')
        ax1.set_ylabel('薪资')
        ax1.tick_params(axis='x', rotation=45)
        
        # 添加数值标签
        for bar in bars:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 1000,
                    f'${height:,.0f}', ha='center', va='bottom', fontsize=9)
        
        # 2. 年龄vs薪资散点图
        ax2 = axes[0, 1]
        colors = ['red' if pos == 'Director' else 'blue' if pos == 'Manager' else 'green' 
                 for pos in positions]
        scatter = ax2.scatter(ages, salaries, c=colors, s=100, alpha=0.7)
        ax2.set_title('年龄 vs 薪资', fontweight='bold')
        ax2.set_xlabel('年龄')
        ax2.set_ylabel('薪资')
        
        # 添加标签
        for i, name in enumerate(names):
            ax2.annotate(name, (ages[i], salaries[i]), xytext=(5, 5), 
                        textcoords='offset points', fontsize=8)
        
        # 添加图例
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', 
                                     markersize=8, label='Director'),
                          plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', 
                                     markersize=8, label='Manager'),
                          plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='green', 
                                     markersize=8, label='Engineer')]
        ax2.legend(handles=legend_elements)
        
        # 3. 职位分布饼图
        ax3 = axes[1, 0]
        position_counts = {}
        for pos in positions:
            position_counts[pos] = position_counts.get(pos, 0) + 1
        
        labels = list(position_counts.keys())
        sizes = list(position_counts.values())
        colors_pie = ['gold', 'lightcoral', 'lightskyblue', 'lightgreen']
        
        wedges, texts, autotexts = ax3.pie(sizes, labels=labels, colors=colors_pie[:len(labels)], 
                                          autopct='%1.1f%%', startangle=90)
        ax3.set_title('职位分布', fontweight='bold')
        
        # 4. DataFrame操作流程图
        ax4 = axes[1, 1]
        ax4.axis('off')
        
        # 创建操作流程
        operations = [
            "DataFrame基本操作流程",
            "=" * 25,
            "",
            "1. 数据创建",
            "   ├─ createDataFrame()",
            "   ├─ read.csv/json/parquet()",
            "   └─ sql()",
            "",
            "2. 数据探索",
            "   ├─ show() - 显示数据",
            "   ├─ printSchema() - 查看结构",
            "   ├─ describe() - 统计信息",
            "   └─ count() - 计数",
            "",
            "3. 数据转换",
            "   ├─ select() - 选择列",
            "   ├─ filter() - 过滤行",
            "   ├─ withColumn() - 添加列",
            "   ├─ drop() - 删除列",
            "   └─ orderBy() - 排序",
            "",
            "4. 数据聚合",
            "   ├─ groupBy() - 分组",
            "   ├─ agg() - 聚合函数",
            "   └─ pivot() - 透视",
            "",
            "5. 数据输出",
            "   ├─ show() - 显示结果",
            "   ├─ collect() - 收集到Driver",
            "   └─ write() - 保存文件"
        ]
        
        y_pos = 0.95
        for line in operations:
            if line.startswith("="):
                ax4.text(0.05, y_pos, line, fontsize=10, fontweight='bold', 
                        transform=ax4.transAxes)
            elif line.startswith("DataFrame基本操作流程"):
                ax4.text(0.05, y_pos, line, fontsize=12, fontweight='bold', 
                        transform=ax4.transAxes)
            elif line.startswith(("1.", "2.", "3.", "4.", "5.")):
                ax4.text(0.05, y_pos, line, fontsize=10, fontweight='bold', 
                        color='blue', transform=ax4.transAxes)
            else:
                ax4.text(0.05, y_pos, line, fontsize=9, fontfamily='monospace',
                        transform=ax4.transAxes)
            y_pos -= 0.032
        
        plt.tight_layout()
        plt.show()

# DataFrame基础演示
df_basics = DataFrameBasics()

print("DataFrame基础操作")
print("=" * 18)

# 创建DataFrame
df = df_basics.create_dataframes()

# 探索DataFrame结构
structure_info = df_basics.explore_dataframe_structure(df)

# 基本操作
enhanced_df = df_basics.basic_operations(df)

# 列操作
column_ops_df = df_basics.demonstrate_column_operations(df)

# 可视化
df_basics.visualize_dataframe_operations(df)

print("\nDataFrame基础总结:")
print("=" * 18)
print(f"✓ 总行数: {structure_info['row_count']}")
print(f"✓ 总列数: {structure_info['column_count']}")
print(f"✓ 数值列数: {len(structure_info['numeric_columns'])}")
print(f"✓ 数值列: {structure_info['numeric_columns']}")

print("\n核心概念:")
print("- DataFrame是分布式的结构化数据集")
print("- 支持丰富的数据操作和转换")
print("- 提供类型安全和查询优化")
print("- 可以与SQL无缝集成")
print("- 支持多种数据源和格式")

3.3 DataFrame高级操作

聚合和分组操作

class DataFrameAdvancedOperations:
    """
    DataFrame高级操作演示
    """
    
    def __init__(self):
        from pyspark.sql import SparkSession
        
        self.spark = SparkSession.builder \
            .appName("DataFrameAdvanced") \
            .getOrCreate()
        
        # 创建更大的示例数据集
        self.df = self.create_sample_dataset()
    
    def create_sample_dataset(self):
        """
        创建示例数据集
        """
        import random
        from datetime import datetime, timedelta
        
        # 生成更多样化的员工数据
        departments = ["Engineering", "Sales", "Marketing", "HR", "Finance"]
        positions = {
            "Engineering": ["Junior Engineer", "Senior Engineer", "Tech Lead", "Engineering Manager"],
            "Sales": ["Sales Rep", "Senior Sales Rep", "Sales Manager", "Sales Director"],
            "Marketing": ["Marketing Specialist", "Marketing Manager", "Marketing Director"],
            "HR": ["HR Specialist", "HR Manager", "HR Director"],
            "Finance": ["Financial Analyst", "Senior Analyst", "Finance Manager", "CFO"]
        }
        
        cities = ["New York", "San Francisco", "Los Angeles", "Chicago", "Boston", "Seattle"]
        
        data = []
        base_date = datetime(2018, 1, 1)
        
        for i in range(100):
            dept = random.choice(departments)
            position = random.choice(positions[dept])
            
            # 根据职位级别设置薪资范围
            if "Director" in position or "CFO" in position:
                salary_range = (120000, 180000)
            elif "Manager" in position or "Lead" in position:
                salary_range = (90000, 130000)
            elif "Senior" in position:
                salary_range = (70000, 100000)
            else:
                salary_range = (50000, 80000)
            
            hire_date = base_date + timedelta(days=random.randint(0, 1825))  # 5年范围
            
            data.append((
                i + 1,  # employee_id
                f"Employee_{i+1}",  # name
                random.randint(22, 60),  # age
                dept,  # department
                position,  # position
                random.randint(*salary_range),  # salary
                random.choice(cities),  # city
                hire_date.strftime("%Y-%m-%d"),  # hire_date
                random.choice(["Male", "Female"]),  # gender
                random.choice(["Bachelor", "Master", "PhD"])  # education
            ))
        
        columns = ["employee_id", "name", "age", "department", "position", 
                  "salary", "city", "hire_date", "gender", "education"]
        
        df = self.spark.createDataFrame(data, columns)
        
        # 转换数据类型
        from pyspark.sql.functions import to_date
        df = df.withColumn("hire_date", to_date(df.hire_date, "yyyy-MM-dd"))
        
        print("创建示例数据集:")
        print(f"总记录数: {df.count()}")
        df.show(10)
        
        return df
    
    def demonstrate_groupby_operations(self):
        """
        演示分组操作
        """
        print("\n\n分组操作演示:")
        print("=" * 15)
        
        from pyspark.sql.functions import (
            count, sum as spark_sum, avg, max as spark_max, min as spark_min,
            stddev, collect_list, collect_set, first, last
        )
        
        # 1. 基本分组统计
        print("\n1. 按部门分组统计:")
        
        dept_stats = self.df.groupBy("department") \
            .agg(
                count("*").alias("employee_count"),
                avg("salary").alias("avg_salary"),
                spark_max("salary").alias("max_salary"),
                spark_min("salary").alias("min_salary"),
                stddev("salary").alias("salary_stddev")
            ) \
            .orderBy("avg_salary", ascending=False)
        
        dept_stats.show()
        
        # 2. 多列分组
        print("\n2. 按部门和城市分组:")
        
        dept_city_stats = self.df.groupBy("department", "city") \
            .agg(
                count("*").alias("count"),
                avg("salary").alias("avg_salary")
            ) \
            .orderBy(["department", "avg_salary"], ascending=[True, False])
        
        dept_city_stats.show()
        
        # 3. 条件聚合
        print("\n3. 条件聚合:")
        
        from pyspark.sql.functions import when, col
        
        conditional_agg = self.df.groupBy("department") \
            .agg(
                count("*").alias("total_employees"),
                spark_sum(when(col("gender") == "Male", 1).otherwise(0)).alias("male_count"),
                spark_sum(when(col("gender") == "Female", 1).otherwise(0)).alias("female_count"),
                spark_sum(when(col("salary") > 80000, 1).otherwise(0)).alias("high_salary_count"),
                avg(when(col("age") < 30, col("salary"))).alias("young_avg_salary")
            )
        
        conditional_agg.show()
        
        # 4. 收集聚合
        print("\n4. 收集聚合:")
        
        collection_agg = self.df.groupBy("department") \
            .agg(
                collect_list("position").alias("all_positions"),
                collect_set("position").alias("unique_positions"),
                collect_set("city").alias("cities")
            )
        
        collection_agg.show(truncate=False)
        
        # 5. 窗口函数准备
        print("\n5. 分组内排名:")
        
        from pyspark.sql.window import Window
        from pyspark.sql.functions import row_number, rank, dense_rank
        
        # 定义窗口
        window_spec = Window.partitionBy("department").orderBy(col("salary").desc())
        
        ranked_df = self.df.select(
            "name", "department", "position", "salary",
            row_number().over(window_spec).alias("row_number"),
            rank().over(window_spec).alias("rank"),
            dense_rank().over(window_spec).alias("dense_rank")
        )
        
        print("每个部门薪资排名前3:")
        ranked_df.filter(col("row_number") <= 3).show()
        
        return dept_stats
    
    def demonstrate_window_functions(self):
        """
        演示窗口函数
        """
        print("\n\n窗口函数演示:")
        print("=" * 15)
        
        from pyspark.sql.window import Window
        from pyspark.sql.functions import (
            row_number, rank, dense_rank, lag, lead, 
            sum as spark_sum, avg, max as spark_max,
            percent_rank, ntile, col
        )
        
        # 1. 排名函数
        print("\n1. 排名函数:")
        
        # 按部门分区,按薪资排序的窗口
        salary_window = Window.partitionBy("department").orderBy(col("salary").desc())
        
        ranking_df = self.df.select(
            "name", "department", "salary",
            row_number().over(salary_window).alias("row_num"),
            rank().over(salary_window).alias("rank"),
            dense_rank().over(salary_window).alias("dense_rank"),
            percent_rank().over(salary_window).alias("percent_rank"),
            ntile(4).over(salary_window).alias("quartile")
        )
        
        ranking_df.filter(col("department") == "Engineering").show()
        
        # 2. 偏移函数
        print("\n2. 偏移函数:")
        
        # 按入职日期排序的窗口
        date_window = Window.partitionBy("department").orderBy("hire_date")
        
        offset_df = self.df.select(
            "name", "department", "hire_date", "salary",
            lag("salary", 1).over(date_window).alias("prev_salary"),
            lead("salary", 1).over(date_window).alias("next_salary"),
            lag("hire_date", 1).over(date_window).alias("prev_hire_date")
        )
        
        # 计算薪资变化
        from pyspark.sql.functions import when, isnan, isnull
        
        salary_change_df = offset_df.withColumn(
            "salary_change",
            when(col("prev_salary").isNotNull(), 
                 col("salary") - col("prev_salary")).otherwise(0)
        )
        
        print("薪资变化分析:")
        salary_change_df.filter(col("department") == "Engineering").show()
        
        # 3. 聚合窗口函数
        print("\n3. 聚合窗口函数:")
        
        # 累计和运行平均
        cumulative_window = Window.partitionBy("department").orderBy("hire_date") \
                                 .rowsBetween(Window.unboundedPreceding, Window.currentRow)
        
        # 移动平均(前后各2行)
        moving_window = Window.partitionBy("department").orderBy("hire_date") \
                             .rowsBetween(-2, 2)
        
        aggregation_df = self.df.select(
            "name", "department", "hire_date", "salary",
            spark_sum("salary").over(cumulative_window).alias("cumulative_salary"),
            avg("salary").over(cumulative_window).alias("running_avg_salary"),
            avg("salary").over(moving_window).alias("moving_avg_salary"),
            spark_max("salary").over(cumulative_window).alias("max_salary_so_far")
        )
        
        print("累计和移动统计:")
        aggregation_df.filter(col("department") == "Sales").show()
        
        # 4. 百分位数和分位数
        print("\n4. 分位数分析:")
        
        # 计算每个部门的薪资分位数
        quartile_window = Window.partitionBy("department").orderBy("salary")
        
        quartile_df = self.df.select(
            "name", "department", "salary",
            ntile(4).over(quartile_window).alias("salary_quartile"),
            percent_rank().over(quartile_window).alias("salary_percentile")
        )
        
        # 统计每个分位数的情况
        quartile_stats = quartile_df.groupBy("department", "salary_quartile") \
            .agg(
                count("*").alias("count"),
                avg("salary").alias("avg_salary"),
                spark_min("salary").alias("min_salary"),
                spark_max("salary").alias("max_salary")
            ) \
            .orderBy("department", "salary_quartile")
        
        print("薪资分位数统计:")
        quartile_stats.show()
        
        return ranking_df
    
    def demonstrate_join_operations(self):
        """
        演示连接操作
        """
        print("\n\n连接操作演示:")
        print("=" * 15)
        
        # 创建部门信息表
        dept_data = [
            ("Engineering", "Tech", "John Smith", 50),
            ("Sales", "Business", "Jane Doe", 30),
            ("Marketing", "Business", "Bob Johnson", 25),
            ("HR", "Support", "Alice Brown", 15),
            ("Finance", "Support", "Charlie Wilson", 20),
            ("Operations", "Support", "Diana Davis", 10)  # 这个部门在员工表中不存在
        ]
        
        dept_columns = ["department", "division", "manager", "budget_millions"]
        dept_df = self.spark.createDataFrame(dept_data, dept_columns)
        
        print("部门信息表:")
        dept_df.show()
        
        # 创建项目信息表
        project_data = [
            (1, "Project Alpha", "Engineering", 100),
            (2, "Project Beta", "Engineering", 150),
            (3, "Sales Campaign Q1", "Sales", 75),
            (4, "Marketing Blitz", "Marketing", 50),
            (5, "HR System Upgrade", "HR", 25),
            (6, "Financial Audit", "Finance", 30),
            (7, "Research Project", "R&D", 200)  # R&D部门不存在
        ]
        
        project_columns = ["project_id", "project_name", "department", "budget_k"]
        project_df = self.spark.createDataFrame(project_data, project_columns)
        
        print("项目信息表:")
        project_df.show()
        
        # 1. 内连接 (Inner Join)
        print("\n1. 内连接 - 员工和部门信息:")
        
        inner_join = self.df.join(dept_df, "department", "inner") \
            .select("name", "department", "position", "salary", "division", "manager")
        
        print(f"员工表记录数: {self.df.count()}")
        print(f"部门表记录数: {dept_df.count()}")
        print(f"内连接结果记录数: {inner_join.count()}")
        inner_join.show(10)
        
        # 2. 左外连接 (Left Outer Join)
        print("\n2. 左外连接 - 保留所有员工:")
        
        left_join = self.df.join(dept_df, "department", "left_outer") \
            .select("name", "department", "position", "salary", "division", "manager")
        
        print(f"左外连接结果记录数: {left_join.count()}")
        
        # 检查是否有员工没有对应的部门信息
        missing_dept_info = left_join.filter(col("division").isNull())
        print(f"没有部门信息的员工数: {missing_dept_info.count()}")
        
        # 3. 右外连接 (Right Outer Join)
        print("\n3. 右外连接 - 保留所有部门:")
        
        right_join = self.df.join(dept_df, "department", "right_outer") \
            .select("department", "division", "manager", "name", "position")
        
        print(f"右外连接结果记录数: {right_join.count()}")
        
        # 检查是否有部门没有员工
        empty_depts = right_join.filter(col("name").isNull())
        print("没有员工的部门:")
        empty_depts.show()
        
        # 4. 全外连接 (Full Outer Join)
        print("\n4. 全外连接:")
        
        full_join = self.df.join(dept_df, "department", "full_outer") \
            .select("name", "department", "position", "division", "manager")
        
        print(f"全外连接结果记录数: {full_join.count()}")
        
        # 5. 多表连接
        print("\n5. 多表连接 - 员工、部门、项目:")
        
        # 先连接员工和部门
        emp_dept = self.df.join(dept_df, "department", "inner")
        
        # 再连接项目
        multi_join = emp_dept.join(project_df, "department", "inner") \
            .select("name", "department", "position", "salary", 
                   "division", "project_name", "budget_k")
        
        print("员工参与的项目:")
        multi_join.show()
        
        # 6. 自连接
        print("\n6. 自连接 - 查找同部门同事:")
        
        # 为了避免列名冲突,给其中一个表起别名
        emp1 = self.df.alias("emp1")
        emp2 = self.df.alias("emp2")
        
        colleagues = emp1.join(
            emp2, 
            (col("emp1.department") == col("emp2.department")) & 
            (col("emp1.employee_id") != col("emp2.employee_id")),
            "inner"
        ).select(
            col("emp1.name").alias("employee1"),
            col("emp2.name").alias("employee2"),
            col("emp1.department").alias("department")
        )
        
        print("同部门同事关系(前20条):")
        colleagues.show(20)
        
        # 7. 连接性能优化
        print("\n7. 连接性能分析:")
        
        # 广播连接(适用于小表)
        from pyspark.sql.functions import broadcast
        
        broadcast_join = self.df.join(broadcast(dept_df), "department", "inner")
        
        print("广播连接执行计划:")
        broadcast_join.explain()
        
        return inner_join
    
    def demonstrate_pivot_operations(self):
        """
        演示透视操作
        """
        print("\n\n透视操作演示:")
        print("=" * 15)
        
        from pyspark.sql.functions import count, avg, sum as spark_sum
        
        # 1. 基本透视
        print("\n1. 基本透视 - 部门vs城市的员工数量:")
        
        pivot_count = self.df.groupBy("department") \
            .pivot("city") \
            .count() \
            .fillna(0)
        
        pivot_count.show()
        
        # 2. 透视聚合
        print("\n2. 透视聚合 - 部门vs教育程度的平均薪资:")
        
        pivot_salary = self.df.groupBy("department") \
            .pivot("education") \
            .avg("salary") \
            .fillna(0)
        
        # 重命名列以便理解
        for col_name in pivot_salary.columns:
            if col_name != "department":
                pivot_salary = pivot_salary.withColumnRenamed(
                    col_name, f"avg_salary_{col_name}"
                )
        
        pivot_salary.show()
        
        # 3. 多值透视
        print("\n3. 多值透视 - 性别vs部门的统计:")
        
        multi_pivot = self.df.groupBy("gender") \
            .pivot("department") \
            .agg(
                count("*").alias("count"),
                avg("salary").alias("avg_salary")
            )
        
        multi_pivot.show()
        
        # 4. 逆透视 (Unpivot) - 手动实现
        print("\n4. 逆透视操作:")
        
        # 首先创建一个透视表
        dept_city_pivot = self.df.groupBy("department") \
            .pivot("city") \
            .count() \
            .fillna(0)
        
        print("原始透视表:")
        dept_city_pivot.show()
        
        # 逆透视:将列转换回行
        from pyspark.sql.functions import expr, col
        
        # 获取城市列名(除了department)
        city_columns = [c for c in dept_city_pivot.columns if c != "department"]
        
        # 创建逆透视表达式
        unpivot_exprs = []
        for city in city_columns:
            unpivot_exprs.append(f"'{city}' as city, `{city}` as count")
        
        unpivot_expr = f"stack({len(city_columns)}, {', '.join(unpivot_exprs)}) as (city, count)"
        
        unpivoted = dept_city_pivot.select("department", expr(unpivot_expr)) \
            .filter("count > 0")
        
        print("逆透视结果:")
        unpivoted.show()
        
        return pivot_count
    
    def visualize_advanced_operations(self):
        """
        可视化高级操作结果
        """
        import matplotlib.pyplot as plt
        import numpy as np
        
        # 收集数据用于可视化
        dept_stats = self.df.groupBy("department") \
            .agg(
                count("*").alias("count"),
                avg("salary").alias("avg_salary"),
                avg("age").alias("avg_age")
            ).collect()
        
        departments = [row.department for row in dept_stats]
        counts = [row.count for row in dept_stats]
        avg_salaries = [row.avg_salary for row in dept_stats]
        avg_ages = [row.avg_age for row in dept_stats]
        
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        
        # 1. 部门员工数量
        ax1 = axes[0, 0]
        bars1 = ax1.bar(departments, counts, color='skyblue', alpha=0.7)
        ax1.set_title('各部门员工数量', fontweight='bold')
        ax1.set_xlabel('部门')
        ax1.set_ylabel('员工数量')
        ax1.tick_params(axis='x', rotation=45)
        
        # 添加数值标签
        for bar in bars1:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                    f'{int(height)}', ha='center', va='bottom')
        
        # 2. 部门平均薪资
        ax2 = axes[0, 1]
        bars2 = ax2.bar(departments, avg_salaries, color='lightgreen', alpha=0.7)
        ax2.set_title('各部门平均薪资', fontweight='bold')
        ax2.set_xlabel('部门')
        ax2.set_ylabel('平均薪资')
        ax2.tick_params(axis='x', rotation=45)
        
        # 添加数值标签
        for bar in bars2:
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height + 1000,
                    f'${height:,.0f}', ha='center', va='bottom')
        
        # 3. 薪资vs年龄散点图(按部门着色)
        ax3 = axes[1, 0]
        
        # 收集所有员工数据
        all_data = self.df.select("department", "age", "salary").collect()
        
        dept_colors = {'Engineering': 'red', 'Sales': 'blue', 'Marketing': 'green', 
                      'HR': 'orange', 'Finance': 'purple'}
        
        for dept in departments:
            dept_data = [row for row in all_data if row.department == dept]
            ages = [row.age for row in dept_data]
            salaries = [row.salary for row in dept_data]
            
            ax3.scatter(ages, salaries, c=dept_colors.get(dept, 'gray'), 
                       label=dept, alpha=0.6, s=50)
        
        ax3.set_title('年龄 vs 薪资(按部门分类)', fontweight='bold')
        ax3.set_xlabel('年龄')
        ax3.set_ylabel('薪资')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # 4. DataFrame操作复杂度对比
        ax4 = axes[1, 1]
        
        operations = ['Select', 'Filter', 'GroupBy', 'Join', 'Window', 'Pivot']
        complexity = [1, 2, 3, 4, 5, 4]  # 相对复杂度
        performance = [10, 8, 6, 4, 3, 5]  # 相对性能(越高越好)
        
        x = np.arange(len(operations))
        width = 0.35
        
        bars1 = ax4.bar(x - width/2, complexity, width, label='复杂度', 
                        color='lightcoral', alpha=0.7)
        bars2 = ax4.bar(x + width/2, performance, width, label='性能', 
                        color='lightblue', alpha=0.7)
        
        ax4.set_title('DataFrame操作复杂度 vs 性能', fontweight='bold')
        ax4.set_xlabel('操作类型')
        ax4.set_ylabel('评分')
        ax4.set_xticks(x)
        ax4.set_xticklabels(operations)
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        # 添加数值标签
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax4.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                        f'{height}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()

# DataFrame高级操作演示
df_advanced = DataFrameAdvancedOperations()

print("\nDataFrame高级操作")
print("=" * 18)

# 分组操作
dept_stats = df_advanced.demonstrate_groupby_operations()

# 窗口函数
ranking_df = df_advanced.demonstrate_window_functions()

# 连接操作
join_result = df_advanced.demonstrate_join_operations()

# 透视操作
pivot_result = df_advanced.demonstrate_pivot_operations()

# 可视化
df_advanced.visualize_advanced_operations()

print("\n高级操作总结:")
print("=" * 15)
print(f"✓ 部门统计结果数: {dept_stats.count()}")
print(f"✓ 排名结果数: {ranking_df.count()}")
print(f"✓ 连接结果数: {join_result.count()}")
print(f"✓ 透视结果数: {pivot_result.count()}")

print("\n关键技术点:")
print("- groupBy()和agg()进行数据聚合")
print("- 窗口函数实现复杂的分析需求")
print("- 多种连接类型处理关联数据")
print("- 透视操作进行数据重塑")
print("- 性能优化和执行计划分析")

3.4 SQL查询

Spark SQL基础

class SparkSQLQueries:
    """
    Spark SQL查询演示
    """
    
    def __init__(self):
        from pyspark.sql import SparkSession
        
        self.spark = SparkSession.builder \
            .appName("SparkSQL") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
            .getOrCreate()
        
        # 创建示例数据
        self.setup_sample_data()
    
    def setup_sample_data(self):
        """
        设置示例数据
        """
        import random
        from datetime import datetime, timedelta
        
        # 员工数据
        employees_data = []
        departments = ["Engineering", "Sales", "Marketing", "HR", "Finance"]
        
        for i in range(50):
            employees_data.append((
                i + 1,
                f"Employee_{i+1}",
                random.choice(departments),
                random.randint(50000, 150000),
                random.randint(25, 60),
                random.choice(["Male", "Female"]),
                (datetime.now() - timedelta(days=random.randint(30, 1825))).strftime("%Y-%m-%d")
            ))
        
        employees_columns = ["emp_id", "name", "department", "salary", "age", "gender", "hire_date"]
        self.employees_df = self.spark.createDataFrame(employees_data, employees_columns)
        
        # 注册为临时视图
        self.employees_df.createOrReplaceTempView("employees")
        
        # 部门数据
        departments_data = [
            ("Engineering", "Technology", "John Smith", "New York"),
            ("Sales", "Business", "Jane Doe", "Chicago"),
            ("Marketing", "Business", "Bob Johnson", "Los Angeles"),
            ("HR", "Support", "Alice Brown", "Boston"),
            ("Finance", "Support", "Charlie Wilson", "Seattle")
        ]
        
        departments_columns = ["dept_name", "division", "manager", "location"]
        self.departments_df = self.spark.createDataFrame(departments_data, departments_columns)
        self.departments_df.createOrReplaceTempView("departments")
        
        # 项目数据
        projects_data = [
            (1, "Project Alpha", "Engineering", 100000, "2023-01-01", "2023-12-31"),
            (2, "Project Beta", "Engineering", 150000, "2023-03-01", "2024-02-29"),
            (3, "Sales Campaign Q1", "Sales", 75000, "2023-01-01", "2023-03-31"),
            (4, "Marketing Blitz", "Marketing", 50000, "2023-06-01", "2023-08-31"),
            (5, "HR System Upgrade", "HR", 25000, "2023-04-01", "2023-09-30")
        ]
        
        projects_columns = ["project_id", "project_name", "department", "budget", "start_date", "end_date"]
        self.projects_df = self.spark.createDataFrame(projects_data, projects_columns)
        self.projects_df.createOrReplaceTempView("projects")
        
        print("示例数据已创建并注册为临时视图")
        print(f"员工表: {self.employees_df.count()} 条记录")
        print(f"部门表: {self.departments_df.count()} 条记录")
        print(f"项目表: {self.projects_df.count()} 条记录")
    
    def demonstrate_basic_sql_queries(self):
        """
        演示基础SQL查询
        """
        print("\n\n基础SQL查询演示:")
        print("=" * 18)
        
        # 1. 简单查询
        print("\n1. 简单SELECT查询:")
        
        result1 = self.spark.sql("""
            SELECT name, department, salary
            FROM employees
            WHERE salary > 80000
            ORDER BY salary DESC
            LIMIT 10
        """)
        
        print("高薪员工(薪资 > 80000):")
        result1.show()
        
        # 2. 聚合查询
        print("\n2. 聚合查询:")
        
        result2 = self.spark.sql("""
            SELECT 
                department,
                COUNT(*) as employee_count,
                AVG(salary) as avg_salary,
                MAX(salary) as max_salary,
                MIN(salary) as min_salary,
                STDDEV(salary) as salary_stddev
            FROM employees
            GROUP BY department
            ORDER BY avg_salary DESC
        """)
        
        print("部门统计:")
        result2.show()
        
        # 3. 条件聚合
        print("\n3. 条件聚合:")
        
        result3 = self.spark.sql("""
            SELECT 
                department,
                COUNT(*) as total_employees,
                SUM(CASE WHEN gender = 'Male' THEN 1 ELSE 0 END) as male_count,
                SUM(CASE WHEN gender = 'Female' THEN 1 ELSE 0 END) as female_count,
                SUM(CASE WHEN salary > 100000 THEN 1 ELSE 0 END) as high_salary_count,
                AVG(CASE WHEN age < 35 THEN salary END) as young_avg_salary
            FROM employees
            GROUP BY department
        """)
        
        print("条件聚合统计:")
        result3.show()
        
        # 4. 子查询
        print("\n4. 子查询:")
        
        result4 = self.spark.sql("""
            SELECT 
                name, 
                department, 
                salary,
                (salary - dept_avg.avg_salary) as salary_diff_from_avg
            FROM employees e
            JOIN (
                SELECT department, AVG(salary) as avg_salary
                FROM employees
                GROUP BY department
            ) dept_avg ON e.department = dept_avg.department
            WHERE salary > dept_avg.avg_salary
            ORDER BY salary_diff_from_avg DESC
        """)
        
        print("薪资高于部门平均值的员工:")
        result4.show()
        
        return result1
    
    def demonstrate_advanced_sql_queries(self):
        """
        演示高级SQL查询
        """
        print("\n\n高级SQL查询演示:")
        print("=" * 18)
        
        # 1. 窗口函数
        print("\n1. 窗口函数:")
        
        result1 = self.spark.sql("""
            SELECT 
                name,
                department,
                salary,
                ROW_NUMBER() OVER (PARTITION BY department ORDER BY salary DESC) as rank_in_dept,
                RANK() OVER (PARTITION BY department ORDER BY salary DESC) as rank_with_ties,
                DENSE_RANK() OVER (PARTITION BY department ORDER BY salary DESC) as dense_rank,
                PERCENT_RANK() OVER (PARTITION BY department ORDER BY salary) as percentile,
                LAG(salary, 1) OVER (PARTITION BY department ORDER BY hire_date) as prev_salary,
                LEAD(salary, 1) OVER (PARTITION BY department ORDER BY hire_date) as next_salary
            FROM employees
        """)
        
        print("窗口函数示例:")
        result1.filter("department = 'Engineering'").show()
        
        # 2. CTE (Common Table Expression)
        print("\n2. 公共表表达式 (CTE):")
        
        result2 = self.spark.sql("""
            WITH dept_stats AS (
                SELECT 
                    department,
                    AVG(salary) as avg_salary,
                    COUNT(*) as emp_count
                FROM employees
                GROUP BY department
            ),
            high_performing_depts AS (
                SELECT department
                FROM dept_stats
                WHERE avg_salary > 80000 AND emp_count >= 5
            )
            SELECT 
                e.name,
                e.department,
                e.salary,
                ds.avg_salary as dept_avg_salary,
                (e.salary / ds.avg_salary) as salary_ratio
            FROM employees e
            JOIN dept_stats ds ON e.department = ds.department
            JOIN high_performing_depts hpd ON e.department = hpd.department
            ORDER BY salary_ratio DESC
        """)
        
        print("高绩效部门员工分析:")
        result2.show()
        
        # 3. 复杂连接查询
        print("\n3. 复杂连接查询:")
        
        result3 = self.spark.sql("""
            SELECT 
                e.name,
                e.department,
                e.salary,
                d.manager,
                d.location,
                p.project_name,
                p.budget as project_budget,
                CASE 
                    WHEN e.salary > 100000 THEN 'Senior'
                    WHEN e.salary > 70000 THEN 'Mid-level'
                    ELSE 'Junior'
                END as level
            FROM employees e
            LEFT JOIN departments d ON e.department = d.dept_name
            LEFT JOIN projects p ON e.department = p.department
            WHERE e.salary > 60000
            ORDER BY e.salary DESC, p.budget DESC
        """)
        
        print("员工项目分配详情:")
        result3.show()
        
        # 4. 数据透视
        print("\n4. 数据透视 (使用SQL):")
        
        result4 = self.spark.sql("""
            SELECT 
                department,
                SUM(CASE WHEN gender = 'Male' THEN 1 ELSE 0 END) as male_count,
                SUM(CASE WHEN gender = 'Female' THEN 1 ELSE 0 END) as female_count,
                SUM(CASE WHEN age < 30 THEN 1 ELSE 0 END) as young_count,
                SUM(CASE WHEN age BETWEEN 30 AND 45 THEN 1 ELSE 0 END) as middle_age_count,
                SUM(CASE WHEN age > 45 THEN 1 ELSE 0 END) as senior_count
            FROM employees
            GROUP BY department
            ORDER BY department
        """)
        
        print("部门人员结构透视:")
        result4.show()
        
        return result1
    
    def demonstrate_sql_functions(self):
        """
        演示SQL内置函数
        """
        print("\n\nSQL内置函数演示:")
        print("=" * 18)
        
        # 1. 字符串函数
        print("\n1. 字符串函数:")
        
        result1 = self.spark.sql("""
            SELECT 
                name,
                UPPER(name) as name_upper,
                LOWER(name) as name_lower,
                LENGTH(name) as name_length,
                SUBSTRING(name, 1, 3) as name_prefix,
                CONCAT(name, ' - ', department) as name_dept,
                REPLACE(name, 'Employee_', 'Emp') as name_short
            FROM employees
            LIMIT 5
        """)
        
        print("字符串函数示例:")
        result1.show(truncate=False)
        
        # 2. 数值函数
        print("\n2. 数值函数:")
        
        result2 = self.spark.sql("""
            SELECT 
                name,
                salary,
                ROUND(salary / 12, 2) as monthly_salary,
                CEIL(salary / 1000) as salary_k_ceil,
                FLOOR(salary / 1000) as salary_k_floor,
                ABS(salary - 75000) as salary_diff_from_75k,
                POWER(salary / 1000, 2) as salary_k_squared
            FROM employees
            WHERE department = 'Engineering'
            LIMIT 5
        """)
        
        print("数值函数示例:")
        result2.show()
        
        # 3. 日期函数
        print("\n3. 日期函数:")
        
        result3 = self.spark.sql("""
            SELECT 
                name,
                hire_date,
                YEAR(hire_date) as hire_year,
                MONTH(hire_date) as hire_month,
                DAYOFWEEK(hire_date) as hire_day_of_week,
                DATEDIFF(CURRENT_DATE(), hire_date) as days_since_hire,
                DATE_ADD(hire_date, 365) as first_anniversary,
                DATE_FORMAT(hire_date, 'yyyy-MM') as hire_year_month
            FROM employees
            LIMIT 5
        """)
        
        print("日期函数示例:")
        result3.show()
        
        # 4. 条件函数
        print("\n4. 条件函数:")
        
        result4 = self.spark.sql("""
            SELECT 
                name,
                salary,
                age,
                CASE 
                    WHEN salary > 120000 THEN 'Executive'
                    WHEN salary > 90000 THEN 'Senior'
                    WHEN salary > 60000 THEN 'Mid-level'
                    ELSE 'Junior'
                END as salary_grade,
                IF(age > 40, 'Experienced', 'Young Professional') as experience_level,
                COALESCE(NULL, salary, 0) as salary_with_default,
                NULLIF(salary, 0) as salary_null_if_zero
            FROM employees
            LIMIT 10
        """)
        
        print("条件函数示例:")
        result4.show()
        
        return result1
    
    def demonstrate_performance_optimization(self):
        """
        演示SQL性能优化
        """
        print("\n\nSQL性能优化演示:")
        print("=" * 18)
        
        # 1. 查询计划分析
        print("\n1. 查询执行计划:")
        
        query = """
            SELECT 
                e.department,
                COUNT(*) as emp_count,
                AVG(e.salary) as avg_salary
            FROM employees e
            JOIN departments d ON e.department = d.dept_name
            GROUP BY e.department
            ORDER BY avg_salary DESC
        """
        
        result = self.spark.sql(query)
        
        print("查询结果:")
        result.show()
        
        print("\n执行计划:")
        result.explain(True)
        
        # 2. 缓存优化
        print("\n2. 缓存优化:")
        
        # 缓存常用表
        self.spark.sql("CACHE TABLE employees")
        self.spark.sql("CACHE TABLE departments")
        
        print("表已缓存")
        
        # 查看缓存状态
        cached_tables = self.spark.sql("SHOW TABLES").collect()
        print("当前表状态:")
        for table in cached_tables:
            print(f"- {table.tableName}")
        
        # 3. 分区优化
        print("\n3. 分区优化建议:")
        
        partition_query = """
            SELECT 
                department,
                COUNT(*) as count
            FROM employees
            GROUP BY department
        """
        
        partition_result = self.spark.sql(partition_query)
        print("按部门分区的数据分布:")
        partition_result.show()
        
        # 4. 广播连接优化
        print("\n4. 广播连接优化:")
        
        broadcast_query = """
            SELECT /*+ BROADCAST(d) */
                e.name,
                e.salary,
                d.manager,
                d.location
            FROM employees e
            JOIN departments d ON e.department = d.dept_name
            WHERE e.salary > 80000
        """
        
        broadcast_result = self.spark.sql(broadcast_query)
        print("广播连接结果:")
        broadcast_result.show()
        
        print("\n广播连接执行计划:")
        broadcast_result.explain()
        
        return result
    
    def cleanup(self):
        """
        清理资源
        """
        # 清除缓存
        self.spark.sql("UNCACHE TABLE IF EXISTS employees")
        self.spark.sql("UNCACHE TABLE IF EXISTS departments")
        self.spark.sql("UNCACHE TABLE IF EXISTS projects")
        
        print("缓存已清理")

# SQL查询演示
sql_demo = SparkSQLQueries()

print("\nSpark SQL查询演示")
print("=" * 20)

# 基础查询
basic_result = sql_demo.demonstrate_basic_sql_queries()

# 高级查询
advanced_result = sql_demo.demonstrate_advanced_sql_queries()

# 内置函数
functions_result = sql_demo.demonstrate_sql_functions()

# 性能优化
performance_result = sql_demo.demonstrate_performance_optimization()

print("\nSQL查询总结:")
print("=" * 15)
print(f"✓ 基础查询演示完成")
print(f"✓ 高级查询演示完成")
print(f"✓ 内置函数演示完成")
print(f"✓ 性能优化演示完成")

print("\n关键技术点:")
print("- 使用createOrReplaceTempView()注册临时视图")
print("- 支持标准SQL语法和高级特性")
print("- 窗口函数和CTE提供强大分析能力")
print("- 查询优化器自动优化执行计划")
print("- 缓存和广播连接提升性能")

# 清理资源
sql_demo.cleanup()

3.5 数据源操作

多种数据源支持

class DataSourceOperations:
    """
    数据源操作演示
    """
    
    def __init__(self):
        from pyspark.sql import SparkSession
        
        self.spark = SparkSession.builder \
            .appName("DataSources") \
            .config("spark.sql.adaptive.enabled", "true") \
            .getOrCreate()
        
        # 创建示例数据目录
        import os
        self.data_dir = "spark_data_examples"
        if not os.path.exists(self.data_dir):
            os.makedirs(self.data_dir)
    
    def demonstrate_file_formats(self):
        """
        演示不同文件格式的读写
        """
        print("\n\n文件格式操作演示:")
        print("=" * 18)
        
        # 创建示例数据
        sample_data = [
            (1, "Alice", 25, "Engineering", 75000),
            (2, "Bob", 30, "Sales", 65000),
            (3, "Charlie", 35, "Marketing", 70000),
            (4, "Diana", 28, "HR", 60000),
            (5, "Eve", 32, "Finance", 80000)
        ]
        
        columns = ["id", "name", "age", "department", "salary"]
        df = self.spark.createDataFrame(sample_data, columns)
        
        print("原始数据:")
        df.show()
        
        # 1. CSV格式
        print("\n1. CSV格式操作:")
        
        csv_path = f"{self.data_dir}/employees.csv"
        
        # 写入CSV
        df.coalesce(1).write \
            .mode("overwrite") \
            .option("header", "true") \
            .csv(csv_path)
        
        # 读取CSV
        csv_df = self.spark.read \
            .option("header", "true") \
            .option("inferSchema", "true") \
            .csv(csv_path)
        
        print("从CSV读取的数据:")
        csv_df.show()
        csv_df.printSchema()
        
        # 2. JSON格式
        print("\n2. JSON格式操作:")
        
        json_path = f"{self.data_dir}/employees.json"
        
        # 写入JSON
        df.coalesce(1).write \
            .mode("overwrite") \
            .json(json_path)
        
        # 读取JSON
        json_df = self.spark.read.json(json_path)
        
        print("从JSON读取的数据:")
        json_df.show()
        json_df.printSchema()
        
        # 3. Parquet格式
        print("\n3. Parquet格式操作:")
        
        parquet_path = f"{self.data_dir}/employees.parquet"
        
        # 写入Parquet
        df.write \
            .mode("overwrite") \
            .parquet(parquet_path)
        
        # 读取Parquet
        parquet_df = self.spark.read.parquet(parquet_path)
        
        print("从Parquet读取的数据:")
        parquet_df.show()
        parquet_df.printSchema()
        
        # 4. 分区写入
        print("\n4. 分区写入:")
        
        partitioned_path = f"{self.data_dir}/employees_partitioned"
        
        # 按部门分区写入
        df.write \
            .mode("overwrite") \
            .partitionBy("department") \
            .parquet(partitioned_path)
        
        # 读取分区数据
        partitioned_df = self.spark.read.parquet(partitioned_path)
        
        print("分区数据:")
        partitioned_df.show()
        
        # 查看分区信息
        print("分区信息:")
        import os
        for root, dirs, files in os.walk(partitioned_path):
            level = root.replace(partitioned_path, '').count(os.sep)
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            subindent = ' ' * 2 * (level + 1)
            for file in files:
                print(f"{subindent}{file}")
        
        return df
    
    def demonstrate_database_operations(self):
        """
        演示数据库操作(模拟)
        """
        print("\n\n数据库操作演示:")
        print("=" * 18)
        
        # 注意:这里演示的是概念,实际使用需要真实的数据库连接
        
        # 1. JDBC连接配置示例
        print("\n1. JDBC连接配置:")
        
        jdbc_config = {
            "url": "jdbc:postgresql://localhost:5432/mydb",
            "driver": "org.postgresql.Driver",
            "user": "username",
            "password": "password"
        }
        
        print("JDBC配置示例:")
        for key, value in jdbc_config.items():
            if key != "password":
                print(f"  {key}: {value}")
            else:
                print(f"  {key}: {'*' * len(value)}")
        
        # 2. 读取数据库表的代码示例
        print("\n2. 读取数据库表:")
        
        read_code = '''
# 从数据库读取数据
df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/mydb") \
    .option("dbtable", "employees") \
    .option("user", "username") \
    .option("password", "password") \
    .option("driver", "org.postgresql.Driver") \
    .load()

# 使用查询读取
query_df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/mydb") \
    .option("query", "SELECT * FROM employees WHERE salary > 70000") \
    .option("user", "username") \
    .option("password", "password") \
    .load()
        '''
        
        print("代码示例:")
        print(read_code)
        
        # 3. 写入数据库的代码示例
        print("\n3. 写入数据库:")
        
        write_code = '''
# 写入数据库
df.write \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/mydb") \
    .option("dbtable", "employees_backup") \
    .option("user", "username") \
    .option("password", "password") \
    .option("driver", "org.postgresql.Driver") \
    .mode("overwrite") \
    .save()

# 追加模式写入
df.write \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/mydb") \
    .option("dbtable", "employees_log") \
    .option("user", "username") \
    .option("password", "password") \
    .mode("append") \
    .save()
        '''
        
        print("代码示例:")
        print(write_code)
        
        # 4. 数据库连接优化
        print("\n4. 数据库连接优化:")
        
        optimization_tips = [
            "使用连接池减少连接开销",
            "设置合适的fetchsize提高读取性能",
            "使用分区读取大表数据",
            "批量写入提高写入性能",
            "使用预处理语句防止SQL注入"
        ]
        
        print("优化建议:")
        for i, tip in enumerate(optimization_tips, 1):
            print(f"  {i}. {tip}")
        
        optimized_code = '''
# 优化的数据库读取
df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:postgresql://localhost:5432/mydb") \
    .option("dbtable", "employees") \
    .option("user", "username") \
    .option("password", "password") \
    .option("driver", "org.postgresql.Driver") \
    .option("fetchsize", "1000") \
    .option("partitionColumn", "id") \
    .option("lowerBound", "1") \
    .option("upperBound", "10000") \
    .option("numPartitions", "4") \
    .load()
        '''
        
        print("\n优化代码示例:")
        print(optimized_code)
    
    def demonstrate_streaming_sources(self):
        """
        演示流数据源(概念演示)
        """
        print("\n\n流数据源演示:")
        print("=" * 15)
        
        # 1. 文件流
        print("\n1. 文件流监控:")
        
        file_stream_code = '''
# 监控目录中的新文件
file_stream = spark.readStream \
    .format("csv") \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .schema(schema) \
    .load("/path/to/streaming/directory")

# 处理流数据
processed_stream = file_stream \
    .filter(col("salary") > 70000) \
    .groupBy("department") \
    .agg(avg("salary").alias("avg_salary"))

# 输出流
query = processed_stream.writeStream \
    .outputMode("complete") \
    .format("console") \
    .trigger(processingTime='10 seconds') \
    .start()
        '''
        
        print("文件流代码示例:")
        print(file_stream_code)
        
        # 2. Kafka流
        print("\n2. Kafka流:")
        
        kafka_stream_code = '''
# 从Kafka读取流数据
kafka_stream = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "employee_updates") \
    .option("startingOffsets", "latest") \
    .load()

# 解析Kafka消息
parsed_stream = kafka_stream.select(
    col("key").cast("string"),
    from_json(col("value").cast("string"), schema).alias("data")
).select("key", "data.*")

# 写入Kafka
output_query = processed_data.writeStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("topic", "processed_employees") \
    .option("checkpointLocation", "/path/to/checkpoint") \
    .start()
        '''
        
        print("Kafka流代码示例:")
        print(kafka_stream_code)
        
        # 3. Socket流
        print("\n3. Socket流:")
        
        socket_stream_code = '''
# 从Socket读取流数据
socket_stream = spark.readStream \
    .format("socket") \
    .option("host", "localhost") \
    .option("port", 9999) \
    .load()

# 处理文本数据
word_counts = socket_stream \
    .select(explode(split(col("value"), " ")).alias("word")) \
    .groupBy("word") \
    .count()

# 输出到控制台
query = word_counts.writeStream \
    .outputMode("complete") \
    .format("console") \
    .start()
        '''
        
        print("Socket流代码示例:")
        print(socket_stream_code)
    
    def demonstrate_data_source_options(self):
        """
        演示数据源选项和配置
        """
        print("\n\n数据源选项演示:")
        print("=" * 18)
        
        # 创建示例数据
        sample_data = [
            (1, "Alice", 25, "Engineering", 75000),
            (2, "Bob", 30, "Sales", 65000),
            (3, "Charlie", 35, "Marketing", 70000)
        ]
        
        columns = ["id", "name", "age", "department", "salary"]
        df = self.spark.createDataFrame(sample_data, columns)
        
        # 1. CSV选项
        print("\n1. CSV读写选项:")
        
        csv_options = {
            "header": "true",
            "inferSchema": "true",
            "sep": ",",
            "quote": '"',
            "escape": "\\",
            "nullValue": "NULL",
            "dateFormat": "yyyy-MM-dd",
            "timestampFormat": "yyyy-MM-dd HH:mm:ss",
            "encoding": "UTF-8"
        }
        
        print("CSV选项:")
        for key, value in csv_options.items():
            print(f"  {key}: {value}")
        
        # 写入带选项的CSV
        csv_path = f"{self.data_dir}/employees_with_options.csv"
        
        df.coalesce(1).write \
            .mode("overwrite") \
            .options(**csv_options) \
            .csv(csv_path)
        
        # 读取带选项的CSV
        csv_df = self.spark.read \
            .options(**csv_options) \
            .csv(csv_path)
        
        print("\n读取结果:")
        csv_df.show()
        
        # 2. JSON选项
        print("\n2. JSON读写选项:")
        
        json_options = {
            "multiLine": "true",
            "allowComments": "true",
            "allowUnquotedFieldNames": "true",
            "allowSingleQuotes": "true",
            "allowNumericLeadingZeros": "true"
        }
        
        print("JSON选项:")
        for key, value in json_options.items():
            print(f"  {key}: {value}")
        
        # 3. Parquet选项
        print("\n3. Parquet读写选项:")
        
        parquet_options = {
            "compression": "snappy",  # gzip, lzo, brotli, lz4, zstd
            "mergeSchema": "true"
        }
        
        print("Parquet选项:")
        for key, value in parquet_options.items():
            print(f"  {key}: {value}")
        
        # 写入压缩的Parquet
        parquet_path = f"{self.data_dir}/employees_compressed.parquet"
        
        df.write \
            .mode("overwrite") \
            .options(**parquet_options) \
            .parquet(parquet_path)
        
        print("\n压缩Parquet文件已创建")
        
        # 4. 性能对比
        print("\n4. 文件格式性能对比:")
        
        import time
        import os
        
        formats = ["csv", "json", "parquet"]
        performance_data = []
        
        for fmt in formats:
            path = f"{self.data_dir}/employees.{fmt}"
            
            # 写入时间
            start_time = time.time()
            if fmt == "csv":
                df.coalesce(1).write.mode("overwrite").option("header", "true").csv(path)
            elif fmt == "json":
                df.coalesce(1).write.mode("overwrite").json(path)
            else:
                df.write.mode("overwrite").parquet(path)
            write_time = time.time() - start_time
            
            # 读取时间
            start_time = time.time()
            if fmt == "csv":
                read_df = self.spark.read.option("header", "true").option("inferSchema", "true").csv(path)
            elif fmt == "json":
                read_df = self.spark.read.json(path)
            else:
                read_df = self.spark.read.parquet(path)
            read_df.count()  # 触发实际读取
            read_time = time.time() - start_time
            
            # 文件大小
            file_size = 0
            for root, dirs, files in os.walk(path):
                for file in files:
                    file_path = os.path.join(root, file)
                    file_size += os.path.getsize(file_path)
            
            performance_data.append((fmt, write_time, read_time, file_size))
        
        print("\n性能对比结果:")
        print(f"{'格式':<10} {'写入时间(s)':<12} {'读取时间(s)':<12} {'文件大小(bytes)':<15}")
        print("-" * 55)
        for fmt, write_t, read_t, size in performance_data:
            print(f"{fmt:<10} {write_t:<12.4f} {read_t:<12.4f} {size:<15}")
        
        return df
    
    def cleanup(self):
        """
        清理示例文件
        """
        import shutil
        import os
        
        if os.path.exists(self.data_dir):
            shutil.rmtree(self.data_dir)
            print(f"\n已清理示例数据目录: {self.data_dir}")

# 数据源操作演示
data_source_demo = DataSourceOperations()

print("\n数据源操作演示")
print("=" * 18)

# 文件格式操作
file_result = data_source_demo.demonstrate_file_formats()

# 数据库操作
data_source_demo.demonstrate_database_operations()

# 流数据源
data_source_demo.demonstrate_streaming_sources()

# 数据源选项
options_result = data_source_demo.demonstrate_data_source_options()

print("\n数据源操作总结:")
print("=" * 18)
print(f"✓ 文件格式操作演示完成")
print(f"✓ 数据库操作概念演示完成")
print(f"✓ 流数据源概念演示完成")
print(f"✓ 数据源选项演示完成")

print("\n关键技术点:")
print("- 支持CSV、JSON、Parquet等多种文件格式")
print("- JDBC连接支持各种关系型数据库")
print("- 流处理支持实时数据处理")
print("- 丰富的读写选项优化性能")
print("- 分区和压缩提高存储效率")

# 清理资源
data_source_demo.cleanup()

3.6 性能优化

Spark SQL性能优化策略

class SparkSQLPerformanceOptimizer:
    """
    Spark SQL性能优化演示
    """
    
    def __init__(self):
        from pyspark.sql import SparkSession
        from pyspark.sql.functions import *
        
        self.spark = SparkSession.builder \
            .appName("SparkSQLOptimization") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
            .config("spark.sql.adaptive.skewJoin.enabled", "true") \
            .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
            .getOrCreate()
        
        # 创建大数据集用于性能测试
        self.setup_large_datasets()
    
    def setup_large_datasets(self):
        """
        创建大数据集用于性能测试
        """
        import random
        from datetime import datetime, timedelta
        
        print("创建大数据集用于性能测试...")
        
        # 创建大型员工数据集
        employees_data = []
        departments = ["Engineering", "Sales", "Marketing", "HR", "Finance", "Operations", "Legal", "Support"]
        locations = ["New York", "San Francisco", "Chicago", "Boston", "Seattle", "Austin", "Denver", "Atlanta"]
        
        for i in range(100000):  # 10万条记录
            employees_data.append((
                i + 1,
                f"Employee_{i+1}",
                random.choice(departments),
                random.randint(40000, 200000),
                random.randint(22, 65),
                random.choice(["Male", "Female"]),
                random.choice(locations),
                (datetime.now() - timedelta(days=random.randint(30, 3650))).strftime("%Y-%m-%d")
            ))
        
        employees_columns = ["emp_id", "name", "department", "salary", "age", "gender", "location", "hire_date"]
        self.large_employees_df = self.spark.createDataFrame(employees_data, employees_columns)
        self.large_employees_df.createOrReplaceTempView("large_employees")
        
        # 创建订单数据集
        orders_data = []
        for i in range(500000):  # 50万条订单记录
            orders_data.append((
                i + 1,
                random.randint(1, 100000),  # 员工ID
                random.randint(100, 10000),  # 订单金额
                random.choice(["Pending", "Completed", "Cancelled", "Shipped"]),
                (datetime.now() - timedelta(days=random.randint(1, 365))).strftime("%Y-%m-%d")
            ))
        
        orders_columns = ["order_id", "emp_id", "amount", "status", "order_date"]
        self.orders_df = self.spark.createDataFrame(orders_data, orders_columns)
        self.orders_df.createOrReplaceTempView("orders")
        
        print(f"大型员工数据集: {self.large_employees_df.count():,} 条记录")
        print(f"订单数据集: {self.orders_df.count():,} 条记录")
    
    def demonstrate_caching_optimization(self):
        """
        演示缓存优化
        """
        print("\n\n缓存优化演示:")
        print("=" * 15)
        
        import time
        
        # 1. 不使用缓存的查询
        print("\n1. 不使用缓存的重复查询:")
        
        query = """
            SELECT 
                department,
                COUNT(*) as emp_count,
                AVG(salary) as avg_salary,
                MAX(salary) as max_salary
            FROM large_employees
            WHERE salary > 60000
            GROUP BY department
            ORDER BY avg_salary DESC
        """
        
        # 第一次查询
        start_time = time.time()
        result1 = self.spark.sql(query)
        result1.show()
        first_query_time = time.time() - start_time
        
        # 第二次查询(无缓存)
        start_time = time.time()
        result2 = self.spark.sql(query)
        result2.show()
        second_query_time = time.time() - start_time
        
        print(f"第一次查询时间: {first_query_time:.2f}秒")
        print(f"第二次查询时间: {second_query_time:.2f}秒")
        
        # 2. 使用缓存的查询
        print("\n2. 使用缓存的重复查询:")
        
        # 缓存表
        self.spark.sql("CACHE TABLE large_employees")
        
        # 第一次查询(触发缓存)
        start_time = time.time()
        result3 = self.spark.sql(query)
        result3.show()
        first_cached_time = time.time() - start_time
        
        # 第二次查询(使用缓存)
        start_time = time.time()
        result4 = self.spark.sql(query)
        result4.show()
        second_cached_time = time.time() - start_time
        
        print(f"第一次缓存查询时间: {first_cached_time:.2f}秒")
        print(f"第二次缓存查询时间: {second_cached_time:.2f}秒")
        
        # 3. 缓存性能对比
        print("\n3. 缓存性能对比:")
        print(f"无缓存平均时间: {(first_query_time + second_query_time) / 2:.2f}秒")
        print(f"有缓存平均时间: {(first_cached_time + second_cached_time) / 2:.2f}秒")
        print(f"性能提升: {((first_query_time + second_query_time) / 2) / ((first_cached_time + second_cached_time) / 2):.2f}倍")
        
        # 4. 查看缓存状态
        print("\n4. 缓存状态:")
        cache_status = self.spark.sql("SHOW TABLES").collect()
        for table in cache_status:
            print(f"表: {table.tableName}")
    
    def demonstrate_join_optimization(self):
        """
        演示连接优化
        """
        print("\n\n连接优化演示:")
        print("=" * 15)
        
        import time
        
        # 1. 普通连接
        print("\n1. 普通连接查询:")
        
        normal_join_query = """
            SELECT 
                e.name,
                e.department,
                e.salary,
                COUNT(o.order_id) as order_count,
                SUM(o.amount) as total_amount
            FROM large_employees e
            LEFT JOIN orders o ON e.emp_id = o.emp_id
            WHERE e.salary > 80000
            GROUP BY e.emp_id, e.name, e.department, e.salary
            ORDER BY total_amount DESC
            LIMIT 20
        """
        
        start_time = time.time()
        normal_result = self.spark.sql(normal_join_query)
        normal_result.show()
        normal_join_time = time.time() - start_time
        
        print(f"普通连接时间: {normal_join_time:.2f}秒")
        
        # 查看执行计划
        print("\n普通连接执行计划:")
        normal_result.explain()
        
        # 2. 广播连接优化
        print("\n2. 广播连接优化:")
        
        # 创建小表用于广播
        dept_info = [
            ("Engineering", "Technology", "High"),
            ("Sales", "Business", "Medium"),
            ("Marketing", "Business", "Medium"),
            ("HR", "Support", "Low"),
            ("Finance", "Support", "Medium"),
            ("Operations", "Operations", "Medium"),
            ("Legal", "Support", "Low"),
            ("Support", "Support", "Low")
        ]
        
        dept_columns = ["dept_name", "division", "priority"]
        dept_df = self.spark.createDataFrame(dept_info, dept_columns)
        dept_df.createOrReplaceTempView("dept_info")
        
        broadcast_join_query = """
            SELECT /*+ BROADCAST(d) */
                e.name,
                e.department,
                e.salary,
                d.division,
                d.priority,
                COUNT(o.order_id) as order_count
            FROM large_employees e
            JOIN dept_info d ON e.department = d.dept_name
            LEFT JOIN orders o ON e.emp_id = o.emp_id
            WHERE e.salary > 80000 AND d.priority = 'High'
            GROUP BY e.emp_id, e.name, e.department, e.salary, d.division, d.priority
            ORDER BY order_count DESC
            LIMIT 20
        """
        
        start_time = time.time()
        broadcast_result = self.spark.sql(broadcast_join_query)
        broadcast_result.show()
        broadcast_join_time = time.time() - start_time
        
        print(f"广播连接时间: {broadcast_join_time:.2f}秒")
        
        # 查看执行计划
        print("\n广播连接执行计划:")
        broadcast_result.explain()
        
        # 3. 分区连接优化
        print("\n3. 分区连接优化建议:")
        
        partition_tips = [
            "按连接键分区数据",
            "确保分区数量合理(通常为CPU核心数的2-3倍)",
            "避免数据倾斜",
            "使用bucketing预分区数据",
            "考虑使用范围分区"
        ]
        
        print("分区优化建议:")
        for i, tip in enumerate(partition_tips, 1):
            print(f"  {i}. {tip}")
    
    def demonstrate_predicate_pushdown(self):
        """
        演示谓词下推优化
        """
        print("\n\n谓词下推优化演示:")
        print("=" * 18)
        
        # 1. 低效查询(过滤在连接后)
        print("\n1. 低效查询(过滤在连接后):")
        
        inefficient_query = """
            SELECT 
                e.name,
                e.department,
                e.salary,
                o.amount,
                o.status
            FROM large_employees e
            JOIN orders o ON e.emp_id = o.emp_id
            WHERE e.salary > 100000 AND o.amount > 5000 AND o.status = 'Completed'
        """
        
        print("低效查询执行计划:")
        inefficient_result = self.spark.sql(inefficient_query)
        inefficient_result.explain()
        
        # 2. 高效查询(谓词下推)
        print("\n2. 高效查询(谓词下推):")
        
        efficient_query = """
            WITH filtered_employees AS (
                SELECT emp_id, name, department, salary
                FROM large_employees
                WHERE salary > 100000
            ),
            filtered_orders AS (
                SELECT emp_id, amount, status
                FROM orders
                WHERE amount > 5000 AND status = 'Completed'
            )
            SELECT 
                e.name,
                e.department,
                e.salary,
                o.amount,
                o.status
            FROM filtered_employees e
            JOIN filtered_orders o ON e.emp_id = o.emp_id
        """
        
        print("高效查询执行计划:")
        efficient_result = self.spark.sql(efficient_query)
        efficient_result.explain()
        
        # 3. 性能对比
        print("\n3. 性能对比:")
        
        import time
        
        # 低效查询时间
        start_time = time.time()
        inefficient_count = inefficient_result.count()
        inefficient_time = time.time() - start_time
        
        # 高效查询时间
        start_time = time.time()
        efficient_count = efficient_result.count()
        efficient_time = time.time() - start_time
        
        print(f"低效查询结果数: {inefficient_count:,}")
        print(f"低效查询时间: {inefficient_time:.2f}秒")
        print(f"高效查询结果数: {efficient_count:,}")
        print(f"高效查询时间: {efficient_time:.2f}秒")
        
        if efficient_time > 0:
            print(f"性能提升: {inefficient_time / efficient_time:.2f}倍")
    
    def demonstrate_columnar_optimization(self):
        """
        演示列式存储优化
        """
        print("\n\n列式存储优化演示:")
        print("=" * 18)
        
        # 1. 列选择优化
        print("\n1. 列选择优化:")
        
        # 低效:选择所有列
        print("低效查询(选择所有列):")
        inefficient_columns_query = """
            SELECT *
            FROM large_employees
            WHERE department = 'Engineering'
        """
        
        inefficient_columns_result = self.spark.sql(inefficient_columns_query)
        print(f"选择所有列的结果: {inefficient_columns_result.count():,} 行")
        
        # 高效:只选择需要的列
        print("\n高效查询(只选择需要的列):")
        efficient_columns_query = """
            SELECT name, salary
            FROM large_employees
            WHERE department = 'Engineering'
        """
        
        efficient_columns_result = self.spark.sql(efficient_columns_query)
        print(f"选择特定列的结果: {efficient_columns_result.count():,} 行")
        
        # 2. 聚合优化
        print("\n2. 聚合优化:")
        
        aggregation_query = """
            SELECT 
                department,
                location,
                COUNT(*) as emp_count,
                AVG(salary) as avg_salary,
                MIN(salary) as min_salary,
                MAX(salary) as max_salary
            FROM large_employees
            GROUP BY department, location
            HAVING COUNT(*) > 100
            ORDER BY avg_salary DESC
        """
        
        print("聚合查询结果:")
        agg_result = self.spark.sql(aggregation_query)
        agg_result.show(20)
        
        print("\n聚合查询执行计划:")
        agg_result.explain()
    
    def demonstrate_adaptive_query_execution(self):
        """
        演示自适应查询执行
        """
        print("\n\n自适应查询执行演示:")
        print("=" * 20)
        
        # 1. 自适应配置
        print("\n1. 自适应查询执行配置:")
        
        aqe_configs = {
            "spark.sql.adaptive.enabled": "true",
            "spark.sql.adaptive.coalescePartitions.enabled": "true",
            "spark.sql.adaptive.skewJoin.enabled": "true",
            "spark.sql.adaptive.localShuffleReader.enabled": "true",
            "spark.sql.adaptive.coalescePartitions.minPartitionNum": "1",
            "spark.sql.adaptive.advisoryPartitionSizeInBytes": "64MB"
        }
        
        print("自适应查询执行配置:")
        for config, value in aqe_configs.items():
            print(f"  {config}: {value}")
        
        # 2. 复杂查询测试
        print("\n2. 复杂查询测试:")
        
        complex_query = """
            WITH employee_stats AS (
                SELECT 
                    department,
                    location,
                    COUNT(*) as emp_count,
                    AVG(salary) as avg_salary
                FROM large_employees
                GROUP BY department, location
            ),
            order_stats AS (
                SELECT 
                    emp_id,
                    COUNT(*) as order_count,
                    SUM(amount) as total_amount
                FROM orders
                WHERE status = 'Completed'
                GROUP BY emp_id
            )
            SELECT 
                es.department,
                es.location,
                es.emp_count,
                es.avg_salary,
                AVG(os.order_count) as avg_orders_per_emp,
                AVG(os.total_amount) as avg_amount_per_emp
            FROM employee_stats es
            LEFT JOIN large_employees e ON es.department = e.department AND es.location = e.location
            LEFT JOIN order_stats os ON e.emp_id = os.emp_id
            GROUP BY es.department, es.location, es.emp_count, es.avg_salary
            HAVING es.emp_count > 50
            ORDER BY avg_amount_per_emp DESC NULLS LAST
        """
        
        print("执行复杂查询...")
        complex_result = self.spark.sql(complex_query)
        complex_result.show()
        
        print("\n复杂查询执行计划:")
        complex_result.explain(True)
    
    def demonstrate_memory_optimization(self):
        """
        演示内存优化
        """
        print("\n\n内存优化演示:")
        print("=" * 15)
        
        # 1. 内存配置建议
        print("\n1. 内存配置建议:")
        
        memory_configs = {
            "spark.executor.memory": "4g",
            "spark.executor.memoryFraction": "0.8",
            "spark.sql.execution.arrow.pyspark.enabled": "true",
            "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
            "spark.sql.adaptive.coalescePartitions.enabled": "true"
        }
        
        print("内存优化配置:")
        for config, value in memory_configs.items():
            print(f"  {config}: {value}")
        
        # 2. 内存使用监控
        print("\n2. 内存使用监控:")
        
        # 获取当前内存使用情况
        storage_level_info = [
            "MEMORY_ONLY: 仅内存存储,最快但可能丢失数据",
            "MEMORY_AND_DISK: 内存优先,磁盘备份",
            "MEMORY_ONLY_SER: 序列化内存存储,节省空间",
            "DISK_ONLY: 仅磁盘存储,最慢但最可靠"
        ]
        
        print("存储级别选择:")
        for info in storage_level_info:
            print(f"  - {info}")
        
        # 3. 垃圾回收优化
        print("\n3. 垃圾回收优化建议:")
        
        gc_tips = [
            "使用G1GC垃圾回收器",
            "设置合适的新生代大小",
            "监控GC频率和时间",
            "避免创建大量小对象",
            "使用对象池重用对象"
        ]
        
        print("垃圾回收优化:")
        for i, tip in enumerate(gc_tips, 1):
            print(f"  {i}. {tip}")
    
    def cleanup(self):
        """
        清理缓存和资源
        """
        # 清除所有缓存
        self.spark.sql("UNCACHE TABLE IF EXISTS large_employees")
        self.spark.sql("UNCACHE TABLE IF EXISTS orders")
        self.spark.sql("UNCACHE TABLE IF EXISTS dept_info")
        
        print("\n所有缓存已清理")

# 性能优化演示
optimizer = SparkSQLPerformanceOptimizer()

print("\nSpark SQL性能优化演示")
print("=" * 25)

# 缓存优化
optimizer.demonstrate_caching_optimization()

# 连接优化
optimizer.demonstrate_join_optimization()

# 谓词下推
optimizer.demonstrate_predicate_pushdown()

# 列式存储优化
optimizer.demonstrate_columnar_optimization()

# 自适应查询执行
optimizer.demonstrate_adaptive_query_execution()

# 内存优化
optimizer.demonstrate_memory_optimization()

print("\n性能优化总结:")
print("=" * 18)
print(f"✓ 缓存优化演示完成")
print(f"✓ 连接优化演示完成")
print(f"✓ 谓词下推优化演示完成")
print(f"✓ 列式存储优化演示完成")
print(f"✓ 自适应查询执行演示完成")
print(f"✓ 内存优化建议完成")

print("\n关键优化技术:")
print("- 合理使用缓存减少重复计算")
print("- 广播小表优化连接性能")
print("- 谓词下推减少数据传输")
print("- 列式存储提高查询效率")
print("- 自适应执行动态优化查询")
print("- 内存配置优化资源使用")

# 清理资源
optimizer.cleanup()

3.7 实际案例:电商数据分析系统

综合案例演示

class ECommerceAnalyticsSystem:
    """
    电商数据分析系统 - Spark SQL综合案例
    """
    
    def __init__(self):
        from pyspark.sql import SparkSession
        from pyspark.sql.functions import *
        from pyspark.sql.types import *
        
        self.spark = SparkSession.builder \
            .appName("ECommerceAnalytics") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
            .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
            .getOrCreate()
        
        # 创建电商数据
        self.create_ecommerce_data()
    
    def create_ecommerce_data(self):
        """
        创建电商业务数据
        """
        import random
        from datetime import datetime, timedelta
        
        print("创建电商业务数据...")
        
        # 1. 用户数据
        users_data = []
        cities = ["北京", "上海", "广州", "深圳", "杭州", "南京", "成都", "武汉"]
        age_groups = ["18-25", "26-35", "36-45", "46-55", "55+"]
        
        for i in range(10000):
            users_data.append((
                i + 1,
                f"user_{i+1}@email.com",
                random.choice(["男", "女"]),
                random.choice(age_groups),
                random.choice(cities),
                (datetime.now() - timedelta(days=random.randint(30, 1095))).strftime("%Y-%m-%d")
            ))
        
        users_schema = StructType([
            StructField("user_id", IntegerType(), True),
            StructField("email", StringType(), True),
            StructField("gender", StringType(), True),
            StructField("age_group", StringType(), True),
            StructField("city", StringType(), True),
            StructField("register_date", StringType(), True)
        ])
        
        self.users_df = self.spark.createDataFrame(users_data, users_schema)
        self.users_df.createOrReplaceTempView("users")
        
        # 2. 商品数据
        products_data = []
        categories = ["电子产品", "服装", "家居", "图书", "运动", "美妆", "食品", "母婴"]
        brands = ["品牌A", "品牌B", "品牌C", "品牌D", "品牌E"]
        
        for i in range(1000):
            products_data.append((
                i + 1,
                f"商品_{i+1}",
                random.choice(categories),
                random.choice(brands),
                round(random.uniform(10, 5000), 2),
                random.randint(0, 1000)
            ))
        
        products_schema = StructType([
            StructField("product_id", IntegerType(), True),
            StructField("product_name", StringType(), True),
            StructField("category", StringType(), True),
            StructField("brand", StringType(), True),
            StructField("price", DoubleType(), True),
            StructField("stock", IntegerType(), True)
        ])
        
        self.products_df = self.spark.createDataFrame(products_data, products_schema)
        self.products_df.createOrReplaceTempView("products")
        
        # 3. 订单数据
        orders_data = []
        statuses = ["已下单", "已支付", "已发货", "已完成", "已取消"]
        
        for i in range(50000):
            order_date = datetime.now() - timedelta(days=random.randint(1, 365))
            orders_data.append((
                i + 1,
                random.randint(1, 10000),  # user_id
                order_date.strftime("%Y-%m-%d"),
                random.choice(statuses),
                round(random.uniform(50, 2000), 2)
            ))
        
        orders_schema = StructType([
            StructField("order_id", IntegerType(), True),
            StructField("user_id", IntegerType(), True),
            StructField("order_date", StringType(), True),
            StructField("status", StringType(), True),
            StructField("total_amount", DoubleType(), True)
        ])
        
        self.orders_df = self.spark.createDataFrame(orders_data, orders_schema)
        self.orders_df.createOrReplaceTempView("orders")
        
        # 4. 订单详情数据
        order_items_data = []
        for i in range(100000):
            order_items_data.append((
                i + 1,
                random.randint(1, 50000),  # order_id
                random.randint(1, 1000),   # product_id
                random.randint(1, 5),      # quantity
                round(random.uniform(10, 500), 2)  # unit_price
            ))
        
        order_items_schema = StructType([
            StructField("item_id", IntegerType(), True),
            StructField("order_id", IntegerType(), True),
            StructField("product_id", IntegerType(), True),
            StructField("quantity", IntegerType(), True),
            StructField("unit_price", DoubleType(), True)
        ])
        
        self.order_items_df = self.spark.createDataFrame(order_items_data, order_items_schema)
        self.order_items_df.createOrReplaceTempView("order_items")
        
        print(f"用户数据: {self.users_df.count():,} 条")
        print(f"商品数据: {self.products_df.count():,} 条")
        print(f"订单数据: {self.orders_df.count():,} 条")
        print(f"订单详情数据: {self.order_items_df.count():,} 条")
    
    def analyze_user_behavior(self):
        """
        用户行为分析
        """
        print("\n\n用户行为分析:")
        print("=" * 15)
        
        # 1. 用户注册趋势
        print("\n1. 用户注册趋势:")
        
        registration_trend = self.spark.sql("""
            SELECT 
                YEAR(register_date) as year,
                MONTH(register_date) as month,
                COUNT(*) as new_users
            FROM users
            GROUP BY YEAR(register_date), MONTH(register_date)
            ORDER BY year, month
        """)
        
        print("用户注册趋势:")
        registration_trend.show()
        
        # 2. 用户地域分布
        print("\n2. 用户地域分布:")
        
        city_distribution = self.spark.sql("""
            SELECT 
                city,
                COUNT(*) as user_count,
                ROUND(COUNT(*) * 100.0 / (SELECT COUNT(*) FROM users), 2) as percentage
            FROM users
            GROUP BY city
            ORDER BY user_count DESC
        """)
        
        print("用户地域分布:")
        city_distribution.show()
        
        # 3. 用户年龄分布
        print("\n3. 用户年龄分布:")
        
        age_distribution = self.spark.sql("""
            SELECT 
                age_group,
                gender,
                COUNT(*) as user_count
            FROM users
            GROUP BY age_group, gender
            ORDER BY 
                CASE age_group
                    WHEN '18-25' THEN 1
                    WHEN '26-35' THEN 2
                    WHEN '36-45' THEN 3
                    WHEN '46-55' THEN 4
                    WHEN '55+' THEN 5
                END,
                gender
        """)
        
        print("用户年龄分布:")
        age_distribution.show()
        
        # 4. 活跃用户分析
        print("\n4. 活跃用户分析:")
        
        active_users = self.spark.sql("""
            WITH user_orders AS (
                SELECT 
                    u.user_id,
                    u.city,
                    u.age_group,
                    COUNT(o.order_id) as order_count,
                    SUM(o.total_amount) as total_spent,
                    MAX(o.order_date) as last_order_date,
                    DATEDIFF(CURRENT_DATE(), MAX(o.order_date)) as days_since_last_order
                FROM users u
                LEFT JOIN orders o ON u.user_id = o.user_id
                GROUP BY u.user_id, u.city, u.age_group
            )
            SELECT 
                CASE 
                    WHEN order_count = 0 THEN '未购买用户'
                    WHEN days_since_last_order <= 30 THEN '活跃用户'
                    WHEN days_since_last_order <= 90 THEN '一般用户'
                    ELSE '流失用户'
                END as user_type,
                COUNT(*) as user_count,
                AVG(order_count) as avg_orders,
                AVG(total_spent) as avg_spent
            FROM user_orders
            GROUP BY 
                CASE 
                    WHEN order_count = 0 THEN '未购买用户'
                    WHEN days_since_last_order <= 30 THEN '活跃用户'
                    WHEN days_since_last_order <= 90 THEN '一般用户'
                    ELSE '流失用户'
                END
            ORDER BY user_count DESC
        """)
        
        print("活跃用户分析:")
        active_users.show()
        
        return registration_trend
    
    def analyze_sales_performance(self):
        """
        销售业绩分析
        """
        print("\n\n销售业绩分析:")
        print("=" * 15)
        
        # 1. 月度销售趋势
        print("\n1. 月度销售趋势:")
        
        monthly_sales = self.spark.sql("""
            SELECT 
                YEAR(order_date) as year,
                MONTH(order_date) as month,
                COUNT(DISTINCT order_id) as order_count,
                COUNT(DISTINCT user_id) as customer_count,
                SUM(total_amount) as total_revenue,
                AVG(total_amount) as avg_order_value
            FROM orders
            WHERE status IN ('已支付', '已发货', '已完成')
            GROUP BY YEAR(order_date), MONTH(order_date)
            ORDER BY year, month
        """)
        
        print("月度销售趋势:")
        monthly_sales.show()
        
        # 2. 商品类别销售分析
        print("\n2. 商品类别销售分析:")
        
        category_sales = self.spark.sql("""
            SELECT 
                p.category,
                COUNT(DISTINCT oi.order_id) as order_count,
                SUM(oi.quantity) as total_quantity,
                SUM(oi.quantity * oi.unit_price) as total_revenue,
                AVG(oi.unit_price) as avg_price
            FROM order_items oi
            JOIN products p ON oi.product_id = p.product_id
            JOIN orders o ON oi.order_id = o.order_id
            WHERE o.status IN ('已支付', '已发货', '已完成')
            GROUP BY p.category
            ORDER BY total_revenue DESC
        """)
        
        print("商品类别销售分析:")
        category_sales.show()
        
        # 3. 品牌销售排行
        print("\n3. 品牌销售排行:")
        
        brand_ranking = self.spark.sql("""
            SELECT 
                p.brand,
                p.category,
                COUNT(DISTINCT oi.order_id) as order_count,
                SUM(oi.quantity) as total_quantity,
                SUM(oi.quantity * oi.unit_price) as total_revenue,
                RANK() OVER (PARTITION BY p.category ORDER BY SUM(oi.quantity * oi.unit_price) DESC) as category_rank
            FROM order_items oi
            JOIN products p ON oi.product_id = p.product_id
            JOIN orders o ON oi.order_id = o.order_id
            WHERE o.status IN ('已支付', '已发货', '已完成')
            GROUP BY p.brand, p.category
            ORDER BY p.category, category_rank
        """)
        
        print("品牌销售排行:")
        brand_ranking.show(20)
        
        # 4. 热销商品分析
        print("\n4. 热销商品分析:")
        
        hot_products = self.spark.sql("""
            SELECT 
                p.product_name,
                p.category,
                p.brand,
                p.price,
                SUM(oi.quantity) as total_sold,
                COUNT(DISTINCT oi.order_id) as order_count,
                SUM(oi.quantity * oi.unit_price) as total_revenue,
                ROUND(SUM(oi.quantity * oi.unit_price) / SUM(oi.quantity), 2) as avg_selling_price
            FROM order_items oi
            JOIN products p ON oi.product_id = p.product_id
            JOIN orders o ON oi.order_id = o.order_id
            WHERE o.status IN ('已支付', '已发货', '已完成')
            GROUP BY p.product_id, p.product_name, p.category, p.brand, p.price
            ORDER BY total_sold DESC
            LIMIT 20
        """)
        
        print("热销商品TOP20:")
        hot_products.show(truncate=False)
        
        return monthly_sales
    
    def analyze_customer_segments(self):
        """
        客户细分分析
        """
        print("\n\n客户细分分析:")
        print("=" * 15)
        
        # 1. RFM分析
        print("\n1. RFM客户价值分析:")
        
        rfm_analysis = self.spark.sql("""
            WITH customer_rfm AS (
                SELECT 
                    u.user_id,
                    u.city,
                    u.age_group,
                    DATEDIFF(CURRENT_DATE(), MAX(o.order_date)) as recency,
                    COUNT(o.order_id) as frequency,
                    SUM(o.total_amount) as monetary
                FROM users u
                JOIN orders o ON u.user_id = o.user_id
                WHERE o.status IN ('已支付', '已发货', '已完成')
                GROUP BY u.user_id, u.city, u.age_group
            ),
            rfm_scores AS (
                SELECT 
                    *,
                    CASE 
                        WHEN recency <= 30 THEN 5
                        WHEN recency <= 60 THEN 4
                        WHEN recency <= 90 THEN 3
                        WHEN recency <= 180 THEN 2
                        ELSE 1
                    END as r_score,
                    CASE 
                        WHEN frequency >= 10 THEN 5
                        WHEN frequency >= 7 THEN 4
                        WHEN frequency >= 4 THEN 3
                        WHEN frequency >= 2 THEN 2
                        ELSE 1
                    END as f_score,
                    CASE 
                        WHEN monetary >= 5000 THEN 5
                        WHEN monetary >= 2000 THEN 4
                        WHEN monetary >= 1000 THEN 3
                        WHEN monetary >= 500 THEN 2
                        ELSE 1
                    END as m_score
                FROM customer_rfm
            )
            SELECT 
                CASE 
                    WHEN r_score >= 4 AND f_score >= 4 AND m_score >= 4 THEN '重要价值客户'
                    WHEN r_score >= 4 AND f_score >= 3 THEN '重要发展客户'
                    WHEN r_score >= 3 AND f_score >= 4 AND m_score >= 4 THEN '重要保持客户'
                    WHEN r_score >= 3 AND f_score >= 3 THEN '重要挽留客户'
                    WHEN f_score >= 4 AND m_score >= 4 THEN '一般价值客户'
                    WHEN f_score >= 3 THEN '一般发展客户'
                    WHEN m_score >= 4 THEN '一般保持客户'
                    ELSE '一般挽留客户'
                END as customer_segment,
                COUNT(*) as customer_count,
                AVG(recency) as avg_recency,
                AVG(frequency) as avg_frequency,
                AVG(monetary) as avg_monetary
            FROM rfm_scores
            GROUP BY 
                CASE 
                    WHEN r_score >= 4 AND f_score >= 4 AND m_score >= 4 THEN '重要价值客户'
                    WHEN r_score >= 4 AND f_score >= 3 THEN '重要发展客户'
                    WHEN r_score >= 3 AND f_score >= 4 AND m_score >= 4 THEN '重要保持客户'
                    WHEN r_score >= 3 AND f_score >= 3 THEN '重要挽留客户'
                    WHEN f_score >= 4 AND m_score >= 4 THEN '一般价值客户'
                    WHEN f_score >= 3 THEN '一般发展客户'
                    WHEN m_score >= 4 THEN '一般保持客户'
                    ELSE '一般挽留客户'
                END
            ORDER BY avg_monetary DESC
        """)
        
        print("RFM客户细分:")
        rfm_analysis.show(truncate=False)
        
        # 2. 地域客户分析
        print("\n2. 地域客户价值分析:")
        
        city_customer_analysis = self.spark.sql("""
            SELECT 
                u.city,
                COUNT(DISTINCT u.user_id) as total_customers,
                COUNT(DISTINCT o.user_id) as paying_customers,
                ROUND(COUNT(DISTINCT o.user_id) * 100.0 / COUNT(DISTINCT u.user_id), 2) as conversion_rate,
                COUNT(o.order_id) as total_orders,
                SUM(o.total_amount) as total_revenue,
                AVG(o.total_amount) as avg_order_value,
                SUM(o.total_amount) / COUNT(DISTINCT u.user_id) as revenue_per_user
            FROM users u
            LEFT JOIN orders o ON u.user_id = o.user_id AND o.status IN ('已支付', '已发货', '已完成')
            GROUP BY u.city
            ORDER BY total_revenue DESC
        """)
        
        print("地域客户价值分析:")
        city_customer_analysis.show()
        
        return rfm_analysis
    
    def generate_business_insights(self):
        """
        生成业务洞察报告
        """
        print("\n\n业务洞察报告:")
        print("=" * 15)
        
        # 1. 整体业务概况
        print("\n1. 整体业务概况:")
        
        business_overview = self.spark.sql("""
            SELECT 
                COUNT(DISTINCT u.user_id) as total_users,
                COUNT(DISTINCT o.user_id) as paying_users,
                COUNT(DISTINCT o.order_id) as total_orders,
                COUNT(DISTINCT p.product_id) as total_products,
                SUM(o.total_amount) as total_revenue,
                AVG(o.total_amount) as avg_order_value,
                SUM(o.total_amount) / COUNT(DISTINCT o.user_id) as avg_customer_value
            FROM users u
            CROSS JOIN products p
            LEFT JOIN orders o ON u.user_id = o.user_id AND o.status IN ('已支付', '已发货', '已完成')
        """)
        
        print("业务概况:")
        business_overview.show()
        
        # 2. 关键业务指标
        print("\n2. 关键业务指标:")
        
        key_metrics = self.spark.sql("""
            WITH monthly_metrics AS (
                SELECT 
                    YEAR(order_date) as year,
                    MONTH(order_date) as month,
                    COUNT(DISTINCT order_id) as orders,
                    COUNT(DISTINCT user_id) as customers,
                    SUM(total_amount) as revenue
                FROM orders
                WHERE status IN ('已支付', '已发货', '已完成')
                GROUP BY YEAR(order_date), MONTH(order_date)
            )
            SELECT 
                year,
                month,
                orders,
                customers,
                revenue,
                LAG(revenue, 1) OVER (ORDER BY year, month) as prev_month_revenue,
                ROUND((revenue - LAG(revenue, 1) OVER (ORDER BY year, month)) * 100.0 / 
                      LAG(revenue, 1) OVER (ORDER BY year, month), 2) as revenue_growth_rate
            FROM monthly_metrics
            ORDER BY year, month
        """)
        
        print("月度关键指标:")
        key_metrics.show()
        
        # 3. 商品推荐
        print("\n3. 商品推荐分析:")
        
        product_recommendations = self.spark.sql("""
            WITH product_pairs AS (
                SELECT 
                    oi1.product_id as product_a,
                    oi2.product_id as product_b,
                    COUNT(*) as co_purchase_count
                FROM order_items oi1
                JOIN order_items oi2 ON oi1.order_id = oi2.order_id AND oi1.product_id < oi2.product_id
                JOIN orders o ON oi1.order_id = o.order_id
                WHERE o.status IN ('已支付', '已发货', '已完成')
                GROUP BY oi1.product_id, oi2.product_id
                HAVING COUNT(*) >= 5
            )
            SELECT 
                p1.product_name as product_a_name,
                p1.category as category_a,
                p2.product_name as product_b_name,
                p2.category as category_b,
                pp.co_purchase_count,
                ROUND(pp.co_purchase_count * 100.0 / (
                    SELECT COUNT(DISTINCT order_id) 
                    FROM order_items 
                    WHERE product_id = pp.product_a
                ), 2) as recommendation_strength
            FROM product_pairs pp
            JOIN products p1 ON pp.product_a = p1.product_id
            JOIN products p2 ON pp.product_b = p2.product_id
            ORDER BY co_purchase_count DESC
            LIMIT 20
        """)
        
        print("商品关联推荐TOP20:")
        product_recommendations.show(truncate=False)
        
        return business_overview
    
    def cleanup(self):
        """
        清理资源
        """
        # 清除临时视图
        self.spark.sql("DROP VIEW IF EXISTS users")
        self.spark.sql("DROP VIEW IF EXISTS products")
        self.spark.sql("DROP VIEW IF EXISTS orders")
        self.spark.sql("DROP VIEW IF EXISTS order_items")
        
        print("\n电商分析系统资源已清理")

# 电商数据分析系统演示
ecommerce_system = ECommerceAnalyticsSystem()

print("\n电商数据分析系统演示")
print("=" * 25)

# 用户行为分析
user_analysis = ecommerce_system.analyze_user_behavior()

# 销售业绩分析
sales_analysis = ecommerce_system.analyze_sales_performance()

# 客户细分分析
customer_analysis = ecommerce_system.analyze_customer_segments()

# 业务洞察报告
business_insights = ecommerce_system.generate_business_insights()

print("\n电商分析系统总结:")
print("=" * 20)
print(f"✓ 用户行为分析完成")
print(f"✓ 销售业绩分析完成")
print(f"✓ 客户细分分析完成")
print(f"✓ 业务洞察报告完成")

print("\n分析价值:")
print("- 深入了解用户行为模式")
print("- 识别高价值客户群体")
print("- 优化商品推荐策略")
print("- 制定精准营销方案")
print("- 提升业务决策质量")

# 清理资源
ecommerce_system.cleanup()

3.8 本章小结

核心概念回顾

本章深入学习了Spark SQL和DataFrame的核心技术:

Spark SQL基础 - Spark SQL架构和查询执行流程 - 统一数据访问和多数据源支持 - SQL查询优化器和执行引擎 - 与传统RDD的性能对比

DataFrame操作 - DataFrame创建和基本操作 - 聚合、分组和窗口函数 - 连接操作和性能优化 - 透视和数据重塑操作

SQL查询技术 - 基础和高级SQL查询语法 - 窗口函数和公共表表达式(CTE) - 内置函数和条件表达式 - 查询性能优化技巧

数据源操作 - 多种文件格式支持(CSV、JSON、Parquet) - 数据库连接和JDBC操作 - 流数据源和实时处理 - 数据源选项和配置优化

性能优化策略 - 缓存和持久化优化 - 连接优化和广播变量 - 谓词下推和列式存储 - 自适应查询执行(AQE) - 内存管理和垃圾回收

实践技能总结

通过本章学习,你已经掌握了:

  1. 数据处理能力

    • 使用DataFrame API进行复杂数据操作
    • 编写高效的SQL查询语句
    • 处理多种数据源和格式
    • 实现数据清洗和转换
  2. 性能优化技能

    • 识别和解决性能瓶颈
    • 合理使用缓存策略
    • 优化连接和聚合操作
    • 配置Spark SQL参数
  3. 业务分析能力

    • 设计数据分析架构
    • 实现复杂业务逻辑
    • 生成业务洞察报告
    • 支持数据驱动决策

最佳实践

  1. 查询优化

    • 尽早过滤数据减少数据量
    • 合理使用分区和索引
    • 选择合适的连接策略
    • 避免不必要的数据shuffle
  2. 资源管理

    • 根据数据量配置内存
    • 合理设置并行度
    • 监控资源使用情况
    • 及时清理缓存
  3. 代码规范

    • 使用有意义的表名和列名
    • 添加适当的注释
    • 模块化复杂查询
    • 处理异常情况

下一章预告

下一章我们将学习Spark Streaming流处理,内容包括: - 流处理基础概念 - DStream和Structured Streaming - 实时数据处理架构 - 窗口操作和状态管理 - 容错和检查点机制 - 实时分析案例实战

练习题

  1. 基础练习

    • 创建包含学生信息的DataFrame,实现成绩统计分析
    • 使用SQL查询实现复杂的数据聚合操作
    • 比较不同文件格式的读写性能
  2. 进阶练习

    • 设计一个日志分析系统,实现实时监控
    • 优化大表连接查询的性能
    • 实现自定义聚合函数
  3. 项目练习

    • 构建完整的数据仓库ETL流程
    • 设计实时推荐系统
    • 实现多维数据分析平台

通过本章的学习,你已经具备了使用Spark SQL进行大规模数据分析的能力,为后续的流处理和机器学习奠定了坚实基础。