# 并发编程
Goroutine基础
什么是Goroutine
创建和启动Goroutine
package main
import (
"fmt"
"time"
)
// 普通函数
func sayHello(name string) {
for i := 0; i < 5; i++ {
fmt.Printf("Hello %s! (%d)\n", name, i+1)
time.Sleep(100 * time.Millisecond)
}
}
// 计数函数
func count(name string) {
for i := 1; i <= 5; i++ {
fmt.Printf("%s: %d\n", name, i)
time.Sleep(200 * time.Millisecond)
}
}
func main() {
fmt.Println("=== 顺序执行 ===")
sayHello("Alice")
sayHello("Bob")
fmt.Println("\n=== 并发执行 ===")
// 启动goroutine
go sayHello("Charlie")
go sayHello("David")
// 主goroutine继续执行
sayHello("Main")
fmt.Println("\n=== 多个goroutine ===")
// 启动多个goroutine
go count("Goroutine-1")
go count("Goroutine-2")
go count("Goroutine-3")
// 等待goroutine完成
time.Sleep(2 * time.Second)
fmt.Println("程序结束")
}
go run goroutine_basic.go
匿名函数和闭包Goroutine
package main
import (
"fmt"
"time"
)
func main() {
fmt.Println("=== 匿名函数goroutine ===")
// 匿名函数goroutine
go func() {
for i := 0; i < 3; i++ {
fmt.Printf("匿名goroutine: %d\n", i)
time.Sleep(100 * time.Millisecond)
}
}()
// 带参数的匿名函数goroutine
message := "Hello from goroutine"
go func(msg string) {
for i := 0; i < 3; i++ {
fmt.Printf("%s: %d\n", msg, i)
time.Sleep(150 * time.Millisecond)
}
}(message)
fmt.Println("\n=== 闭包goroutine ===")
// 闭包goroutine(注意变量捕获)
for i := 0; i < 3; i++ {
// 错误的方式:所有goroutine可能看到相同的i值
go func() {
fmt.Printf("错误方式 - i: %d\n", i)
}()
// 正确的方式:传递参数
go func(index int) {
fmt.Printf("正确方式 - index: %d\n", index)
}(i)
// 另一种正确方式:在循环内声明变量
j := i
go func() {
fmt.Printf("另一种正确方式 - j: %d\n", j)
}()
}
// 等待所有goroutine完成
time.Sleep(1 * time.Second)
fmt.Println("\n=== 共享变量示例 ===")
counter := 0
// 启动多个goroutine修改共享变量(不安全)
for i := 0; i < 5; i++ {
go func(id int) {
for j := 0; j < 100; j++ {
counter++ // 竞态条件
}
fmt.Printf("Goroutine %d 完成\n", id)
}(i)
}
time.Sleep(500 * time.Millisecond)
fmt.Printf("最终计数器值: %d (期望: 500)\n", counter)
fmt.Println("程序结束")
}
go run goroutine_anonymous.go
Goroutine的生命周期
package main
import (
"fmt"
"runtime"
"time"
)
// 长时间运行的goroutine
func longRunningTask(id int, duration time.Duration) {
fmt.Printf("任务 %d 开始,预计运行 %v\n", id, duration)
start := time.Now()
for time.Since(start) < duration {
// 模拟工作
time.Sleep(100 * time.Millisecond)
// 检查是否应该退出(在实际应用中可能使用context)
if time.Since(start) > duration {
break
}
}
fmt.Printf("任务 %d 完成,实际运行时间: %v\n", id, time.Since(start))
}
// 监控goroutine数量
func monitorGoroutines() {
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for i := 0; i < 10; i++ {
select {
case <-ticker.C:
fmt.Printf("当前goroutine数量: %d\n", runtime.NumGoroutine())
}
}
}
func main() {
fmt.Printf("程序开始,初始goroutine数量: %d\n", runtime.NumGoroutine())
// 启动监控goroutine
go monitorGoroutines()
fmt.Println("\n=== 启动多个任务 ===")
// 启动多个不同持续时间的任务
go longRunningTask(1, 500*time.Millisecond)
go longRunningTask(2, 800*time.Millisecond)
go longRunningTask(3, 300*time.Millisecond)
go longRunningTask(4, 1*time.Second)
time.Sleep(100 * time.Millisecond)
fmt.Printf("启动任务后goroutine数量: %d\n", runtime.NumGoroutine())
// 等待所有任务完成
time.Sleep(1500 * time.Millisecond)
fmt.Printf("\n任务完成后goroutine数量: %d\n", runtime.NumGoroutine())
// 演示goroutine泄漏
fmt.Println("\n=== Goroutine泄漏示例 ===")
for i := 0; i < 5; i++ {
go func(id int) {
// 无限循环的goroutine(会造成泄漏)
for {
time.Sleep(1 * time.Second)
// 在实际应用中,这种goroutine应该有退出机制
}
}(i)
}
time.Sleep(200 * time.Millisecond)
fmt.Printf("创建无限循环goroutine后数量: %d\n", runtime.NumGoroutine())
fmt.Println("\n注意:程序结束时,所有goroutine都会被终止")
fmt.Println("在实际应用中,应该提供优雅的退出机制")
}
go run goroutine_lifecycle.go
Channel基础
Channel的创建和使用
基本Channel操作
package main
import (
"fmt"
"time"
)
func main() {
fmt.Println("=== 基本Channel操作 ===")
// 创建channel
ch := make(chan int)
// 启动goroutine发送数据
go func() {
fmt.Println("发送数据到channel")
ch <- 42
fmt.Println("数据已发送")
}()
// 从channel接收数据
fmt.Println("等待接收数据...")
value := <-ch
fmt.Printf("接收到数据: %d\n", value)
fmt.Println("\n=== 多次发送和接收 ===")
// 创建字符串channel
messages := make(chan string)
// 发送多个消息
go func() {
messages <- "Hello"
messages <- "World"
messages <- "from"
messages <- "Go"
close(messages) // 关闭channel
}()
// 接收所有消息
for msg := range messages {
fmt.Printf("接收到消息: %s\n", msg)
time.Sleep(100 * time.Millisecond)
}
fmt.Println("\n=== Channel方向 ===")
// 只发送channel
sendOnly := make(chan<- int)
// 只接收channel
receiveOnly := make(<-chan int)
fmt.Printf("只发送channel类型: %T\n", sendOnly)
fmt.Printf("只接收channel类型: %T\n", receiveOnly)
// 双向channel可以转换为单向channel
bidirectional := make(chan int)
go sender(bidirectional) // 传递给只发送函数
receiver(bidirectional) // 传递给只接收函数
}
// 只发送函数
func sender(ch chan<- int) {
for i := 1; i <= 3; i++ {
fmt.Printf("发送: %d\n", i)
ch <- i
time.Sleep(100 * time.Millisecond)
}
close(ch)
}
// 只接收函数
func receiver(ch <-chan int) {
for value := range ch {
fmt.Printf("接收: %d\n", value)
}
fmt.Println("接收完成")
}
go run channel_basic.go
缓冲Channel
package main
import (
"fmt"
"time"
)
func main() {
fmt.Println("=== 无缓冲Channel ===")
// 无缓冲channel(同步)
unbuffered := make(chan int)
go func() {
fmt.Println("准备发送到无缓冲channel")
unbuffered <- 1
fmt.Println("发送完成")
}()
time.Sleep(100 * time.Millisecond)
fmt.Println("准备接收")
value := <-unbuffered
fmt.Printf("接收到: %d\n", value)
fmt.Println("\n=== 缓冲Channel ===")
// 缓冲channel(异步)
buffered := make(chan int, 3)
// 发送数据(不会阻塞,直到缓冲区满)
fmt.Println("发送数据到缓冲channel")
buffered <- 1
fmt.Println("发送 1")
buffered <- 2
fmt.Println("发送 2")
buffered <- 3
fmt.Println("发送 3")
fmt.Printf("缓冲区长度: %d, 容量: %d\n", len(buffered), cap(buffered))
// 接收数据
fmt.Println("\n开始接收数据:")
for i := 0; i < 3; i++ {
value := <-buffered
fmt.Printf("接收到: %d\n", value)
fmt.Printf("缓冲区长度: %d\n", len(buffered))
}
fmt.Println("\n=== 缓冲Channel的阻塞行为 ===")
buffer := make(chan string, 2)
// 启动发送goroutine
go func() {
messages := []string{"msg1", "msg2", "msg3", "msg4"}
for i, msg := range messages {
fmt.Printf("准备发送: %s\n", msg)
buffer <- msg
fmt.Printf("已发送: %s (第%d个)\n", msg, i+1)
time.Sleep(200 * time.Millisecond)
}
close(buffer)
}()
// 延迟接收
time.Sleep(500 * time.Millisecond)
fmt.Println("开始接收:")
for msg := range buffer {
fmt.Printf("接收到: %s\n", msg)
time.Sleep(300 * time.Millisecond)
}
fmt.Println("\n=== 生产者-消费者模式 ===")
jobs := make(chan int, 5)
results := make(chan int, 5)
// 启动3个工作者
for w := 1; w <= 3; w++ {
go worker(w, jobs, results)
}
// 发送5个任务
for j := 1; j <= 5; j++ {
jobs <- j
}
close(jobs)
// 收集结果
for r := 1; r <= 5; r++ {
result := <-results
fmt.Printf("结果: %d\n", result)
}
}
func worker(id int, jobs <-chan int, results chan<- int) {
for job := range jobs {
fmt.Printf("工作者 %d 开始任务 %d\n", id, job)
time.Sleep(time.Duration(job*100) * time.Millisecond)
result := job * 2
fmt.Printf("工作者 %d 完成任务 %d,结果: %d\n", id, job, result)
results <- result
}
}
go run channel_buffered.go
Channel的关闭和检测
package main
import (
"fmt"
"time"
)
func main() {
fmt.Println("=== Channel关闭检测 ===")
ch := make(chan int, 3)
// 发送一些数据然后关闭
go func() {
for i := 1; i <= 3; i++ {
ch <- i
fmt.Printf("发送: %d\n", i)
}
fmt.Println("关闭channel")
close(ch)
}()
time.Sleep(100 * time.Millisecond)
// 方法1:使用range(推荐)
fmt.Println("\n使用range接收:")
for value := range ch {
fmt.Printf("接收到: %d\n", value)
}
fmt.Println("range循环结束(channel已关闭)")
// 重新创建channel进行其他测试
ch2 := make(chan int, 2)
go func() {
ch2 <- 10
ch2 <- 20
close(ch2)
}()
time.Sleep(100 * time.Millisecond)
// 方法2:使用ok检测
fmt.Println("\n使用ok检测:")
for {
value, ok := <-ch2
if !ok {
fmt.Println("channel已关闭")
break
}
fmt.Printf("接收到: %d\n", value)
}
fmt.Println("\n=== 从已关闭的channel读取 ===")
ch3 := make(chan string, 1)
ch3 <- "hello"
close(ch3)
// 从已关闭的channel读取剩余数据
value1, ok1 := <-ch3
fmt.Printf("第一次读取: %s, ok: %t\n", value1, ok1)
// 从已关闭且空的channel读取
value2, ok2 := <-ch3
fmt.Printf("第二次读取: %s, ok: %t\n", value2, ok2)
fmt.Println("\n=== 多个发送者,一个接收者 ===")
messages := make(chan string, 10)
done := make(chan bool)
// 启动多个发送者
for i := 1; i <= 3; i++ {
go sender(i, messages)
}
// 启动接收者
go func() {
count := 0
for msg := range messages {
fmt.Printf("接收到消息: %s\n", msg)
count++
if count == 9 { // 期望接收9条消息(每个发送者3条)
break
}
}
done <- true
}()
// 等待接收完成
<-done
close(messages)
fmt.Println("\n=== 错误处理:向已关闭的channel发送 ===")
ch4 := make(chan int)
close(ch4)
// 使用defer和recover捕获panic
defer func() {
if r := recover(); r != nil {
fmt.Printf("捕获到panic: %v\n", r)
}
}()
fmt.Println("尝试向已关闭的channel发送数据...")
ch4 <- 1 // 这会引发panic
}
func sender(id int, ch chan<- string) {
for i := 1; i <= 3; i++ {
message := fmt.Sprintf("发送者%d-消息%d", id, i)
ch <- message
time.Sleep(100 * time.Millisecond)
}
fmt.Printf("发送者 %d 完成\n", id)
}
go run channel_close.go
Select语句
Select基础
Select的基本用法
package main
import (
"fmt"
"time"
)
func main() {
fmt.Println("=== Select基本用法 ===")
ch1 := make(chan string)
ch2 := make(chan string)
// 启动两个goroutine发送数据
go func() {
time.Sleep(200 * time.Millisecond)
ch1 <- "来自channel 1的消息"
}()
go func() {
time.Sleep(100 * time.Millisecond)
ch2 <- "来自channel 2的消息"
}()
// 使用select等待第一个可用的channel
select {
case msg1 := <-ch1:
fmt.Printf("接收到: %s\n", msg1)
case msg2 := <-ch2:
fmt.Printf("接收到: %s\n", msg2)
}
fmt.Println("\n=== Select with timeout ===")
ch3 := make(chan string)
go func() {
time.Sleep(500 * time.Millisecond)
ch3 <- "延迟消息"
}()
select {
case msg := <-ch3:
fmt.Printf("接收到消息: %s\n", msg)
case <-time.After(300 * time.Millisecond):
fmt.Println("超时!没有接收到消息")
}
fmt.Println("\n=== Select with default ===")
ch4 := make(chan int, 1)
// 非阻塞发送
select {
case ch4 <- 42:
fmt.Println("成功发送到channel")
default:
fmt.Println("channel已满,无法发送")
}
// 非阻塞接收
select {
case value := <-ch4:
fmt.Printf("接收到值: %d\n", value)
default:
fmt.Println("channel为空,无法接收")
}
// 再次尝试非阻塞接收
select {
case value := <-ch4:
fmt.Printf("接收到值: %d\n", value)
default:
fmt.Println("channel为空,无法接收")
}
fmt.Println("\n=== 多路复用示例 ===")
// 创建多个channel
channels := make([]chan int, 3)
for i := range channels {
channels[i] = make(chan int)
}
// 启动发送者
for i, ch := range channels {
go func(id int, channel chan int) {
for j := 0; j < 3; j++ {
time.Sleep(time.Duration(id*100+j*50) * time.Millisecond)
channel <- id*10 + j
}
close(channel)
}(i, ch)
}
// 使用select接收所有消息
activeChannels := len(channels)
for activeChannels > 0 {
select {
case value, ok := <-channels[0]:
if ok {
fmt.Printf("从channel 0 接收: %d\n", value)
} else {
fmt.Println("Channel 0 已关闭")
channels[0] = nil
activeChannels--
}
case value, ok := <-channels[1]:
if ok {
fmt.Printf("从channel 1 接收: %d\n", value)
} else {
fmt.Println("Channel 1 已关闭")
channels[1] = nil
activeChannels--
}
case value, ok := <-channels[2]:
if ok {
fmt.Printf("从channel 2 接收: %d\n", value)
} else {
fmt.Println("Channel 2 已关闭")
channels[2] = nil
activeChannels--
}
}
}
fmt.Println("所有channel都已关闭")
}
go run select_basic.go
Select的高级应用
package main
import (
"fmt"
"math/rand"
"time"
)
// 工作任务结构
type Job struct {
ID int
Data string
}
// 结果结构
type Result struct {
JobID int
Output string
Error error
}
func main() {
fmt.Println("=== 工作池模式 ===")
jobs := make(chan Job, 10)
results := make(chan Result, 10)
done := make(chan bool)
// 启动3个工作者
for w := 1; w <= 3; w++ {
go worker(w, jobs, results)
}
// 发送任务
go func() {
for i := 1; i <= 5; i++ {
job := Job{ID: i, Data: fmt.Sprintf("task-%d", i)}
jobs <- job
fmt.Printf("发送任务: %+v\n", job)
}
close(jobs)
}()
// 收集结果
go func() {
for i := 0; i < 5; i++ {
result := <-results
fmt.Printf("收到结果: %+v\n", result)
}
done <- true
}()
// 等待完成或超时
select {
case <-done:
fmt.Println("所有任务完成")
case <-time.After(3 * time.Second):
fmt.Println("超时!")
}
fmt.Println("\n=== 心跳监控 ===")
heartbeat := make(chan string)
quit := make(chan bool)
// 启动心跳发送者
go func() {
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
heartbeat <- "ping"
case <-quit:
fmt.Println("心跳发送者退出")
return
}
}
}()
// 监控心跳
timeout := time.After(1 * time.Second)
heartbeatCount := 0
heartbeatLoop:
for {
select {
case beat := <-heartbeat:
heartbeatCount++
fmt.Printf("收到心跳 #%d: %s\n", heartbeatCount, beat)
if heartbeatCount >= 4 {
break heartbeatLoop
}
case <-timeout:
fmt.Println("心跳超时!")
break heartbeatLoop
}
}
quit <- true
fmt.Println("\n=== 扇入模式(Fan-in) ===")
// 创建多个输入channel
input1 := make(chan string)
input2 := make(chan string)
input3 := make(chan string)
// 启动多个生产者
go producer("Producer-1", input1, 3)
go producer("Producer-2", input2, 3)
go producer("Producer-3", input3, 3)
// 扇入:合并多个channel到一个
merged := fanIn(input1, input2, input3)
// 接收合并后的数据
for i := 0; i < 9; i++ {
msg := <-merged
fmt.Printf("合并接收: %s\n", msg)
}
fmt.Println("\n=== 扇出模式(Fan-out) ===")
source := make(chan int)
// 启动数据源
go func() {
for i := 1; i <= 6; i++ {
source <- i
time.Sleep(100 * time.Millisecond)
}
close(source)
}()
// 扇出:分发到多个处理器
out1, out2, out3 := fanOut(source)
// 启动多个消费者
go consumer("Consumer-1", out1)
go consumer("Consumer-2", out2)
go consumer("Consumer-3", out3)
time.Sleep(1 * time.Second)
}
func worker(id int, jobs <-chan Job, results chan<- Result) {
for job := range jobs {
fmt.Printf("工作者 %d 开始处理任务 %d\n", id, job.ID)
// 模拟工作
time.Sleep(time.Duration(rand.Intn(300)+100) * time.Millisecond)
result := Result{
JobID: job.ID,
Output: fmt.Sprintf("处理结果-%d", job.ID),
Error: nil,
}
results <- result
fmt.Printf("工作者 %d 完成任务 %d\n", id, job.ID)
}
}
func producer(name string, output chan<- string, count int) {
for i := 1; i <= count; i++ {
msg := fmt.Sprintf("%s-消息%d", name, i)
output <- msg
time.Sleep(time.Duration(rand.Intn(200)+50) * time.Millisecond)
}
close(output)
}
// 扇入:合并多个channel
func fanIn(inputs ...<-chan string) <-chan string {
output := make(chan string)
go func() {
defer close(output)
// 跟踪活跃的channel
activeChannels := len(inputs)
cases := make([]chan string, len(inputs))
copy(cases, inputs)
for activeChannels > 0 {
// 动态构建select cases
for i, ch := range cases {
if ch != nil {
select {
case msg, ok := <-ch:
if ok {
output <- msg
} else {
cases[i] = nil
activeChannels--
}
default:
// 非阻塞,继续下一个channel
}
}
}
time.Sleep(10 * time.Millisecond) // 避免忙等待
}
}()
return output
}
// 扇出:分发到多个channel
func fanOut(input <-chan int) (<-chan int, <-chan int, <-chan int) {
out1 := make(chan int)
out2 := make(chan int)
out3 := make(chan int)
go func() {
defer close(out1)
defer close(out2)
defer close(out3)
for value := range input {
// 复制到所有输出channel
out1 <- value
out2 <- value
out3 <- value
}
}()
return out1, out2, out3
}
func consumer(name string, input <-chan int) {
for value := range input {
fmt.Printf("%s 处理: %d\n", name, value)
time.Sleep(50 * time.Millisecond)
}
fmt.Printf("%s 完成\n", name)
}
go run select_advanced.go
同步原语
WaitGroup
WaitGroup基础使用
package main
import (
"fmt"
"sync"
"time"
)
func main() {
fmt.Println("=== WaitGroup基础使用 ===")
var wg sync.WaitGroup
// 启动3个goroutine
for i := 1; i <= 3; i++ {
wg.Add(1) // 增加等待计数
go func(id int) {
defer wg.Done() // 完成时减少计数
fmt.Printf("Goroutine %d 开始\n", id)
time.Sleep(time.Duration(id*200) * time.Millisecond)
fmt.Printf("Goroutine %d 完成\n", id)
}(i)
}
fmt.Println("等待所有goroutine完成...")
wg.Wait() // 等待计数归零
fmt.Println("所有goroutine已完成")
fmt.Println("\n=== 批量任务处理 ===")
tasks := []string{"任务A", "任务B", "任务C", "任务D", "任务E"}
var taskWg sync.WaitGroup
for _, task := range tasks {
taskWg.Add(1)
go func(taskName string) {
defer taskWg.Done()
processTask(taskName)
}(task)
}
taskWg.Wait()
fmt.Println("所有任务处理完成")
fmt.Println("\n=== 工作池模式 ===")
const numWorkers = 3
const numJobs = 10
jobs := make(chan int, numJobs)
var workerWg sync.WaitGroup
// 启动工作者
for w := 1; w <= numWorkers; w++ {
workerWg.Add(1)
go worker(w, jobs, &workerWg)
}
// 发送任务
for j := 1; j <= numJobs; j++ {
jobs <- j
}
close(jobs)
// 等待所有工作者完成
workerWg.Wait()
fmt.Println("所有工作者已完成")
fmt.Println("\n=== 分阶段处理 ===")
data := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
// 阶段1:数据预处理
var stage1Wg sync.WaitGroup
preprocessed := make(chan int, len(data))
for _, value := range data {
stage1Wg.Add(1)
go func(v int) {
defer stage1Wg.Done()
result := v * 2 // 预处理:乘以2
fmt.Printf("预处理: %d -> %d\n", v, result)
preprocessed <- result
}(value)
}
stage1Wg.Wait()
close(preprocessed)
fmt.Println("阶段1完成")
// 阶段2:数据处理
var stage2Wg sync.WaitGroup
processed := make(chan int, len(data))
for value := range preprocessed {
stage2Wg.Add(1)
go func(v int) {
defer stage2Wg.Done()
result := v + 10 // 处理:加10
fmt.Printf("处理: %d -> %d\n", v, result)
processed <- result
}(value)
}
stage2Wg.Wait()
close(processed)
fmt.Println("阶段2完成")
// 收集最终结果
var results []int
for result := range processed {
results = append(results, result)
}
fmt.Printf("最终结果: %v\n", results)
}
func processTask(taskName string) {
fmt.Printf("开始处理 %s\n", taskName)
// 模拟任务处理时间
time.Sleep(time.Duration(len(taskName)*100) * time.Millisecond)
fmt.Printf("%s 处理完成\n", taskName)
}
func worker(id int, jobs <-chan int, wg *sync.WaitGroup) {
defer wg.Done()
for job := range jobs {
fmt.Printf("工作者 %d 开始任务 %d\n", id, job)
time.Sleep(100 * time.Millisecond)
fmt.Printf("工作者 %d 完成任务 %d\n", id, job)
}
fmt.Printf("工作者 %d 退出\n", id)
}
go run waitgroup_basic.go
Mutex和RWMutex
互斥锁的使用
package main
import (
"fmt"
"sync"
"time"
)
// 共享资源
type Counter struct {
mu sync.Mutex
value int
}
// 安全的增加方法
func (c *Counter) Increment() {
c.mu.Lock()
defer c.mu.Unlock()
c.value++
}
// 安全的获取方法
func (c *Counter) Value() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.value
}
// 银行账户示例
type BankAccount struct {
mu sync.Mutex
balance float64
}
func (ba *BankAccount) Deposit(amount float64) {
ba.mu.Lock()
defer ba.mu.Unlock()
fmt.Printf("存款前余额: %.2f\n", ba.balance)
ba.balance += amount
fmt.Printf("存款 %.2f,余额: %.2f\n", amount, ba.balance)
}
func (ba *BankAccount) Withdraw(amount float64) bool {
ba.mu.Lock()
defer ba.mu.Unlock()
fmt.Printf("取款前余额: %.2f\n", ba.balance)
if ba.balance >= amount {
ba.balance -= amount
fmt.Printf("取款 %.2f,余额: %.2f\n", amount, ba.balance)
return true
}
fmt.Printf("余额不足,无法取款 %.2f\n", amount)
return false
}
func (ba *BankAccount) Balance() float64 {
ba.mu.Lock()
defer ba.mu.Unlock()
return ba.balance
}
func main() {
fmt.Println("=== 竞态条件演示 ===")
// 不安全的计数器
unsafeCounter := 0
var wg sync.WaitGroup
// 启动10个goroutine,每个增加1000次
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 1000; j++ {
unsafeCounter++ // 竞态条件
}
}()
}
wg.Wait()
fmt.Printf("不安全计数器结果: %d (期望: 10000)\n", unsafeCounter)
fmt.Println("\n=== 使用Mutex解决竞态条件 ===")
safeCounter := &Counter{}
// 启动10个goroutine,每个增加1000次
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 1000; j++ {
safeCounter.Increment()
}
}()
}
wg.Wait()
fmt.Printf("安全计数器结果: %d\n", safeCounter.Value())
fmt.Println("\n=== 银行账户示例 ===")
account := &BankAccount{balance: 1000.0}
// 并发存款和取款
for i := 0; i < 3; i++ {
wg.Add(2)
// 存款goroutine
go func(id int) {
defer wg.Done()
amount := float64(id*100 + 50)
account.Deposit(amount)
}(i)
// 取款goroutine
go func(id int) {
defer wg.Done()
amount := float64(id*80 + 30)
account.Withdraw(amount)
}(i)
}
wg.Wait()
fmt.Printf("最终余额: %.2f\n", account.Balance())
fmt.Println("\n=== 死锁演示(注释掉的代码) ===")
fmt.Println("// 以下代码会导致死锁,已注释")
fmt.Println("// var mu1, mu2 sync.Mutex")
fmt.Println("// go func() {")
fmt.Println("// mu1.Lock()")
fmt.Println("// time.Sleep(100 * time.Millisecond)")
fmt.Println("// mu2.Lock() // 等待mu2")
fmt.Println("// mu2.Unlock()")
fmt.Println("// mu1.Unlock()")
fmt.Println("// }()")
fmt.Println("// go func() {")
fmt.Println("// mu2.Lock()")
fmt.Println("// time.Sleep(100 * time.Millisecond)")
fmt.Println("// mu1.Lock() // 等待mu1")
fmt.Println("// mu1.Unlock()")
fmt.Println("// mu2.Unlock()")
fmt.Println("// }()")
fmt.Println("\n=== 避免死锁的方法 ===")
var mu1, mu2 sync.Mutex
// 方法1:总是以相同的顺序获取锁
for i := 0; i < 2; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// 总是先获取mu1,再获取mu2
mu1.Lock()
fmt.Printf("Goroutine %d 获取了 mu1\n", id)
time.Sleep(50 * time.Millisecond)
mu2.Lock()
fmt.Printf("Goroutine %d 获取了 mu2\n", id)
// 做一些工作
time.Sleep(50 * time.Millisecond)
mu2.Unlock()
fmt.Printf("Goroutine %d 释放了 mu2\n", id)
mu1.Unlock()
fmt.Printf("Goroutine %d 释放了 mu1\n", id)
}(i)
}
wg.Wait()
fmt.Println("避免死锁演示完成")
}
go run mutex_basic.go
读写锁的使用
package main
import (
"fmt"
"sync"
"time"
)
// 使用读写锁的缓存
type Cache struct {
mu sync.RWMutex
data map[string]string
}
func NewCache() *Cache {
return &Cache{
data: make(map[string]string),
}
}
// 读操作(可以并发)
func (c *Cache) Get(key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
fmt.Printf("读取 key: %s\n", key)
value, exists := c.data[key]
return value, exists
}
// 写操作(独占)
func (c *Cache) Set(key, value string) {
c.mu.Lock()
defer c.mu.Unlock()
fmt.Printf("设置 %s = %s\n", key, value)
c.data[key] = value
}
// 删除操作(独占)
func (c *Cache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
fmt.Printf("删除 key: %s\n", key)
delete(c.data, key)
}
// 获取所有键(读操作)
func (c *Cache) Keys() []string {
c.mu.RLock()
defer c.mu.RUnlock()
keys := make([]string, 0, len(c.data))
for k := range c.data {
keys = append(keys, k)
}
return keys
}
// 配置管理器
type ConfigManager struct {
mu sync.RWMutex
config map[string]interface{}
}
func NewConfigManager() *ConfigManager {
return &ConfigManager{
config: make(map[string]interface{}),
}
}
func (cm *ConfigManager) GetString(key string) string {
cm.mu.RLock()
defer cm.mu.RUnlock()
if value, exists := cm.config[key]; exists {
if str, ok := value.(string); ok {
return str
}
}
return ""
}
func (cm *ConfigManager) GetInt(key string) int {
cm.mu.RLock()
defer cm.mu.RUnlock()
if value, exists := cm.config[key]; exists {
if num, ok := value.(int); ok {
return num
}
}
return 0
}
func (cm *ConfigManager) Set(key string, value interface{}) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.config[key] = value
fmt.Printf("配置更新: %s = %v\n", key, value)
}
func (cm *ConfigManager) GetAll() map[string]interface{} {
cm.mu.RLock()
defer cm.mu.RUnlock()
// 返回副本以避免外部修改
result := make(map[string]interface{})
for k, v := range cm.config {
result[k] = v
}
return result
}
func main() {
fmt.Println("=== 读写锁缓存示例 ===")
cache := NewCache()
var wg sync.WaitGroup
// 初始化一些数据
cache.Set("user:1", "Alice")
cache.Set("user:2", "Bob")
cache.Set("user:3", "Charlie")
// 启动多个读者
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 3; j++ {
key := fmt.Sprintf("user:%d", (j%3)+1)
value, exists := cache.Get(key)
if exists {
fmt.Printf("读者 %d: %s = %s\n", id, key, value)
} else {
fmt.Printf("读者 %d: %s 不存在\n", id, key)
}
time.Sleep(100 * time.Millisecond)
}
}(i)
}
// 启动写者
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(200 * time.Millisecond)
cache.Set("user:4", "David")
time.Sleep(200 * time.Millisecond)
cache.Set("user:1", "Alice Updated")
time.Sleep(200 * time.Millisecond)
cache.Delete("user:2")
}()
wg.Wait()
fmt.Printf("\n最终缓存键: %v\n", cache.Keys())
fmt.Println("\n=== 配置管理器示例 ===")
configMgr := NewConfigManager()
// 初始化配置
configMgr.Set("app_name", "MyApp")
configMgr.Set("port", 8080)
configMgr.Set("debug", true)
configMgr.Set("max_connections", 100)
// 启动多个配置读取者
for i := 0; i < 3; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < 5; j++ {
appName := configMgr.GetString("app_name")
port := configMgr.GetInt("port")
maxConn := configMgr.GetInt("max_connections")
fmt.Printf("读取者 %d: %s:%d (max_conn: %d)\n",
id, appName, port, maxConn)
time.Sleep(50 * time.Millisecond)
}
}(i)
}
// 配置更新者
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(100 * time.Millisecond)
configMgr.Set("port", 9090)
time.Sleep(150 * time.Millisecond)
configMgr.Set("max_connections", 200)
time.Sleep(100 * time.Millisecond)
configMgr.Set("app_name", "MyApp v2.0")
}()
wg.Wait()
fmt.Println("\n最终配置:")
finalConfig := configMgr.GetAll()
for k, v := range finalConfig {
fmt.Printf(" %s: %v\n", k, v)
}
fmt.Println("\n=== 性能对比:Mutex vs RWMutex ===")
// 测试读多写少的场景
testReadHeavyWorkload()
}
func testReadHeavyWorkload() {
const numReaders = 10
const numWriters = 1
const operations = 1000
// 使用普通Mutex的数据结构
type MutexData struct {
mu sync.Mutex
data map[string]int
}
// 使用RWMutex的数据结构
type RWMutexData struct {
mu sync.RWMutex
data map[string]int
}
// 测试Mutex
fmt.Println("测试 Mutex 性能...")
mutexData := &MutexData{data: make(map[string]int)}
start := time.Now()
var wg sync.WaitGroup
// 读者
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < operations; j++ {
mutexData.mu.Lock()
_ = mutexData.data["key"]
mutexData.mu.Unlock()
}
}()
}
// 写者
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < operations/10; j++ {
mutexData.mu.Lock()
mutexData.data["key"] = j
mutexData.mu.Unlock()
}
}()
}
wg.Wait()
mutexTime := time.Since(start)
// 测试RWMutex
fmt.Println("测试 RWMutex 性能...")
rwMutexData := &RWMutexData{data: make(map[string]int)}
start = time.Now()
// 读者
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < operations; j++ {
rwMutexData.mu.RLock()
_ = rwMutexData.data["key"]
rwMutexData.mu.RUnlock()
}
}()
}
// 写者
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < operations/10; j++ {
rwMutexData.mu.Lock()
rwMutexData.data["key"] = j
rwMutexData.mu.Unlock()
}
}()
}
wg.Wait()
rwMutexTime := time.Since(start)
fmt.Printf("Mutex 耗时: %v\n", mutexTime)
fmt.Printf("RWMutex 耗时: %v\n", rwMutexTime)
fmt.Printf("性能提升: %.2fx\n", float64(mutexTime)/float64(rwMutexTime))
}
go run rwmutex_basic.go
并发模式
常用并发模式
生产者-消费者模式
package main
import (
"fmt"
"math/rand"
"sync"
"time"
)
// 任务结构
type Task struct {
ID int
Data string
}
// 结果结构
type Result struct {
TaskID int
Output string
Worker int
}
func main() {
fmt.Println("=== 基本生产者-消费者模式 ===")
tasks := make(chan Task, 10)
results := make(chan Result, 10)
var wg sync.WaitGroup
// 启动3个消费者(工作者)
for i := 1; i <= 3; i++ {
wg.Add(1)
go consumer(i, tasks, results, &wg)
}
// 生产者:生成任务
go func() {
for i := 1; i <= 10; i++ {
task := Task{
ID: i,
Data: fmt.Sprintf("任务数据-%d", i),
}
tasks <- task
fmt.Printf("生产任务: %+v\n", task)
time.Sleep(100 * time.Millisecond)
}
close(tasks)
}()
// 结果收集器
go func() {
for result := range results {
fmt.Printf("收到结果: 任务%d 由工作者%d 完成,输出: %s\n",
result.TaskID, result.Worker, result.Output)
}
}()
wg.Wait()
close(results)
fmt.Println("\n=== 多生产者-多消费者模式 ===")
jobs := make(chan int, 20)
outputs := make(chan string, 20)
var producerWg sync.WaitGroup
var consumerWg sync.WaitGroup
// 启动3个生产者
for p := 1; p <= 3; p++ {
producerWg.Add(1)
go producer(p, jobs, &producerWg)
}
// 启动4个消费者
for c := 1; c <= 4; c++ {
consumerWg.Add(1)
go worker(c, jobs, outputs, &consumerWg)
}
// 等待所有生产者完成
go func() {
producerWg.Wait()
close(jobs)
}()
// 等待所有消费者完成
go func() {
consumerWg.Wait()
close(outputs)
}()
// 收集所有输出
var allOutputs []string
for output := range outputs {
allOutputs = append(allOutputs, output)
fmt.Printf("输出: %s\n", output)
}
fmt.Printf("总共处理了 %d 个任务\n", len(allOutputs))
}
func consumer(id int, tasks <-chan Task, results chan<- Result, wg *sync.WaitGroup) {
defer wg.Done()
for task := range tasks {
fmt.Printf("消费者 %d 开始处理任务 %d\n", id, task.ID)
// 模拟处理时间
time.Sleep(time.Duration(rand.Intn(300)+100) * time.Millisecond)
result := Result{
TaskID: task.ID,
Output: fmt.Sprintf("处理结果-%d", task.ID),
Worker: id,
}
results <- result
fmt.Printf("消费者 %d 完成任务 %d\n", id, task.ID)
}
fmt.Printf("消费者 %d 退出\n", id)
}
func producer(id int, jobs chan<- int, wg *sync.WaitGroup) {
defer wg.Done()
for i := 1; i <= 5; i++ {
job := id*10 + i
jobs <- job
fmt.Printf("生产者 %d 生产任务 %d\n", id, job)
time.Sleep(time.Duration(rand.Intn(200)+50) * time.Millisecond)
}
fmt.Printf("生产者 %d 完成\n", id)
}
func worker(id int, jobs <-chan int, outputs chan<- string, wg *sync.WaitGroup) {
defer wg.Done()
for job := range jobs {
fmt.Printf("工作者 %d 处理任务 %d\n", id, job)
time.Sleep(time.Duration(rand.Intn(200)+100) * time.Millisecond)
output := fmt.Sprintf("任务%d由工作者%d完成", job, id)
outputs <- output
}
fmt.Printf("工作者 %d 退出\n", id)
}
go run producer_consumer.go
管道模式
package main
import (
"fmt"
"time"
)
// 管道阶段1:数据生成
func generate(nums ...int) <-chan int {
out := make(chan int)
go func() {
defer close(out)
for _, n := range nums {
out <- n
fmt.Printf("生成: %d\n", n)
time.Sleep(50 * time.Millisecond)
}
}()
return out
}
// 管道阶段2:数据处理(平方)
func square(in <-chan int) <-chan int {
out := make(chan int)
go func() {
defer close(out)
for n := range in {
result := n * n
out <- result
fmt.Printf("平方: %d -> %d\n", n, result)
time.Sleep(100 * time.Millisecond)
}
}()
return out
}
// 管道阶段3:数据过滤(偶数)
func filterEven(in <-chan int) <-chan int {
out := make(chan int)
go func() {
defer close(out)
for n := range in {
if n%2 == 0 {
out <- n
fmt.Printf("过滤(偶数): %d\n", n)
} else {
fmt.Printf("过滤(奇数,丢弃): %d\n", n)
}
time.Sleep(50 * time.Millisecond)
}
}()
return out
}
// 管道阶段4:数据转换(转为字符串)
func toString(in <-chan int) <-chan string {
out := make(chan string)
go func() {
defer close(out)
for n := range in {
result := fmt.Sprintf("数字-%d", n)
out <- result
fmt.Printf("转换: %d -> %s\n", n, result)
time.Sleep(75 * time.Millisecond)
}
}()
return out
}
func main() {
fmt.Println("=== 基本管道模式 ===")
// 构建管道
numbers := generate(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
squared := square(numbers)
filtered := filterEven(squared)
strings := toString(filtered)
// 消费最终结果
fmt.Println("\n最终结果:")
for result := range strings {
fmt.Printf("接收: %s\n", result)
}
fmt.Println("\n=== 扇出-扇入管道模式 ===")
// 数据源
source := generate(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
// 扇出:分发到多个处理器
processor1 := square(source)
processor2 := square(source)
processor3 := square(source)
// 扇入:合并结果
merged := merge(processor1, processor2, processor3)
// 消费合并后的结果
fmt.Println("\n扇出-扇入结果:")
for result := range merged {
fmt.Printf("合并结果: %d\n", result)
}
fmt.Println("\n=== 带缓冲的管道 ===")
bufferedPipeline()
}
// 扇入:合并多个channel
func merge(channels ...<-chan int) <-chan int {
out := make(chan int)
// 为每个输入channel启动一个goroutine
for i, ch := range channels {
go func(id int, c <-chan int) {
for value := range c {
fmt.Printf("合并器从处理器%d接收: %d\n", id+1, value)
out <- value
}
}(i, ch)
}
// 启动一个goroutine来关闭输出channel
go func() {
// 等待所有输入channel关闭
// 注意:这是一个简化的实现,实际应用中可能需要更复杂的同步
time.Sleep(2 * time.Second)
close(out)
}()
return out
}
// 带缓冲的管道示例
func bufferedPipeline() {
// 创建带缓冲的管道阶段
stage1 := make(chan int, 5)
stage2 := make(chan int, 3)
stage3 := make(chan string, 2)
// 阶段1:数据生成
go func() {
defer close(stage1)
for i := 1; i <= 10; i++ {
stage1 <- i
fmt.Printf("阶段1生成: %d (缓冲: %d/%d)\n",
i, len(stage1), cap(stage1))
time.Sleep(50 * time.Millisecond)
}
}()
// 阶段2:数据处理
go func() {
defer close(stage2)
for n := range stage1 {
result := n * 3
stage2 <- result
fmt.Printf("阶段2处理: %d -> %d (缓冲: %d/%d)\n",
n, result, len(stage2), cap(stage2))
time.Sleep(100 * time.Millisecond)
}
}()
// 阶段3:数据转换
go func() {
defer close(stage3)
for n := range stage2 {
result := fmt.Sprintf("结果-%d", n)
stage3 <- result
fmt.Printf("阶段3转换: %d -> %s (缓冲: %d/%d)\n",
n, result, len(stage3), cap(stage3))
time.Sleep(150 * time.Millisecond)
}
}()
// 最终消费
for result := range stage3 {
fmt.Printf("最终输出: %s\n", result)
time.Sleep(200 * time.Millisecond)
}
}' > pipeline_pattern.go && go run pipeline_pattern.go
Context包的使用
Context基础
Context的创建和使用
echo 'package main
import (
"context"
"fmt"
"time"
)
func main() {
fmt.Println("=== Context基础使用 ===")
// 创建根context
ctx := context.Background()
fmt.Printf("根context: %v\n", ctx)
// 创建带超时的context
timeoutCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
// 启动一个长时间运行的任务
go longRunningTask(timeoutCtx, "任务1")
// 等待任务完成或超时
select {
case <-timeoutCtx.Done():
fmt.Printf("Context完成,原因: %v\n", timeoutCtx.Err())
case <-time.After(3 * time.Second):
fmt.Println("等待超时")
}
fmt.Println("\n=== Context取消 ===")
// 创建可取消的context
cancelCtx, cancelFunc := context.WithCancel(ctx)
// 启动任务
go longRunningTask(cancelCtx, "任务2")
// 1秒后取消
time.Sleep(1 * time.Second)
fmt.Println("取消任务...")
cancelFunc()
// 等待任务响应取消
time.Sleep(500 * time.Millisecond)
fmt.Println("\n=== Context传值 ===")
// 创建带值的context
type key string
userKey := key("user")
requestIDKey := key("requestID")
valueCtx := context.WithValue(ctx, userKey, "Alice")
valueCtx = context.WithValue(valueCtx, requestIDKey, "req-123")
// 传递context到函数
processRequest(valueCtx)
fmt.Println("\n=== Context截止时间 ===")
// 创建带截止时间的context
deadline := time.Now().Add(1500 * time.Millisecond)
deadlineCtx, cancel2 := context.WithDeadline(ctx, deadline)
defer cancel2()
fmt.Printf("设置截止时间: %v\n", deadline.Format("15:04:05.000"))
go longRunningTask(deadlineCtx, "任务3")
// 等待截止时间
<-deadlineCtx.Done()
fmt.Printf("截止时间到达,错误: %v\n", deadlineCtx.Err())
}
func longRunningTask(ctx context.Context, name string) {
fmt.Printf("%s 开始执行\n", name)
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
fmt.Printf("%s 被取消: %v\n", name, ctx.Err())
return
default:
fmt.Printf("%s 执行步骤 %d\n", name, i+1)
time.Sleep(300 * time.Millisecond)
}
}
fmt.Printf("%s 完成\n", name)
}
func processRequest(ctx context.Context) {
// 从context中获取值
user := ctx.Value(key("user"))
requestID := ctx.Value(key("requestID"))
fmt.Printf("处理请求 - 用户: %v, 请求ID: %v\n", user, requestID)
// 模拟处理时间
time.Sleep(200 * time.Millisecond)
fmt.Println("请求处理完成")
}
type key string' > context_basic.go && go run context_basic.go
Context在HTTP服务中的应用
echo 'package main
import (
"context"
"fmt"
"net/http"
"time"
)
// 模拟数据库查询
func queryDatabase(ctx context.Context, query string) (string, error) {
// 模拟数据库查询时间
select {
case <-time.After(2 * time.Second):
return fmt.Sprintf("查询结果: %s", query), nil
case <-ctx.Done():
return "", ctx.Err()
}
}
// 模拟外部API调用
func callExternalAPI(ctx context.Context, endpoint string) (string, error) {
// 创建带超时的子context
apiCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
select {
case <-time.After(1500 * time.Millisecond):
return fmt.Sprintf("API响应: %s", endpoint), nil
case <-apiCtx.Done():
return "", apiCtx.Err()
}
}
// HTTP处理器
func dataHandler(w http.ResponseWriter, r *http.Request) {
// 获取请求的context
ctx := r.Context()
// 添加请求ID到context
requestID := fmt.Sprintf("req-%d", time.Now().Unix())
ctx = context.WithValue(ctx, "requestID", requestID)
fmt.Printf("开始处理请求: %s\n", requestID)
// 并发执行多个操作
type result struct {
name string
data string
err error
}
results := make(chan result, 2)
// 数据库查询
go func() {
data, err := queryDatabase(ctx, "SELECT * FROM users")
results <- result{"database", data, err}
}()
// API调用
go func() {
data, err := callExternalAPI(ctx, "/api/external")
results <- result{"api", data, err}
}()
// 收集结果
var dbResult, apiResult result
for i := 0; i < 2; i++ {
select {
case res := <-results:
if res.name == "database" {
dbResult = res
} else {
apiResult = res
}
case <-ctx.Done():
http.Error(w, "请求被取消", http.StatusRequestTimeout)
fmt.Printf("请求 %s 被取消: %v\n", requestID, ctx.Err())
return
}
}
// 检查错误
if dbResult.err != nil {
http.Error(w, fmt.Sprintf("数据库错误: %v", dbResult.err), http.StatusInternalServerError)
return
}
if apiResult.err != nil {
http.Error(w, fmt.Sprintf("API错误: %v", apiResult.err), http.StatusInternalServerError)
return
}
// 返回成功响应
response := fmt.Sprintf("请求ID: %s\n数据库: %s\nAPI: %s",
requestID, dbResult.data, apiResult.data)
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(response))
fmt.Printf("请求 %s 处理完成\n", requestID)
}
func main() {
fmt.Println("=== HTTP服务器Context示例 ===")
fmt.Println("启动HTTP服务器...")
http.HandleFunc("/data", dataHandler)
server := &http.Server{
Addr: ":8080",
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
}
fmt.Println("服务器运行在 http://localhost:8080")
fmt.Println("访问 http://localhost:8080/data 测试Context")
fmt.Println("按 Ctrl+C 停止服务器")
if err := server.ListenAndServe(); err != nil {
fmt.Printf("服务器错误: %v\n", err)
}
}' > context_http.go
并发最佳实践
避免常见陷阱
竞态条件和数据竞争
echo 'package main
import (
"fmt"
"runtime"
"sync"
"sync/atomic"
"time"
)
// 错误示例:数据竞争
type UnsafeCounter struct {
count int
}
func (c *UnsafeCounter) Increment() {
c.count++ // 数据竞争
}
func (c *UnsafeCounter) Value() int {
return c.count // 数据竞争
}
// 正确示例1:使用Mutex
type SafeCounter struct {
mu sync.Mutex
count int
}
func (c *SafeCounter) Increment() {
c.mu.Lock()
defer c.mu.Unlock()
c.count++
}
func (c *SafeCounter) Value() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.count
}
// 正确示例2:使用原子操作
type AtomicCounter struct {
count int64
}
func (c *AtomicCounter) Increment() {
atomic.AddInt64(&c.count, 1)
}
func (c *AtomicCounter) Value() int64 {
return atomic.LoadInt64(&c.count)
}
func main() {
fmt.Println("=== 数据竞争演示 ===")
const numGoroutines = 100
const numIncrements = 1000
// 测试不安全的计数器
fmt.Println("\n测试不安全计数器:")
unsafeCounter := &UnsafeCounter{}
var wg sync.WaitGroup
start := time.Now()
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < numIncrements; j++ {
unsafeCounter.Increment()
}
}()
}
wg.Wait()
unsafeTime := time.Since(start)
fmt.Printf("不安全计数器结果: %d (期望: %d)\n",
unsafeCounter.Value(), numGoroutines*numIncrements)
fmt.Printf("耗时: %v\n", unsafeTime)
// 测试安全的计数器(Mutex)
fmt.Println("\n测试安全计数器(Mutex):")
safeCounter := &SafeCounter{}
start = time.Now()
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < numIncrements; j++ {
safeCounter.Increment()
}
}()
}
wg.Wait()
safeTime := time.Since(start)
fmt.Printf("安全计数器结果: %d\n", safeCounter.Value())
fmt.Printf("耗时: %v\n", safeTime)
// 测试原子计数器
fmt.Println("\n测试原子计数器:")
atomicCounter := &AtomicCounter{}
start = time.Now()
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < numIncrements; j++ {
atomicCounter.Increment()
}
}()
}
wg.Wait()
atomicTime := time.Since(start)
fmt.Printf("原子计数器结果: %d\n", atomicCounter.Value())
fmt.Printf("耗时: %v\n", atomicTime)
fmt.Println("\n=== 性能对比 ===")
fmt.Printf("不安全版本: %v\n", unsafeTime)
fmt.Printf("Mutex版本: %v (%.2fx slower)\n",
safeTime, float64(safeTime)/float64(unsafeTime))
fmt.Printf("原子版本: %v (%.2fx slower)\n",
atomicTime, float64(atomicTime)/float64(unsafeTime))
fmt.Println("\n=== 检测数据竞争 ===")
fmt.Println("使用 'go run -race' 命令可以检测数据竞争")
fmt.Println("例如: go run -race race_detection.go")
demonstrateRaceDetection()
}
func demonstrateRaceDetection() {
fmt.Println("\n演示竞态条件检测:")
var data int
var wg sync.WaitGroup
// 启动写者
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 5; i++ {
data = i // 写操作
time.Sleep(10 * time.Millisecond)
}
}()
// 启动读者
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 5; i++ {
fmt.Printf("读取到: %d\n", data) // 读操作
time.Sleep(10 * time.Millisecond)
}
}()
wg.Wait()
fmt.Println("\n注意:上面的代码存在数据竞争")
fmt.Println("运行 'go run -race' 会检测到这个问题")
fmt.Println("\n=== Goroutine泄漏检测 ===")
fmt.Printf("当前goroutine数量: %d\n", runtime.NumGoroutine())
// 创建一些可能泄漏的goroutine
ch := make(chan int)
for i := 0; i < 3; i++ {
go func(id int) {
// 这个goroutine会一直等待,造成泄漏
<-ch
fmt.Printf("Goroutine %d 完成\n", id)
}(i)
}
time.Sleep(100 * time.Millisecond)
fmt.Printf("创建goroutine后数量: %d\n", runtime.NumGoroutine())
// 正确的做法:关闭channel或发送信号
close(ch)
time.Sleep(100 * time.Millisecond)
fmt.Printf("关闭channel后数量: %d\n", runtime.NumGoroutine())
}
go run race_conditions.go
性能优化技巧
并发性能优化
package main
import (
"fmt"
"runtime"
"sync"
"time"
)
func main() {
fmt.Println("=== 并发性能优化技巧 ===")
// 1. 合理设置GOMAXPROCS
fmt.Printf("CPU核心数: %d\n", runtime.NumCPU())
fmt.Printf("当前GOMAXPROCS: %d\n", runtime.GOMAXPROCS(0))
// 2. 工作池大小优化
optimizeWorkerPoolSize()
// 3. 批处理优化
demonstrateBatching()
// 4. 缓冲区大小优化
optimizeBufferSize()
// 5. 避免过度并发
avoidExcessiveConcurrency()
}
func optimizeWorkerPoolSize() {
fmt.Println("\n=== 工作池大小优化 ===")
const numTasks = 1000
tasks := make([]int, numTasks)
for i := range tasks {
tasks[i] = i + 1
}
// 测试不同的工作池大小
workerCounts := []int{1, 2, 4, 8, 16, 32, 64}
for _, workerCount := range workerCounts {
duration := benchmarkWorkerPool(tasks, workerCount)
fmt.Printf("工作者数量: %2d, 耗时: %v\n", workerCount, duration)
}
}
func benchmarkWorkerPool(tasks []int, workerCount int) time.Duration {
start := time.Now()
jobs := make(chan int, len(tasks))
var wg sync.WaitGroup
// 启动工作者
for i := 0; i < workerCount; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for task := range jobs {
// 模拟CPU密集型工作
_ = fibonacci(20)
_ = task
}
}()
}
// 发送任务
for _, task := range tasks {
jobs <- task
}
close(jobs)
wg.Wait()
return time.Since(start)
}
func fibonacci(n int) int {
if n <= 1 {
return n
}
return fibonacci(n-1) + fibonacci(n-2)
}
func demonstrateBatching() {
fmt.Println("\n=== 批处理优化 ===")
const numItems = 10000
items := make([]int, numItems)
for i := range items {
items[i] = i + 1
}
// 单个处理
start := time.Now()
processItemsIndividually(items)
individualTime := time.Since(start)
// 批量处理
start = time.Now()
processItemsInBatches(items, 100)
batchTime := time.Since(start)
fmt.Printf("单个处理耗时: %v\n", individualTime)
fmt.Printf("批量处理耗时: %v\n", batchTime)
fmt.Printf("性能提升: %.2fx\n", float64(individualTime)/float64(batchTime))
}
func processItemsIndividually(items []int) {
ch := make(chan int)
var wg sync.WaitGroup
// 启动处理器
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for item := range ch {
// 模拟处理
time.Sleep(1 * time.Microsecond)
_ = item * 2
}
}()
}
// 发送单个项目
go func() {
for _, item := range items {
ch <- item
}
close(ch)
}()
wg.Wait()
}
func processItemsInBatches(items []int, batchSize int) {
ch := make(chan []int)
var wg sync.WaitGroup
// 启动处理器
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for batch := range ch {
// 批量处理
for _, item := range batch {
time.Sleep(1 * time.Microsecond)
_ = item * 2
}
}
}()
}
// 发送批次
go func() {
for i := 0; i < len(items); i += batchSize {
end := i + batchSize
if end > len(items) {
end = len(items)
}
batch := make([]int, end-i)
copy(batch, items[i:end])
ch <- batch
}
close(ch)
}()
wg.Wait()
}
func optimizeBufferSize() {
fmt.Println("\n=== 缓冲区大小优化 ===")
const numMessages = 10000
// 测试不同的缓冲区大小
bufferSizes := []int{0, 1, 10, 100, 1000}
for _, bufferSize := range bufferSizes {
duration := benchmarkBufferSize(numMessages, bufferSize)
fmt.Printf("缓冲区大小: %4d, 耗时: %v\n", bufferSize, duration)
}
}
func benchmarkBufferSize(numMessages, bufferSize int) time.Duration {
start := time.Now()
ch := make(chan int, bufferSize)
var wg sync.WaitGroup
// 消费者
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < numMessages; i++ {
<-ch
// 模拟处理时间
time.Sleep(1 * time.Microsecond)
}
}()
// 生产者
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < numMessages; i++ {
ch <- i
}
}()
wg.Wait()
return time.Since(start)
}
func avoidExcessiveConcurrency() {
fmt.Println("\n=== 避免过度并发 ===")
const numTasks = 1000
// 过度并发:为每个任务创建一个goroutine
start := time.Now()
var wg sync.WaitGroup
for i := 0; i < numTasks; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// 简单任务
time.Sleep(1 * time.Millisecond)
_ = id * 2
}(i)
}
wg.Wait()
excessiveTime := time.Since(start)
// 合理并发:使用工作池
start = time.Now()
jobs := make(chan int, numTasks)
// 启动合理数量的工作者
numWorkers := runtime.NumCPU()
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for job := range jobs {
time.Sleep(1 * time.Millisecond)
_ = job * 2
}
}()
}
// 发送任务
for i := 0; i < numTasks; i++ {
jobs <- i
}
close(jobs)
wg.Wait()
reasonableTime := time.Since(start)
fmt.Printf("过度并发(%d个goroutine): %v\n", numTasks, excessiveTime)
fmt.Printf("合理并发(%d个goroutine): %v\n", numWorkers, reasonableTime)
fmt.Printf("性能提升: %.2fx\n", float64(excessiveTime)/float64(reasonableTime))
fmt.Println("\n=== 并发最佳实践总结 ===")
fmt.Println("1. 根据CPU核心数设置合理的工作者数量")
fmt.Println("2. 使用适当大小的缓冲区减少阻塞")
fmt.Println("3. 批量处理减少goroutine创建开销")
fmt.Println("4. 避免为简单任务创建过多goroutine")
fmt.Println("5. 使用工作池模式管理并发")
fmt.Println("6. 及时关闭channel避免goroutine泄漏")
fmt.Println("7. 使用Context管理goroutine生命周期")
fmt.Println("8. 使用race detector检测数据竞争")
}
go run performance_optimization.go
总结
本章详细介绍了Go语言的并发编程,包括:
- Goroutine基础 - 轻量级线程的创建和管理
- Channel通信 - 安全的数据传递机制
- Select语句 - 多路复用和非阻塞操作
- 同步原语 - WaitGroup、Mutex、RWMutex的使用
- 并发模式 - 生产者-消费者、管道、扇入扇出等模式
- Context包 - 请求作用域和取消机制
- 最佳实践 - 性能优化和常见陷阱避免
关键要点
- 使用goroutine实现轻量级并发
- 通过channel进行安全的数据交换
- 合理使用同步原语保护共享资源
- 掌握常用的并发设计模式
- 使用Context管理goroutine生命周期
- 注意避免数据竞争和goroutine泄漏
下一步
下一章将学习Go语言的包和模块系统,了解如何组织和管理大型Go项目。