12.1 章节概述
本章将深入学习两种重要的行为型设计模式:迭代器模式(Iterator Pattern)和访问者模式(Visitor Pattern)。这两种模式都涉及对象的遍历和操作,但解决的问题和应用场景有所不同。
12.1.1 学习目标
- 理解迭代器模式的核心思想和实现方式
- 掌握访问者模式的设计原理和应用场景
- 学会在实际项目中选择和应用这两种模式
- 了解两种模式的优缺点和最佳实践
- 掌握模式的组合使用技巧
12.1.2 应用场景预览
迭代器模式适用于: - 需要遍历聚合对象而不暴露其内部结构 - 为聚合对象提供多种遍历方式 - 为不同的聚合结构提供统一的遍历接口
访问者模式适用于: - 需要对对象结构中的元素执行很多不同操作 - 对象结构稳定但操作经常变化 - 需要在不修改元素类的情况下定义新操作
12.2 迭代器模式(Iterator Pattern)
12.2.1 模式定义
迭代器模式提供一种方法来访问聚合对象中的各个元素,而不需要暴露该对象的内部表示。它将遍历的责任从聚合对象中分离出来,并将其封装到一个迭代器对象中。
12.2.2 模式动机
在软件开发中,我们经常需要遍历各种数据结构(如列表、树、图等)。如果直接在客户端代码中实现遍历逻辑,会导致:
- 耦合度高:客户端代码与具体的数据结构紧密耦合
- 代码重复:相同的遍历逻辑在多处重复
- 难以扩展:添加新的遍历方式需要修改现有代码
- 封装性破坏:需要暴露聚合对象的内部结构
迭代器模式通过引入迭代器对象来解决这些问题。
12.2.3 模式结构
迭代器模式包含以下角色:
- Iterator(迭代器接口):定义访问和遍历元素的接口
- ConcreteIterator(具体迭代器):实现迭代器接口,跟踪遍历的当前位置
- Aggregate(聚合接口):定义创建迭代器对象的接口
- ConcreteAggregate(具体聚合):实现聚合接口,返回具体迭代器实例
12.2.4 Python实现示例:自定义集合遍历
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Iterator as TypingIterator
import random
# 迭代器接口
class Iterator(ABC):
"""迭代器抽象接口"""
@abstractmethod
def has_next(self) -> bool:
"""检查是否还有下一个元素"""
pass
@abstractmethod
def next(self) -> Any:
"""获取下一个元素"""
pass
@abstractmethod
def reset(self):
"""重置迭代器到初始位置"""
pass
# 聚合接口
class Aggregate(ABC):
"""聚合对象抽象接口"""
@abstractmethod
def create_iterator(self) -> Iterator:
"""创建迭代器"""
pass
# 书籍类
class Book:
"""书籍类"""
def __init__(self, title: str, author: str, isbn: str, year: int):
self.title = title
self.author = author
self.isbn = isbn
self.year = year
def __str__(self):
return f"《{self.title}》 - {self.author} ({self.year})"
def __repr__(self):
return f"Book('{self.title}', '{self.author}', '{self.isbn}', {self.year})"
# 书架类(具体聚合)
class BookShelf(Aggregate):
"""书架类 - 具体聚合对象"""
def __init__(self):
self._books: List[Book] = []
def add_book(self, book: Book):
"""添加书籍"""
self._books.append(book)
print(f"添加书籍: {book}")
def remove_book(self, index: int) -> Optional[Book]:
"""移除书籍"""
if 0 <= index < len(self._books):
book = self._books.pop(index)
print(f"移除书籍: {book}")
return book
return None
def get_book(self, index: int) -> Optional[Book]:
"""获取指定位置的书籍"""
if 0 <= index < len(self._books):
return self._books[index]
return None
def get_count(self) -> int:
"""获取书籍数量"""
return len(self._books)
def create_iterator(self) -> Iterator:
"""创建顺序迭代器"""
return BookShelfIterator(self)
def create_reverse_iterator(self) -> Iterator:
"""创建逆序迭代器"""
return ReverseBookShelfIterator(self)
def create_random_iterator(self) -> Iterator:
"""创建随机迭代器"""
return RandomBookShelfIterator(self)
def create_filtered_iterator(self, filter_func) -> Iterator:
"""创建过滤迭代器"""
return FilteredBookShelfIterator(self, filter_func)
# 顺序迭代器
class BookShelfIterator(Iterator):
"""书架顺序迭代器"""
def __init__(self, bookshelf: BookShelf):
self._bookshelf = bookshelf
self._index = 0
def has_next(self) -> bool:
return self._index < self._bookshelf.get_count()
def next(self) -> Book:
if not self.has_next():
raise StopIteration("没有更多书籍")
book = self._bookshelf.get_book(self._index)
self._index += 1
return book
def reset(self):
self._index = 0
# 逆序迭代器
class ReverseBookShelfIterator(Iterator):
"""书架逆序迭代器"""
def __init__(self, bookshelf: BookShelf):
self._bookshelf = bookshelf
self._index = bookshelf.get_count() - 1
def has_next(self) -> bool:
return self._index >= 0
def next(self) -> Book:
if not self.has_next():
raise StopIteration("没有更多书籍")
book = self._bookshelf.get_book(self._index)
self._index -= 1
return book
def reset(self):
self._index = self._bookshelf.get_count() - 1
# 随机迭代器
class RandomBookShelfIterator(Iterator):
"""书架随机迭代器"""
def __init__(self, bookshelf: BookShelf):
self._bookshelf = bookshelf
self._indices = list(range(bookshelf.get_count()))
random.shuffle(self._indices)
self._current = 0
def has_next(self) -> bool:
return self._current < len(self._indices)
def next(self) -> Book:
if not self.has_next():
raise StopIteration("没有更多书籍")
index = self._indices[self._current]
book = self._bookshelf.get_book(index)
self._current += 1
return book
def reset(self):
random.shuffle(self._indices)
self._current = 0
# 过滤迭代器
class FilteredBookShelfIterator(Iterator):
"""书架过滤迭代器"""
def __init__(self, bookshelf: BookShelf, filter_func):
self._bookshelf = bookshelf
self._filter_func = filter_func
self._index = 0
self._find_next_valid_index()
def _find_next_valid_index(self):
"""找到下一个符合条件的索引"""
while (self._index < self._bookshelf.get_count() and
not self._filter_func(self._bookshelf.get_book(self._index))):
self._index += 1
def has_next(self) -> bool:
return self._index < self._bookshelf.get_count()
def next(self) -> Book:
if not self.has_next():
raise StopIteration("没有更多符合条件的书籍")
book = self._bookshelf.get_book(self._index)
self._index += 1
self._find_next_valid_index()
return book
def reset(self):
self._index = 0
self._find_next_valid_index()
# Python风格的迭代器适配器
class PythonIteratorAdapter:
"""Python风格迭代器适配器"""
def __init__(self, iterator: Iterator):
self._iterator = iterator
def __iter__(self):
self._iterator.reset()
return self
def __next__(self):
if self._iterator.has_next():
return self._iterator.next()
else:
raise StopIteration
# 迭代器工厂
class IteratorFactory:
"""迭代器工厂"""
@staticmethod
def create_iterator(bookshelf: BookShelf, iterator_type: str, **kwargs) -> Iterator:
"""创建指定类型的迭代器"""
if iterator_type == "sequential":
return bookshelf.create_iterator()
elif iterator_type == "reverse":
return bookshelf.create_reverse_iterator()
elif iterator_type == "random":
return bookshelf.create_random_iterator()
elif iterator_type == "filtered":
filter_func = kwargs.get('filter_func')
if not filter_func:
raise ValueError("过滤迭代器需要提供filter_func参数")
return bookshelf.create_filtered_iterator(filter_func)
else:
raise ValueError(f"不支持的迭代器类型: {iterator_type}")
# 迭代器使用示例
def demonstrate_iterator_pattern():
"""演示迭代器模式"""
print("=== 迭代器模式演示 ===")
# 创建书架并添加书籍
bookshelf = BookShelf()
books = [
Book("设计模式", "GoF", "978-0201633610", 1994),
Book("重构", "Martin Fowler", "978-0201485677", 1999),
Book("代码大全", "Steve McConnell", "978-0735619678", 2004),
Book("程序员修炼之道", "Andrew Hunt", "978-0201616224", 1999),
Book("敏捷软件开发", "Robert C. Martin", "978-0135974445", 2002)
]
for book in books:
bookshelf.add_book(book)
print(f"\n书架共有 {bookshelf.get_count()} 本书\n")
# 1. 顺序遍历
print("--- 顺序遍历 ---")
iterator = bookshelf.create_iterator()
while iterator.has_next():
book = iterator.next()
print(f" {book}")
# 2. 逆序遍历
print("\n--- 逆序遍历 ---")
reverse_iterator = bookshelf.create_reverse_iterator()
while reverse_iterator.has_next():
book = reverse_iterator.next()
print(f" {book}")
# 3. 随机遍历
print("\n--- 随机遍历 ---")
random_iterator = bookshelf.create_random_iterator()
while random_iterator.has_next():
book = random_iterator.next()
print(f" {book}")
# 4. 过滤遍历(只显示1999年以后的书)
print("\n--- 过滤遍历(2000年以后的书籍) ---")
filter_func = lambda book: book.year >= 2000
filtered_iterator = bookshelf.create_filtered_iterator(filter_func)
while filtered_iterator.has_next():
book = filtered_iterator.next()
print(f" {book}")
# 5. 使用Python风格的迭代器
print("\n--- Python风格迭代器 ---")
python_iterator = PythonIteratorAdapter(bookshelf.create_iterator())
for book in python_iterator:
print(f" {book}")
# 6. 使用工厂创建迭代器
print("\n--- 使用工厂创建迭代器 ---")
factory_iterator = IteratorFactory.create_iterator(
bookshelf, "filtered",
filter_func=lambda book: "Martin" in book.author
)
print("作者名字包含'Martin'的书籍:")
while factory_iterator.has_next():
book = factory_iterator.next()
print(f" {book}")
if __name__ == "__main__":
demonstrate_iterator_pattern()
12.2.5 Java实现示例:树结构遍历
import java.util.*;
import java.util.function.Consumer;
// 迭代器接口
interface TreeIterator<T> {
boolean hasNext();
T next();
void reset();
}
// 树节点类
class TreeNode<T> {
private T data;
private List<TreeNode<T>> children;
private TreeNode<T> parent;
public TreeNode(T data) {
this.data = data;
this.children = new ArrayList<>();
this.parent = null;
}
public void addChild(TreeNode<T> child) {
children.add(child);
child.parent = this;
}
public void removeChild(TreeNode<T> child) {
children.remove(child);
child.parent = null;
}
// Getters
public T getData() { return data; }
public List<TreeNode<T>> getChildren() { return new ArrayList<>(children); }
public TreeNode<T> getParent() { return parent; }
public boolean isLeaf() { return children.isEmpty(); }
public boolean isRoot() { return parent == null; }
@Override
public String toString() {
return data.toString();
}
}
// 树结构类
class Tree<T> {
private TreeNode<T> root;
public Tree(T rootData) {
this.root = new TreeNode<>(rootData);
}
public TreeNode<T> getRoot() {
return root;
}
// 创建不同类型的迭代器
public TreeIterator<T> createDepthFirstIterator() {
return new DepthFirstIterator<>(root);
}
public TreeIterator<T> createBreadthFirstIterator() {
return new BreadthFirstIterator<>(root);
}
public TreeIterator<T> createLeafIterator() {
return new LeafIterator<>(root);
}
public TreeIterator<T> createLevelIterator(int level) {
return new LevelIterator<>(root, level);
}
// 添加节点的便捷方法
public TreeNode<T> addNode(TreeNode<T> parent, T data) {
TreeNode<T> newNode = new TreeNode<>(data);
parent.addChild(newNode);
return newNode;
}
// 打印树结构
public void printTree() {
printTree(root, "", true);
}
private void printTree(TreeNode<T> node, String prefix, boolean isLast) {
System.out.println(prefix + (isLast ? "└── " : "├── ") + node.getData());
List<TreeNode<T>> children = node.getChildren();
for (int i = 0; i < children.size(); i++) {
boolean isLastChild = (i == children.size() - 1);
String newPrefix = prefix + (isLast ? " " : "│ ");
printTree(children.get(i), newPrefix, isLastChild);
}
}
}
// 深度优先迭代器
class DepthFirstIterator<T> implements TreeIterator<T> {
private Stack<TreeNode<T>> stack;
private TreeNode<T> root;
public DepthFirstIterator(TreeNode<T> root) {
this.root = root;
this.stack = new Stack<>();
reset();
}
@Override
public boolean hasNext() {
return !stack.isEmpty();
}
@Override
public T next() {
if (!hasNext()) {
throw new NoSuchElementException("没有更多元素");
}
TreeNode<T> current = stack.pop();
// 将子节点逆序压入栈(保证从左到右的遍历顺序)
List<TreeNode<T>> children = current.getChildren();
for (int i = children.size() - 1; i >= 0; i--) {
stack.push(children.get(i));
}
return current.getData();
}
@Override
public void reset() {
stack.clear();
if (root != null) {
stack.push(root);
}
}
}
// 广度优先迭代器
class BreadthFirstIterator<T> implements TreeIterator<T> {
private Queue<TreeNode<T>> queue;
private TreeNode<T> root;
public BreadthFirstIterator(TreeNode<T> root) {
this.root = root;
this.queue = new LinkedList<>();
reset();
}
@Override
public boolean hasNext() {
return !queue.isEmpty();
}
@Override
public T next() {
if (!hasNext()) {
throw new NoSuchElementException("没有更多元素");
}
TreeNode<T> current = queue.poll();
// 将所有子节点加入队列
queue.addAll(current.getChildren());
return current.getData();
}
@Override
public void reset() {
queue.clear();
if (root != null) {
queue.offer(root);
}
}
}
// 叶子节点迭代器
class LeafIterator<T> implements TreeIterator<T> {
private List<TreeNode<T>> leafNodes;
private int currentIndex;
private TreeNode<T> root;
public LeafIterator(TreeNode<T> root) {
this.root = root;
this.leafNodes = new ArrayList<>();
this.currentIndex = 0;
reset();
}
@Override
public boolean hasNext() {
return currentIndex < leafNodes.size();
}
@Override
public T next() {
if (!hasNext()) {
throw new NoSuchElementException("没有更多叶子节点");
}
return leafNodes.get(currentIndex++).getData();
}
@Override
public void reset() {
leafNodes.clear();
currentIndex = 0;
collectLeafNodes(root);
}
private void collectLeafNodes(TreeNode<T> node) {
if (node == null) return;
if (node.isLeaf()) {
leafNodes.add(node);
} else {
for (TreeNode<T> child : node.getChildren()) {
collectLeafNodes(child);
}
}
}
}
// 指定层级迭代器
class LevelIterator<T> implements TreeIterator<T> {
private List<TreeNode<T>> levelNodes;
private int currentIndex;
private TreeNode<T> root;
private int targetLevel;
public LevelIterator(TreeNode<T> root, int level) {
this.root = root;
this.targetLevel = level;
this.levelNodes = new ArrayList<>();
this.currentIndex = 0;
reset();
}
@Override
public boolean hasNext() {
return currentIndex < levelNodes.size();
}
@Override
public T next() {
if (!hasNext()) {
throw new NoSuchElementException("没有更多指定层级的节点");
}
return levelNodes.get(currentIndex++).getData();
}
@Override
public void reset() {
levelNodes.clear();
currentIndex = 0;
collectLevelNodes(root, 0);
}
private void collectLevelNodes(TreeNode<T> node, int currentLevel) {
if (node == null) return;
if (currentLevel == targetLevel) {
levelNodes.add(node);
} else if (currentLevel < targetLevel) {
for (TreeNode<T> child : node.getChildren()) {
collectLevelNodes(child, currentLevel + 1);
}
}
}
}
// 迭代器工具类
class IteratorUtils {
public static <T> void forEach(TreeIterator<T> iterator, Consumer<T> action) {
iterator.reset();
while (iterator.hasNext()) {
action.accept(iterator.next());
}
}
public static <T> List<T> toList(TreeIterator<T> iterator) {
List<T> result = new ArrayList<>();
iterator.reset();
while (iterator.hasNext()) {
result.add(iterator.next());
}
return result;
}
public static <T> int count(TreeIterator<T> iterator) {
int count = 0;
iterator.reset();
while (iterator.hasNext()) {
iterator.next();
count++;
}
return count;
}
}
// 演示类
public class TreeIteratorDemo {
public static void main(String[] args) {
System.out.println("=== 树结构迭代器模式演示 ===");
// 创建树结构
Tree<String> tree = new Tree<>("根节点");
TreeNode<String> root = tree.getRoot();
// 构建树结构
TreeNode<String> node1 = tree.addNode(root, "节点1");
TreeNode<String> node2 = tree.addNode(root, "节点2");
TreeNode<String> node3 = tree.addNode(root, "节点3");
tree.addNode(node1, "节点1.1");
tree.addNode(node1, "节点1.2");
TreeNode<String> node21 = tree.addNode(node2, "节点2.1");
tree.addNode(node2, "节点2.2");
tree.addNode(node2, "节点2.3");
tree.addNode(node21, "节点2.1.1");
tree.addNode(node21, "节点2.1.2");
tree.addNode(node3, "节点3.1");
// 打印树结构
System.out.println("\n--- 树结构 ---");
tree.printTree();
// 深度优先遍历
System.out.println("\n--- 深度优先遍历 ---");
TreeIterator<String> dfsIterator = tree.createDepthFirstIterator();
IteratorUtils.forEach(dfsIterator, node -> System.out.println(" " + node));
// 广度优先遍历
System.out.println("\n--- 广度优先遍历 ---");
TreeIterator<String> bfsIterator = tree.createBreadthFirstIterator();
IteratorUtils.forEach(bfsIterator, node -> System.out.println(" " + node));
// 叶子节点遍历
System.out.println("\n--- 叶子节点遍历 ---");
TreeIterator<String> leafIterator = tree.createLeafIterator();
IteratorUtils.forEach(leafIterator, node -> System.out.println(" " + node));
// 指定层级遍历
System.out.println("\n--- 第2层节点遍历 ---");
TreeIterator<String> levelIterator = tree.createLevelIterator(2);
IteratorUtils.forEach(levelIterator, node -> System.out.println(" " + node));
// 统计信息
System.out.println("\n--- 统计信息 ---");
System.out.println("总节点数: " + IteratorUtils.count(tree.createDepthFirstIterator()));
System.out.println("叶子节点数: " + IteratorUtils.count(tree.createLeafIterator()));
System.out.println("第2层节点数: " + IteratorUtils.count(tree.createLevelIterator(2)));
}
}
12.2.6 迭代器模式的优缺点
优点: 1. 支持多种遍历方式:可以为同一聚合对象提供多种遍历算法 2. 简化聚合接口:聚合对象不需要提供遍历相关的方法 3. 封装性好:客户端不需要了解聚合对象的内部结构 4. 一致性:为不同的聚合结构提供统一的遍历接口 5. 并发安全:每个迭代器维护自己的遍历状态
缺点: 1. 增加类的数量:每种遍历方式都需要一个迭代器类 2. 性能开销:相比直接访问可能有额外的性能开销 3. 复杂性增加:对于简单的遍历需求可能过于复杂
12.2.7 适用场景
- 访问聚合对象内容而不暴露内部表示
- 支持对聚合对象的多种遍历
- 为遍历不同的聚合结构提供统一接口
- 需要在遍历过程中修改遍历算法
12.3 访问者模式(Visitor Pattern)
12.3.1 模式定义
访问者模式表示一个作用于某对象结构中的各元素的操作。它让你可以在不改变各元素类的前提下定义作用于这些元素的新操作。
12.3.2 模式动机
在面向对象系统中,我们经常遇到这样的情况:
- 对象结构稳定:对象的类层次结构很少改变
- 操作经常变化:需要在这些对象上定义新的操作
- 避免污染:不希望在元素类中添加新的操作方法
- 操作相关性:相关的操作希望集中在一起
访问者模式通过将操作从对象结构中分离出来,解决了这些问题。
12.3.3 模式结构
访问者模式包含以下角色:
- Visitor(访问者接口):声明访问各种元素的方法
- ConcreteVisitor(具体访问者):实现访问者接口,定义具体的操作
- Element(元素接口):定义接受访问者的方法
- ConcreteElement(具体元素):实现元素接口,调用访问者的相应方法
- ObjectStructure(对象结构):包含元素的容器,提供遍历元素的方法
12.3.4 Python实现示例:文件系统操作
from abc import ABC, abstractmethod
from typing import List, Dict, Any
import os
import time
from datetime import datetime
# 访问者接口
class FileSystemVisitor(ABC):
"""文件系统访问者接口"""
@abstractmethod
def visit_file(self, file_element: 'FileElement'):
"""访问文件"""
pass
@abstractmethod
def visit_directory(self, directory_element: 'DirectoryElement'):
"""访问目录"""
pass
@abstractmethod
def visit_symlink(self, symlink_element: 'SymlinkElement'):
"""访问符号链接"""
pass
# 元素接口
class FileSystemElement(ABC):
"""文件系统元素接口"""
def __init__(self, name: str, path: str, size: int, modified_time: float):
self.name = name
self.path = path
self.size = size
self.modified_time = modified_time
@abstractmethod
def accept(self, visitor: FileSystemVisitor):
"""接受访问者"""
pass
def get_modified_datetime(self) -> datetime:
"""获取修改时间的datetime对象"""
return datetime.fromtimestamp(self.modified_time)
def __str__(self):
return f"{self.name} ({self.size} bytes)"
# 文件元素
class FileElement(FileSystemElement):
"""文件元素"""
def __init__(self, name: str, path: str, size: int, modified_time: float, file_type: str):
super().__init__(name, path, size, modified_time)
self.file_type = file_type
def accept(self, visitor: FileSystemVisitor):
visitor.visit_file(self)
def get_extension(self) -> str:
"""获取文件扩展名"""
return os.path.splitext(self.name)[1].lower()
# 目录元素
class DirectoryElement(FileSystemElement):
"""目录元素"""
def __init__(self, name: str, path: str, size: int, modified_time: float):
super().__init__(name, path, size, modified_time)
self.children: List[FileSystemElement] = []
def add_child(self, child: FileSystemElement):
"""添加子元素"""
self.children.append(child)
def remove_child(self, child: FileSystemElement):
"""移除子元素"""
if child in self.children:
self.children.remove(child)
def accept(self, visitor: FileSystemVisitor):
visitor.visit_directory(self)
# 递归访问子元素
for child in self.children:
child.accept(visitor)
def get_child_count(self) -> int:
"""获取子元素数量"""
return len(self.children)
# 符号链接元素
class SymlinkElement(FileSystemElement):
"""符号链接元素"""
def __init__(self, name: str, path: str, size: int, modified_time: float, target: str):
super().__init__(name, path, size, modified_time)
self.target = target
def accept(self, visitor: FileSystemVisitor):
visitor.visit_symlink(self)
# 具体访问者:文件统计
class FileStatisticsVisitor(FileSystemVisitor):
"""文件统计访问者"""
def __init__(self):
self.reset()
def reset(self):
"""重置统计数据"""
self.file_count = 0
self.directory_count = 0
self.symlink_count = 0
self.total_size = 0
self.file_types: Dict[str, int] = {}
self.largest_file: FileElement = None
self.oldest_file: FileSystemElement = None
self.newest_file: FileSystemElement = None
def visit_file(self, file_element: FileElement):
self.file_count += 1
self.total_size += file_element.size
# 统计文件类型
extension = file_element.get_extension()
if not extension:
extension = "无扩展名"
self.file_types[extension] = self.file_types.get(extension, 0) + 1
# 记录最大文件
if self.largest_file is None or file_element.size > self.largest_file.size:
self.largest_file = file_element
# 记录最旧和最新文件
self._update_time_records(file_element)
def visit_directory(self, directory_element: DirectoryElement):
self.directory_count += 1
self._update_time_records(directory_element)
def visit_symlink(self, symlink_element: SymlinkElement):
self.symlink_count += 1
self._update_time_records(symlink_element)
def _update_time_records(self, element: FileSystemElement):
"""更新时间记录"""
if self.oldest_file is None or element.modified_time < self.oldest_file.modified_time:
self.oldest_file = element
if self.newest_file is None or element.modified_time > self.newest_file.modified_time:
self.newest_file = element
def get_statistics(self) -> Dict[str, Any]:
"""获取统计结果"""
return {
'file_count': self.file_count,
'directory_count': self.directory_count,
'symlink_count': self.symlink_count,
'total_size': self.total_size,
'average_file_size': self.total_size / self.file_count if self.file_count > 0 else 0,
'file_types': self.file_types,
'largest_file': self.largest_file,
'oldest_file': self.oldest_file,
'newest_file': self.newest_file
}
def print_statistics(self):
"""打印统计结果"""
stats = self.get_statistics()
print("=== 文件系统统计 ===")
print(f"文件数量: {stats['file_count']}")
print(f"目录数量: {stats['directory_count']}")
print(f"符号链接数量: {stats['symlink_count']}")
print(f"总大小: {stats['total_size']:,} bytes ({stats['total_size']/1024/1024:.2f} MB)")
if stats['file_count'] > 0:
print(f"平均文件大小: {stats['average_file_size']:.2f} bytes")
print("\n文件类型分布:")
for file_type, count in sorted(stats['file_types'].items()):
print(f" {file_type}: {count} 个文件")
if stats['largest_file']:
print(f"\n最大文件: {stats['largest_file'].name} ({stats['largest_file'].size:,} bytes)")
if stats['oldest_file']:
print(f"最旧文件: {stats['oldest_file'].name} ({stats['oldest_file'].get_modified_datetime()})")
if stats['newest_file']:
print(f"最新文件: {stats['newest_file'].name} ({stats['newest_file'].get_modified_datetime()})")
# 具体访问者:文件搜索
class FileSearchVisitor(FileSystemVisitor):
"""文件搜索访问者"""
def __init__(self, search_criteria: Dict[str, Any]):
self.search_criteria = search_criteria
self.results: List[FileSystemElement] = []
def visit_file(self, file_element: FileElement):
if self._matches_criteria(file_element):
self.results.append(file_element)
def visit_directory(self, directory_element: DirectoryElement):
if self._matches_criteria(directory_element):
self.results.append(directory_element)
def visit_symlink(self, symlink_element: SymlinkElement):
if self._matches_criteria(symlink_element):
self.results.append(symlink_element)
def _matches_criteria(self, element: FileSystemElement) -> bool:
"""检查元素是否匹配搜索条件"""
# 名称匹配
if 'name_contains' in self.search_criteria:
if self.search_criteria['name_contains'].lower() not in element.name.lower():
return False
# 大小范围
if 'min_size' in self.search_criteria:
if element.size < self.search_criteria['min_size']:
return False
if 'max_size' in self.search_criteria:
if element.size > self.search_criteria['max_size']:
return False
# 文件类型(仅对文件有效)
if 'file_extension' in self.search_criteria and isinstance(element, FileElement):
if element.get_extension() != self.search_criteria['file_extension']:
return False
# 修改时间范围
if 'modified_after' in self.search_criteria:
if element.modified_time < self.search_criteria['modified_after']:
return False
if 'modified_before' in self.search_criteria:
if element.modified_time > self.search_criteria['modified_before']:
return False
return True
def get_results(self) -> List[FileSystemElement]:
"""获取搜索结果"""
return self.results.copy()
def print_results(self):
"""打印搜索结果"""
print(f"\n=== 搜索结果 ({len(self.results)} 个匹配项) ===")
for element in self.results:
element_type = type(element).__name__.replace('Element', '')
print(f" [{element_type}] {element.path} - {element.size:,} bytes")
# 具体访问者:文件备份
class FileBackupVisitor(FileSystemVisitor):
"""文件备份访问者"""
def __init__(self, backup_criteria: Dict[str, Any]):
self.backup_criteria = backup_criteria
self.backup_list: List[FileElement] = []
self.total_backup_size = 0
def visit_file(self, file_element: FileElement):
if self._should_backup(file_element):
self.backup_list.append(file_element)
self.total_backup_size += file_element.size
def visit_directory(self, directory_element: DirectoryElement):
# 目录本身不需要备份,但会遍历其子元素
pass
def visit_symlink(self, symlink_element: SymlinkElement):
# 符号链接通常不需要备份
pass
def _should_backup(self, file_element: FileElement) -> bool:
"""判断文件是否需要备份"""
# 检查文件扩展名
if 'include_extensions' in self.backup_criteria:
extensions = self.backup_criteria['include_extensions']
if file_element.get_extension() not in extensions:
return False
if 'exclude_extensions' in self.backup_criteria:
extensions = self.backup_criteria['exclude_extensions']
if file_element.get_extension() in extensions:
return False
# 检查文件大小
if 'max_file_size' in self.backup_criteria:
if file_element.size > self.backup_criteria['max_file_size']:
return False
# 检查修改时间
if 'modified_since' in self.backup_criteria:
if file_element.modified_time < self.backup_criteria['modified_since']:
return False
return True
def get_backup_plan(self) -> Dict[str, Any]:
"""获取备份计划"""
return {
'files_to_backup': self.backup_list.copy(),
'file_count': len(self.backup_list),
'total_size': self.total_backup_size
}
def print_backup_plan(self):
"""打印备份计划"""
plan = self.get_backup_plan()
print(f"\n=== 备份计划 ===")
print(f"需要备份的文件数量: {plan['file_count']}")
print(f"总备份大小: {plan['total_size']:,} bytes ({plan['total_size']/1024/1024:.2f} MB)")
print("\n备份文件列表:")
for file_element in plan['files_to_backup']:
print(f" {file_element.path} - {file_element.size:,} bytes")
# 具体访问者:文件清理
class FileCleanupVisitor(FileSystemVisitor):
"""文件清理访问者"""
def __init__(self, cleanup_criteria: Dict[str, Any]):
self.cleanup_criteria = cleanup_criteria
self.cleanup_list: List[FileSystemElement] = []
self.total_cleanup_size = 0
def visit_file(self, file_element: FileElement):
if self._should_cleanup(file_element):
self.cleanup_list.append(file_element)
self.total_cleanup_size += file_element.size
def visit_directory(self, directory_element: DirectoryElement):
# 检查空目录
if ('remove_empty_dirs' in self.cleanup_criteria and
self.cleanup_criteria['remove_empty_dirs'] and
directory_element.get_child_count() == 0):
self.cleanup_list.append(directory_element)
def visit_symlink(self, symlink_element: SymlinkElement):
# 检查损坏的符号链接
if ('remove_broken_symlinks' in self.cleanup_criteria and
self.cleanup_criteria['remove_broken_symlinks']):
# 这里简化处理,实际应该检查链接目标是否存在
self.cleanup_list.append(symlink_element)
def _should_cleanup(self, file_element: FileElement) -> bool:
"""判断文件是否需要清理"""
current_time = time.time()
# 检查临时文件
if ('remove_temp_files' in self.cleanup_criteria and
self.cleanup_criteria['remove_temp_files']):
temp_extensions = ['.tmp', '.temp', '.bak', '.old']
if file_element.get_extension() in temp_extensions:
return True
# 检查旧文件
if 'older_than_days' in self.cleanup_criteria:
days = self.cleanup_criteria['older_than_days']
age_threshold = current_time - (days * 24 * 60 * 60)
if file_element.modified_time < age_threshold:
return True
# 检查大文件
if 'larger_than_mb' in self.cleanup_criteria:
size_threshold = self.cleanup_criteria['larger_than_mb'] * 1024 * 1024
if file_element.size > size_threshold:
return True
return False
def get_cleanup_plan(self) -> Dict[str, Any]:
"""获取清理计划"""
return {
'items_to_cleanup': self.cleanup_list.copy(),
'item_count': len(self.cleanup_list),
'total_size_saved': self.total_cleanup_size
}
def print_cleanup_plan(self):
"""打印清理计划"""
plan = self.get_cleanup_plan()
print(f"\n=== 清理计划 ===")
print(f"需要清理的项目数量: {plan['item_count']}")
print(f"可释放空间: {plan['total_size_saved']:,} bytes ({plan['total_size_saved']/1024/1024:.2f} MB)")
print("\n清理项目列表:")
for item in plan['items_to_cleanup']:
item_type = type(item).__name__.replace('Element', '')
print(f" [{item_type}] {item.path} - {item.size:,} bytes")
# 文件系统构建器(用于演示)
class FileSystemBuilder:
"""文件系统构建器"""
@staticmethod
def create_sample_filesystem() -> DirectoryElement:
"""创建示例文件系统"""
current_time = time.time()
# 根目录
root = DirectoryElement("project", "/project", 0, current_time)
# 源代码目录
src_dir = DirectoryElement("src", "/project/src", 0, current_time - 86400)
root.add_child(src_dir)
# Python文件
src_dir.add_child(FileElement("main.py", "/project/src/main.py", 2048, current_time - 3600, "Python"))
src_dir.add_child(FileElement("utils.py", "/project/src/utils.py", 1024, current_time - 7200, "Python"))
src_dir.add_child(FileElement("config.py", "/project/src/config.py", 512, current_time - 86400*2, "Python"))
# 文档目录
docs_dir = DirectoryElement("docs", "/project/docs", 0, current_time - 86400*3)
root.add_child(docs_dir)
docs_dir.add_child(FileElement("README.md", "/project/docs/README.md", 4096, current_time - 86400, "Markdown"))
docs_dir.add_child(FileElement("API.md", "/project/docs/API.md", 8192, current_time - 86400*2, "Markdown"))
# 测试目录
tests_dir = DirectoryElement("tests", "/project/tests", 0, current_time - 86400*5)
root.add_child(tests_dir)
tests_dir.add_child(FileElement("test_main.py", "/project/tests/test_main.py", 1536, current_time - 86400*3, "Python"))
tests_dir.add_child(FileElement("test_utils.py", "/project/tests/test_utils.py", 1024, current_time - 86400*4, "Python"))
# 临时文件
root.add_child(FileElement("temp.tmp", "/project/temp.tmp", 256, current_time - 86400*10, "Temporary"))
root.add_child(FileElement("backup.bak", "/project/backup.bak", 10485760, current_time - 86400*30, "Backup"))
# 配置文件
root.add_child(FileElement("config.json", "/project/config.json", 1024, current_time - 86400*7, "JSON"))
root.add_child(FileElement("requirements.txt", "/project/requirements.txt", 512, current_time - 86400*5, "Text"))
# 符号链接
root.add_child(SymlinkElement("link_to_src", "/project/link_to_src", 0, current_time, "/project/src"))
# 空目录
empty_dir = DirectoryElement("empty", "/project/empty", 0, current_time - 86400*20)
root.add_child(empty_dir)
return root
# 演示函数
def demonstrate_visitor_pattern():
"""演示访问者模式"""
print("=== 访问者模式演示:文件系统操作 ===")
# 创建示例文件系统
filesystem = FileSystemBuilder.create_sample_filesystem()
# 1. 文件统计
print("\n1. 文件系统统计")
stats_visitor = FileStatisticsVisitor()
filesystem.accept(stats_visitor)
stats_visitor.print_statistics()
# 2. 文件搜索
print("\n\n2. 文件搜索")
# 搜索Python文件
search_criteria = {
'file_extension': '.py',
'min_size': 1000
}
search_visitor = FileSearchVisitor(search_criteria)
filesystem.accept(search_visitor)
print("搜索条件: Python文件且大小 >= 1000 bytes")
search_visitor.print_results()
# 搜索最近修改的文件
recent_time = time.time() - 86400 * 2 # 2天前
search_criteria = {
'modified_after': recent_time
}
search_visitor = FileSearchVisitor(search_criteria)
filesystem.accept(search_visitor)
print("\n搜索条件: 最近2天修改的文件")
search_visitor.print_results()
# 3. 备份计划
print("\n\n3. 备份计划")
backup_criteria = {
'include_extensions': ['.py', '.md', '.json'],
'max_file_size': 5 * 1024 * 1024, # 5MB
'modified_since': time.time() - 86400 * 7 # 一周内修改的
}
backup_visitor = FileBackupVisitor(backup_criteria)
filesystem.accept(backup_visitor)
backup_visitor.print_backup_plan()
# 4. 清理计划
print("\n\n4. 清理计划")
cleanup_criteria = {
'remove_temp_files': True,
'remove_empty_dirs': True,
'older_than_days': 15,
'larger_than_mb': 5
}
cleanup_visitor = FileCleanupVisitor(cleanup_criteria)
filesystem.accept(cleanup_visitor)
cleanup_visitor.print_cleanup_plan()
if __name__ == "__main__":
demonstrate_visitor_pattern()
12.3.5 Java实现示例:编译器AST处理
import java.util.*;
import java.util.function.Function;
// 访问者接口
interface ASTVisitor {
void visitNumberNode(NumberNode node);
void visitBinaryOpNode(BinaryOpNode node);
void visitUnaryOpNode(UnaryOpNode node);
void visitVariableNode(VariableNode node);
void visitFunctionCallNode(FunctionCallNode node);
}
// AST节点接口
abstract class ASTNode {
public abstract void accept(ASTVisitor visitor);
public abstract String toString();
}
// 数字节点
class NumberNode extends ASTNode {
private double value;
public NumberNode(double value) {
this.value = value;
}
public double getValue() {
return value;
}
@Override
public void accept(ASTVisitor visitor) {
visitor.visitNumberNode(this);
}
@Override
public String toString() {
return String.valueOf(value);
}
}
// 二元操作节点
class BinaryOpNode extends ASTNode {
private ASTNode left;
private ASTNode right;
private String operator;
public BinaryOpNode(ASTNode left, String operator, ASTNode right) {
this.left = left;
this.operator = operator;
this.right = right;
}
public ASTNode getLeft() { return left; }
public ASTNode getRight() { return right; }
public String getOperator() { return operator; }
@Override
public void accept(ASTVisitor visitor) {
visitor.visitBinaryOpNode(this);
}
@Override
public String toString() {
return "(" + left + " " + operator + " " + right + ")";
}
}
// 一元操作节点
class UnaryOpNode extends ASTNode {
private ASTNode operand;
private String operator;
public UnaryOpNode(String operator, ASTNode operand) {
this.operator = operator;
this.operand = operand;
}
public ASTNode getOperand() { return operand; }
public String getOperator() { return operator; }
@Override
public void accept(ASTVisitor visitor) {
visitor.visitUnaryOpNode(this);
}
@Override
public String toString() {
return operator + operand;
}
}
// 变量节点
class VariableNode extends ASTNode {
private String name;
public VariableNode(String name) {
this.name = name;
}
public String getName() {
return name;
}
@Override
public void accept(ASTVisitor visitor) {
visitor.visitVariableNode(this);
}
@Override
public String toString() {
return name;
}
}
// 函数调用节点
class FunctionCallNode extends ASTNode {
private String functionName;
private List<ASTNode> arguments;
public FunctionCallNode(String functionName, List<ASTNode> arguments) {
this.functionName = functionName;
this.arguments = new ArrayList<>(arguments);
}
public String getFunctionName() { return functionName; }
public List<ASTNode> getArguments() { return new ArrayList<>(arguments); }
@Override
public void accept(ASTVisitor visitor) {
visitor.visitFunctionCallNode(this);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(functionName).append("(");
for (int i = 0; i < arguments.size(); i++) {
if (i > 0) sb.append(", ");
sb.append(arguments.get(i));
}
sb.append(")");
return sb.toString();
}
}
// 具体访问者:表达式求值
class EvaluatorVisitor implements ASTVisitor {
private Stack<Double> valueStack;
private Map<String, Double> variables;
private Map<String, Function<List<Double>, Double>> functions;
public EvaluatorVisitor() {
this.valueStack = new Stack<>();
this.variables = new HashMap<>();
this.functions = new HashMap<>();
// 注册内置函数
functions.put("sin", args -> Math.sin(args.get(0)));
functions.put("cos", args -> Math.cos(args.get(0)));
functions.put("sqrt", args -> Math.sqrt(args.get(0)));
functions.put("pow", args -> Math.pow(args.get(0), args.get(1)));
functions.put("max", args -> Collections.max(args));
functions.put("min", args -> Collections.min(args));
}
public void setVariable(String name, double value) {
variables.put(name, value);
}
public void setFunction(String name, Function<List<Double>, Double> function) {
functions.put(name, function);
}
public double evaluate(ASTNode node) {
valueStack.clear();
node.accept(this);
if (valueStack.isEmpty()) {
throw new RuntimeException("求值失败:栈为空");
}
return valueStack.pop();
}
@Override
public void visitNumberNode(NumberNode node) {
valueStack.push(node.getValue());
}
@Override
public void visitBinaryOpNode(BinaryOpNode node) {
// 先求值子节点
node.getLeft().accept(this);
node.getRight().accept(this);
// 从栈中取出操作数
double right = valueStack.pop();
double left = valueStack.pop();
// 执行操作
double result;
switch (node.getOperator()) {
case "+":
result = left + right;
break;
case "-":
result = left - right;
break;
case "*":
result = left * right;
break;
case "/":
if (right == 0) {
throw new ArithmeticException("除零错误");
}
result = left / right;
break;
case "^":
result = Math.pow(left, right);
break;
default:
throw new UnsupportedOperationException("不支持的操作符: " + node.getOperator());
}
valueStack.push(result);
}
@Override
public void visitUnaryOpNode(UnaryOpNode node) {
// 先求值操作数
node.getOperand().accept(this);
double operand = valueStack.pop();
double result;
switch (node.getOperator()) {
case "-":
result = -operand;
break;
case "+":
result = operand;
break;
default:
throw new UnsupportedOperationException("不支持的一元操作符: " + node.getOperator());
}
valueStack.push(result);
}
@Override
public void visitVariableNode(VariableNode node) {
String name = node.getName();
if (!variables.containsKey(name)) {
throw new RuntimeException("未定义的变量: " + name);
}
valueStack.push(variables.get(name));
}
@Override
public void visitFunctionCallNode(FunctionCallNode node) {
String functionName = node.getFunctionName();
if (!functions.containsKey(functionName)) {
throw new RuntimeException("未定义的函数: " + functionName);
}
// 求值所有参数
List<Double> args = new ArrayList<>();
for (ASTNode arg : node.getArguments()) {
arg.accept(this);
args.add(valueStack.pop());
}
// 调用函数
double result = functions.get(functionName).apply(args);
valueStack.push(result);
}
}
// 具体访问者:代码生成
class CodeGeneratorVisitor implements ASTVisitor {
private StringBuilder code;
private String targetLanguage;
private int indentLevel;
public CodeGeneratorVisitor(String targetLanguage) {
this.targetLanguage = targetLanguage;
this.code = new StringBuilder();
this.indentLevel = 0;
}
public String generateCode(ASTNode node) {
code.setLength(0);
indentLevel = 0;
node.accept(this);
return code.toString();
}
private void append(String text) {
code.append(text);
}
@Override
public void visitNumberNode(NumberNode node) {
append(String.valueOf(node.getValue()));
}
@Override
public void visitBinaryOpNode(BinaryOpNode node) {
if (targetLanguage.equals("C")) {
append("(");
node.getLeft().accept(this);
append(" " + convertOperator(node.getOperator()) + " ");
node.getRight().accept(this);
append(")");
} else if (targetLanguage.equals("Python")) {
append("(");
node.getLeft().accept(this);
String op = node.getOperator().equals("^") ? " ** " : " " + node.getOperator() + " ";
append(op);
node.getRight().accept(this);
append(")");
}
}
@Override
public void visitUnaryOpNode(UnaryOpNode node) {
append("(" + node.getOperator());
node.getOperand().accept(this);
append(")");
}
@Override
public void visitVariableNode(VariableNode node) {
append(node.getName());
}
@Override
public void visitFunctionCallNode(FunctionCallNode node) {
String functionName = convertFunctionName(node.getFunctionName());
append(functionName + "(");
List<ASTNode> args = node.getArguments();
for (int i = 0; i < args.size(); i++) {
if (i > 0) append(", ");
args.get(i).accept(this);
}
append(")");
}
private String convertOperator(String operator) {
if (targetLanguage.equals("C") && operator.equals("^")) {
return "pow";
}
return operator;
}
private String convertFunctionName(String functionName) {
if (targetLanguage.equals("C")) {
switch (functionName) {
case "sqrt": return "sqrt";
case "sin": return "sin";
case "cos": return "cos";
case "pow": return "pow";
default: return functionName;
}
}
return functionName;
}
}
// 具体访问者:语法分析
class SyntaxAnalyzerVisitor implements ASTVisitor {
private List<String> errors;
private Set<String> declaredVariables;
private Set<String> usedVariables;
private Set<String> availableFunctions;
public SyntaxAnalyzerVisitor() {
this.errors = new ArrayList<>();
this.declaredVariables = new HashSet<>();
this.usedVariables = new HashSet<>();
this.availableFunctions = new HashSet<>();
// 添加内置函数
availableFunctions.addAll(Arrays.asList("sin", "cos", "sqrt", "pow", "max", "min"));
}
public void declareVariable(String name) {
declaredVariables.add(name);
}
public void declareFunction(String name) {
availableFunctions.add(name);
}
public List<String> analyze(ASTNode node) {
errors.clear();
usedVariables.clear();
node.accept(this);
// 检查未声明的变量
for (String var : usedVariables) {
if (!declaredVariables.contains(var)) {
errors.add("未声明的变量: " + var);
}
}
return new ArrayList<>(errors);
}
@Override
public void visitNumberNode(NumberNode node) {
// 数字节点无需检查
}
@Override
public void visitBinaryOpNode(BinaryOpNode node) {
node.getLeft().accept(this);
node.getRight().accept(this);
// 检查除零
if (node.getOperator().equals("/") && node.getRight() instanceof NumberNode) {
NumberNode rightNum = (NumberNode) node.getRight();
if (rightNum.getValue() == 0) {
errors.add("除零错误");
}
}
}
@Override
public void visitUnaryOpNode(UnaryOpNode node) {
node.getOperand().accept(this);
}
@Override
public void visitVariableNode(VariableNode node) {
usedVariables.add(node.getName());
}
@Override
public void visitFunctionCallNode(FunctionCallNode node) {
String functionName = node.getFunctionName();
if (!availableFunctions.contains(functionName)) {
errors.add("未定义的函数: " + functionName);
}
// 检查参数
for (ASTNode arg : node.getArguments()) {
arg.accept(this);
}
// 检查参数数量(简化检查)
int argCount = node.getArguments().size();
switch (functionName) {
case "sin":
case "cos":
case "sqrt":
if (argCount != 1) {
errors.add(functionName + " 函数需要1个参数,但提供了 " + argCount + " 个");
}
break;
case "pow":
if (argCount != 2) {
errors.add("pow 函数需要2个参数,但提供了 " + argCount + " 个");
}
break;
}
}
}
// 具体访问者:优化器
class OptimizerVisitor implements ASTVisitor {
private Stack<ASTNode> nodeStack;
private boolean optimized;
public OptimizerVisitor() {
this.nodeStack = new Stack<>();
this.optimized = false;
}
public ASTNode optimize(ASTNode node) {
nodeStack.clear();
optimized = false;
node.accept(this);
return nodeStack.isEmpty() ? node : nodeStack.pop();
}
public boolean wasOptimized() {
return optimized;
}
@Override
public void visitNumberNode(NumberNode node) {
nodeStack.push(node);
}
@Override
public void visitBinaryOpNode(BinaryOpNode node) {
// 递归优化子节点
node.getLeft().accept(this);
ASTNode optimizedLeft = nodeStack.pop();
node.getRight().accept(this);
ASTNode optimizedRight = nodeStack.pop();
// 常量折叠
if (optimizedLeft instanceof NumberNode && optimizedRight instanceof NumberNode) {
NumberNode leftNum = (NumberNode) optimizedLeft;
NumberNode rightNum = (NumberNode) optimizedRight;
double result;
switch (node.getOperator()) {
case "+":
result = leftNum.getValue() + rightNum.getValue();
break;
case "-":
result = leftNum.getValue() - rightNum.getValue();
break;
case "*":
result = leftNum.getValue() * rightNum.getValue();
break;
case "/":
if (rightNum.getValue() == 0) {
// 保持原样,让运行时处理除零错误
nodeStack.push(new BinaryOpNode(optimizedLeft, node.getOperator(), optimizedRight));
return;
}
result = leftNum.getValue() / rightNum.getValue();
break;
case "^":
result = Math.pow(leftNum.getValue(), rightNum.getValue());
break;
default:
nodeStack.push(new BinaryOpNode(optimizedLeft, node.getOperator(), optimizedRight));
return;
}
nodeStack.push(new NumberNode(result));
optimized = true;
return;
}
// 代数简化
if (node.getOperator().equals("+")) {
// x + 0 = x
if (optimizedRight instanceof NumberNode && ((NumberNode) optimizedRight).getValue() == 0) {
nodeStack.push(optimizedLeft);
optimized = true;
return;
}
// 0 + x = x
if (optimizedLeft instanceof NumberNode && ((NumberNode) optimizedLeft).getValue() == 0) {
nodeStack.push(optimizedRight);
optimized = true;
return;
}
} else if (node.getOperator().equals("*")) {
// x * 1 = x
if (optimizedRight instanceof NumberNode && ((NumberNode) optimizedRight).getValue() == 1) {
nodeStack.push(optimizedLeft);
optimized = true;
return;
}
// 1 * x = x
if (optimizedLeft instanceof NumberNode && ((NumberNode) optimizedLeft).getValue() == 1) {
nodeStack.push(optimizedRight);
optimized = true;
return;
}
// x * 0 = 0
if ((optimizedRight instanceof NumberNode && ((NumberNode) optimizedRight).getValue() == 0) ||
(optimizedLeft instanceof NumberNode && ((NumberNode) optimizedLeft).getValue() == 0)) {
nodeStack.push(new NumberNode(0));
optimized = true;
return;
}
}
nodeStack.push(new BinaryOpNode(optimizedLeft, node.getOperator(), optimizedRight));
}
@Override
public void visitUnaryOpNode(UnaryOpNode node) {
node.getOperand().accept(this);
ASTNode optimizedOperand = nodeStack.pop();
// 常量折叠
if (optimizedOperand instanceof NumberNode) {
NumberNode operandNum = (NumberNode) optimizedOperand;
double result;
switch (node.getOperator()) {
case "-":
result = -operandNum.getValue();
break;
case "+":
result = operandNum.getValue();
break;
default:
nodeStack.push(new UnaryOpNode(node.getOperator(), optimizedOperand));
return;
}
nodeStack.push(new NumberNode(result));
optimized = true;
return;
}
nodeStack.push(new UnaryOpNode(node.getOperator(), optimizedOperand));
}
@Override
public void visitVariableNode(VariableNode node) {
nodeStack.push(node);
}
@Override
public void visitFunctionCallNode(FunctionCallNode node) {
List<ASTNode> optimizedArgs = new ArrayList<>();
for (ASTNode arg : node.getArguments()) {
arg.accept(this);
optimizedArgs.add(nodeStack.pop());
}
nodeStack.push(new FunctionCallNode(node.getFunctionName(), optimizedArgs));
}
}
// AST构建器
class ASTBuilder {
public static ASTNode buildSampleExpression() {
// 构建表达式: (x + 2) * (y - 1) + sin(z)
ASTNode x = new VariableNode("x");
ASTNode two = new NumberNode(2);
ASTNode xPlus2 = new BinaryOpNode(x, "+", two);
ASTNode y = new VariableNode("y");
ASTNode one = new NumberNode(1);
ASTNode yMinus1 = new BinaryOpNode(y, "-", one);
ASTNode multiply = new BinaryOpNode(xPlus2, "*", yMinus1);
ASTNode z = new VariableNode("z");
ASTNode sinZ = new FunctionCallNode("sin", Arrays.asList(z));
return new BinaryOpNode(multiply, "+", sinZ);
}
public static ASTNode buildOptimizationExample() {
// 构建表达式: (x + 0) * 1 - 0 * y
ASTNode x = new VariableNode("x");
ASTNode zero1 = new NumberNode(0);
ASTNode xPlusZero = new BinaryOpNode(x, "+", zero1);
ASTNode one = new NumberNode(1);
ASTNode multiplyOne = new BinaryOpNode(xPlusZero, "*", one);
ASTNode zero2 = new NumberNode(0);
ASTNode y = new VariableNode("y");
ASTNode zeroTimesY = new BinaryOpNode(zero2, "*", y);
return new BinaryOpNode(multiplyOne, "-", zeroTimesY);
}
}
// 演示类
public class ASTVisitorDemo {
public static void main(String[] args) {
System.out.println("=== AST访问者模式演示 ===");
// 构建示例表达式
ASTNode expression = ASTBuilder.buildSampleExpression();
System.out.println("原始表达式: " + expression);
// 1. 表达式求值
System.out.println("\n--- 表达式求值 ---");
EvaluatorVisitor evaluator = new EvaluatorVisitor();
evaluator.setVariable("x", 3.0);
evaluator.setVariable("y", 2.0);
evaluator.setVariable("z", Math.PI / 6); // 30度
double result = evaluator.evaluate(expression);
System.out.println("当 x=3, y=2, z=π/6 时,结果 = " + result);
// 2. 代码生成
System.out.println("\n--- 代码生成 ---");
CodeGeneratorVisitor cGenerator = new CodeGeneratorVisitor("C");
String cCode = cGenerator.generateCode(expression);
System.out.println("C代码: " + cCode);
CodeGeneratorVisitor pyGenerator = new CodeGeneratorVisitor("Python");
String pythonCode = pyGenerator.generateCode(expression);
System.out.println("Python代码: " + pythonCode);
// 3. 语法分析
System.out.println("\n--- 语法分析 ---");
SyntaxAnalyzerVisitor analyzer = new SyntaxAnalyzerVisitor();
analyzer.declareVariable("x");
analyzer.declareVariable("y");
// 故意不声明变量z来测试错误检测
List<String> errors = analyzer.analyze(expression);
if (errors.isEmpty()) {
System.out.println("语法检查通过");
} else {
System.out.println("发现语法错误:");
for (String error : errors) {
System.out.println(" - " + error);
}
}
// 4. 表达式优化
System.out.println("\n--- 表达式优化 ---");
ASTNode optimizationExample = ASTBuilder.buildOptimizationExample();
System.out.println("优化前: " + optimizationExample);
OptimizerVisitor optimizer = new OptimizerVisitor();
ASTNode optimized = optimizer.optimize(optimizationExample);
System.out.println("优化后: " + optimized);
System.out.println("是否进行了优化: " + optimizer.wasOptimized());
// 5. 多次优化
System.out.println("\n--- 多次优化 ---");
ASTNode current = optimizationExample;
int round = 1;
while (true) {
OptimizerVisitor roundOptimizer = new OptimizerVisitor();
ASTNode newOptimized = roundOptimizer.optimize(current);
if (!roundOptimizer.wasOptimized()) {
System.out.println("第" + round + "轮: 无法进一步优化");
break;
}
System.out.println("第" + round + "轮优化: " + newOptimized);
current = newOptimized;
round++;
if (round > 10) { // 防止无限循环
System.out.println("达到最大优化轮数");
break;
}
}
System.out.println("最终优化结果: " + current);
}
}
12.3.6 访问者模式的优缺点
优点: 1. 易于添加新操作:添加新的访问者类即可增加新操作,无需修改元素类 2. 集中相关操作:相关的操作被集中在一个访问者类中 3. 符合开闭原则:对扩展开放,对修改封闭 4. 类型安全:编译时就能检查操作的正确性 5. 数据与算法分离:将数据结构与作用于其上的操作分离
缺点: 1. 增加新元素困难:添加新的元素类需要修改所有访问者接口 2. 破坏封装性:元素类可能需要暴露内部状态给访问者 3. 依赖具体类:访问者通常依赖于具体的元素类 4. 复杂性增加:引入了额外的抽象层次
12.3.7 适用场景
- 对象结构稳定但操作经常变化
- 需要对对象结构中的元素执行很多不同操作
- 这些操作不相关,不希望污染元素类
- 需要在运行时确定具体的操作
12.4 迭代器模式与访问者模式对比
12.4.1 相同点
- 行为型模式:都属于行为型设计模式
- 遍历相关:都涉及对对象集合的遍历
- 解耦设计:都将特定行为从主要对象中分离出来
- 扩展性好:都支持在不修改原有代码的情况下添加新功能
12.4.2 不同点
方面 | 迭代器模式 | 访问者模式 |
---|---|---|
主要目的 | 提供统一的遍历接口 | 在不修改类的前提下添加新操作 |
关注点 | 如何遍历 | 遍历时做什么 |
元素访问 | 顺序访问元素 | 根据元素类型执行不同操作 |
操作定义 | 在客户端定义 | 在访问者中定义 |
类型安全 | 运行时确定 | 编译时确定 |
扩展方向 | 易于添加新的遍历方式 | 易于添加新的操作 |
对象结构 | 可以是同质的 | 通常是异质的 |
状态维护 | 维护遍历状态 | 不维护遍历状态 |
12.4.3 选择指南
选择迭代器模式当: - 需要为聚合对象提供多种遍历方式 - 希望隐藏聚合对象的内部结构 - 需要支持并发遍历 - 遍历算法比较复杂
选择访问者模式当: - 对象结构稳定但操作经常变化 - 需要对不同类型的元素执行不同操作 - 操作之间相关性强,希望集中管理 - 需要在编译时确保类型安全
12.4.4 组合使用示例
from abc import ABC, abstractmethod
from typing import List, Iterator as TypingIterator
# 组合使用迭代器模式和访问者模式
class CompositeElement(ABC):
"""复合元素接口"""
@abstractmethod
def accept(self, visitor: 'ElementVisitor'):
"""接受访问者"""
pass
@abstractmethod
def create_iterator(self) -> 'ElementIterator':
"""创建迭代器"""
pass
class ElementVisitor(ABC):
"""元素访问者接口"""
@abstractmethod
def visit_text_element(self, element: 'TextElement'):
pass
@abstractmethod
def visit_image_element(self, element: 'ImageElement'):
pass
@abstractmethod
def visit_container_element(self, element: 'ContainerElement'):
pass
class ElementIterator(ABC):
"""元素迭代器接口"""
@abstractmethod
def has_next(self) -> bool:
pass
@abstractmethod
def next(self) -> CompositeElement:
pass
@abstractmethod
def reset(self):
pass
# 具体元素类
class TextElement(CompositeElement):
def __init__(self, content: str):
self.content = content
def accept(self, visitor: ElementVisitor):
visitor.visit_text_element(self)
def create_iterator(self) -> ElementIterator:
return SingleElementIterator(self)
class ImageElement(CompositeElement):
def __init__(self, src: str, alt: str):
self.src = src
self.alt = alt
def accept(self, visitor: ElementVisitor):
visitor.visit_image_element(self)
def create_iterator(self) -> ElementIterator:
return SingleElementIterator(self)
class ContainerElement(CompositeElement):
def __init__(self, tag: str):
self.tag = tag
self.children: List[CompositeElement] = []
def add_child(self, child: CompositeElement):
self.children.append(child)
def accept(self, visitor: ElementVisitor):
visitor.visit_container_element(self)
# 使用迭代器遍历子元素
iterator = self.create_iterator()
while iterator.has_next():
child = iterator.next()
if child != self: # 避免无限递归
child.accept(visitor)
def create_iterator(self) -> ElementIterator:
return ContainerElementIterator(self.children)
# 具体迭代器
class SingleElementIterator(ElementIterator):
def __init__(self, element: CompositeElement):
self.element = element
self.returned = False
def has_next(self) -> bool:
return not self.returned
def next(self) -> CompositeElement:
if not self.has_next():
raise StopIteration()
self.returned = True
return self.element
def reset(self):
self.returned = False
class ContainerElementIterator(ElementIterator):
def __init__(self, children: List[CompositeElement]):
self.children = children
self.index = 0
def has_next(self) -> bool:
return self.index < len(self.children)
def next(self) -> CompositeElement:
if not self.has_next():
raise StopIteration()
element = self.children[self.index]
self.index += 1
return element
def reset(self):
self.index = 0
# 具体访问者:HTML生成器
class HTMLGeneratorVisitor(ElementVisitor):
def __init__(self):
self.html = []
self.indent_level = 0
def visit_text_element(self, element: TextElement):
self._add_line(element.content)
def visit_image_element(self, element: ImageElement):
self._add_line(f'<img src="{element.src}" alt="{element.alt}" />')
def visit_container_element(self, element: ContainerElement):
self._add_line(f'<{element.tag}>')
self.indent_level += 1
# 注意:子元素的访问由accept方法中的迭代器处理
self.indent_level -= 1
self._add_line(f'</{element.tag}>')
def _add_line(self, content: str):
indent = ' ' * self.indent_level
self.html.append(indent + content)
def get_html(self) -> str:
return '\n'.join(self.html)
# 具体访问者:统计访问者
class StatisticsVisitor(ElementVisitor):
def __init__(self):
self.text_count = 0
self.image_count = 0
self.container_count = 0
self.total_text_length = 0
def visit_text_element(self, element: TextElement):
self.text_count += 1
self.total_text_length += len(element.content)
def visit_image_element(self, element: ImageElement):
self.image_count += 1
def visit_container_element(self, element: ContainerElement):
self.container_count += 1
def get_statistics(self) -> dict:
return {
'text_elements': self.text_count,
'image_elements': self.image_count,
'container_elements': self.container_count,
'total_text_length': self.total_text_length,
'average_text_length': self.total_text_length / self.text_count if self.text_count > 0 else 0
}
# 演示组合使用
def demonstrate_combined_patterns():
print("=== 迭代器模式与访问者模式组合使用 ===")
# 构建文档结构
document = ContainerElement("html")
head = ContainerElement("head")
head.add_child(TextElement("<title>示例页面</title>"))
document.add_child(head)
body = ContainerElement("body")
body.add_child(TextElement("<h1>欢迎来到我的网站</h1>"))
body.add_child(TextElement("<p>这是一个段落。</p>"))
body.add_child(ImageElement("logo.png", "网站Logo"))
footer = ContainerElement("footer")
footer.add_child(TextElement("<p>版权所有 © 2024</p>"))
body.add_child(footer)
document.add_child(body)
# 使用HTML生成器访问者
print("\n--- HTML生成 ---")
html_generator = HTMLGeneratorVisitor()
document.accept(html_generator)
print(html_generator.get_html())
# 使用统计访问者
print("\n--- 文档统计 ---")
stats_visitor = StatisticsVisitor()
document.accept(stats_visitor)
stats = stats_visitor.get_statistics()
for key, value in stats.items():
print(f"{key}: {value}")
# 演示迭代器的独立使用
print("\n--- 独立使用迭代器遍历body元素 ---")
body_iterator = body.create_iterator()
while body_iterator.has_next():
child = body_iterator.next()
print(f"子元素类型: {type(child).__name__}")
if __name__ == "__main__":
demonstrate_combined_patterns()
12.5 高级应用技巧
12.5.1 迭代器模式高级技巧
1. 惰性求值迭代器
class LazyIterator:
"""惰性求值迭代器"""
def __init__(self, data_source, transform_func=None):
self.data_source = data_source
self.transform_func = transform_func or (lambda x: x)
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index >= len(self.data_source):
raise StopIteration
# 惰性计算:只在需要时才进行转换
item = self.data_source[self.index]
result = self.transform_func(item)
self.index += 1
return result
2. 缓存迭代器
class CachedIterator:
"""带缓存的迭代器"""
def __init__(self, source_iterator, cache_size=100):
self.source_iterator = source_iterator
self.cache = []
self.cache_size = cache_size
self.position = 0
def has_next(self):
return self.position < len(self.cache) or self._load_next()
def next(self):
if not self.has_next():
raise StopIteration
item = self.cache[self.position]
self.position += 1
return item
def _load_next(self):
try:
if len(self.cache) < self.cache_size:
item = next(self.source_iterator)
self.cache.append(item)
return True
except StopIteration:
pass
return False
12.5.2 访问者模式高级技巧
1. 双分派访问者
class DoubleDispatchVisitor:
"""双分派访问者"""
def visit(self, element, context=None):
method_name = f"visit_{element.__class__.__name__.lower()}"
method = getattr(self, method_name, self.generic_visit)
return method(element, context)
def generic_visit(self, element, context=None):
"""默认访问方法"""
pass
2. 参数化访问者
class ParameterizedVisitor:
"""参数化访问者"""
def __init__(self, **parameters):
self.parameters = parameters
def visit_with_params(self, element, **additional_params):
# 合并参数
all_params = {**self.parameters, **additional_params}
return self.visit(element, all_params)
12.5.3 性能优化技巧
1. 迭代器池化
class IteratorPool:
"""迭代器对象池"""
def __init__(self, iterator_class, pool_size=10):
self.iterator_class = iterator_class
self.pool = []
self.pool_size = pool_size
def get_iterator(self, *args, **kwargs):
if self.pool:
iterator = self.pool.pop()
iterator.reset(*args, **kwargs)
return iterator
else:
return self.iterator_class(*args, **kwargs)
def return_iterator(self, iterator):
if len(self.pool) < self.pool_size:
self.pool.append(iterator)
2. 访问者缓存
class CachedVisitor:
"""带缓存的访问者"""
def __init__(self):
self.cache = {}
def visit(self, element):
element_id = id(element)
if element_id in self.cache:
return self.cache[element_id]
result = self._do_visit(element)
self.cache[element_id] = result
return result
def _do_visit(self, element):
# 实际的访问逻辑
pass
12.6 实际应用案例
12.6.1 数据库查询结果迭代器
class DatabaseResultIterator:
"""数据库查询结果迭代器"""
def __init__(self, connection, query, batch_size=1000):
self.connection = connection
self.query = query
self.batch_size = batch_size
self.cursor = None
self.current_batch = []
self.batch_index = 0
self.total_fetched = 0
def __enter__(self):
self.cursor = self.connection.cursor()
self.cursor.execute(self.query)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.cursor:
self.cursor.close()
def __iter__(self):
return self
def __next__(self):
if self.batch_index >= len(self.current_batch):
self._fetch_next_batch()
if not self.current_batch:
raise StopIteration
row = self.current_batch[self.batch_index]
self.batch_index += 1
return row
def _fetch_next_batch(self):
self.current_batch = self.cursor.fetchmany(self.batch_size)
self.batch_index = 0
self.total_fetched += len(self.current_batch)
12.6.2 Web爬虫数据处理访问者
class WebPageElement:
"""网页元素基类"""
def __init__(self, url, content):
self.url = url
self.content = content
self.timestamp = time.time()
def accept(self, visitor):
visitor.visit(self)
class HTMLPageElement(WebPageElement):
def __init__(self, url, html_content, title, links):
super().__init__(url, html_content)
self.title = title
self.links = links
class ImageElement(WebPageElement):
def __init__(self, url, image_data, format, size):
super().__init__(url, image_data)
self.format = format
self.size = size
class DataExtractionVisitor:
"""数据提取访问者"""
def __init__(self):
self.extracted_data = []
def visit(self, element):
if isinstance(element, HTMLPageElement):
self._extract_html_data(element)
elif isinstance(element, ImageElement):
self._extract_image_data(element)
def _extract_html_data(self, element):
data = {
'type': 'html',
'url': element.url,
'title': element.title,
'link_count': len(element.links),
'content_length': len(element.content),
'timestamp': element.timestamp
}
self.extracted_data.append(data)
def _extract_image_data(self, element):
data = {
'type': 'image',
'url': element.url,
'format': element.format,
'size': element.size,
'timestamp': element.timestamp
}
self.extracted_data.append(data)
12.7 本章总结
12.7.1 核心概念回顾
迭代器模式: - 提供统一的遍历接口,隐藏聚合对象的内部结构 - 支持多种遍历方式,如顺序、逆序、随机、过滤等 - 将遍历责任从聚合对象中分离出来
访问者模式: - 在不修改元素类的前提下定义新操作 - 将数据结构与作用于其上的操作分离 - 支持对不同类型元素执行不同操作
12.7.2 最佳实践
- 合理选择模式:根据问题特点选择合适的模式
- 性能考虑:注意迭代器的内存使用和访问者的计算复杂度
- 异常处理:确保迭代器和访问者能正确处理异常情况
- 线程安全:在多线程环境中注意同步问题
- 组合使用:两种模式可以有效组合使用
12.7.3 实际应用建议
迭代器模式适用于:
- 数据库结果集遍历
- 文件系统遍历
- 集合框架设计
- 大数据流处理
访问者模式适用于:
- 编译器设计(AST处理)
- 文档处理系统
- 图形处理系统
- 数据分析工具
12.7.4 注意事项
迭代器模式:
- 避免在遍历过程中修改聚合对象
- 注意迭代器的生命周期管理
- 考虑并发访问的安全性
访问者模式:
- 确保访问者接口的稳定性
- 注意元素类的封装性
- 避免访问者之间的相互依赖
12.8 练习题
12.8.1 基础练习
迭代器实现:为二叉树实现前序、中序、后序三种遍历迭代器。
访问者应用:设计一个图形系统,包含圆形、矩形、三角形等图形,实现面积计算、周长计算、绘制等访问者。
组合使用:结合迭代器和访问者模式,实现一个简单的文件系统分析工具。
12.8.2 进阶练习
性能优化:实现一个支持惰性求值的迭代器,用于处理大型数据集。
多态访问者:设计一个支持多种输出格式(JSON、XML、CSV)的数据导出访问者。
状态感知迭代器:实现一个能够记住遍历历史并支持回退操作的迭代器。
12.8.3 思考题
设计权衡:在什么情况下应该选择迭代器模式而不是简单的for循环?
扩展性分析:访问者模式在面对频繁的元素类型变化时有什么局限性?如何改进?
模式演化:随着函数式编程的流行,这两种模式是否还有存在的必要?为什么?
下一章预告
在下一章中,我们将学习备忘录模式与解释器模式,探讨如何保存和恢复对象状态,以及如何构建简单的领域特定语言。这两种模式在撤销/重做功能和规则引擎中有着重要应用。 “`