🏗️ 实战项目:生产级 REST API
从零构建一个生产级任务管理系统,综合运用全部 9 章知识——分层架构、JWT 认证、PostgreSQL + Redis、Docker 部署,以及完整测试覆盖。
项目:TaskFlow — 任务管理系统 API
功能包括:用户注册/登录(JWT)、CRUD 任务、分页查询、PostgreSQL 持久化、Redis 缓存、Docker Compose 一键部署。
1. 项目架构设计
目录结构
taskflow/
cmd/
server/
main.go // 程序入口,组装依赖
internal/ // 内部包,外部不可导入
handler/
user.go // 用户注册/登录 Handler
task.go // 任务 CRUD Handler
middleware.go // JWT 认证中间件
service/
user.go // 用户业务逻辑
task.go // 任务业务逻辑
repository/
user_pg.go // PostgreSQL 用户数据访问
task_pg.go // PostgreSQL 任务数据访问
cache/
redis.go // Redis 缓存层
domain/
user.go // User 领域模型 + 接口
task.go // Task 领域模型 + 接口
errors.go // 业务错误类型
pkg/ // 可被外部使用的公共包
jwt/ // JWT 工具
respond/ // 统一响应格式
config/
config.go // 配置结构体
config.yaml // 默认配置
migrations/ // 数据库迁移文件
Dockerfile
docker-compose.yml
go.mod
分层架构图
┌─────────────────────────────────────────────┐
│ HTTP 客户端(浏览器 / App / curl) │
└───────────────────┬─────────────────────────┘
│ HTTP Request
┌───────────────────▼─────────────────────────┐
│ Handler 层(internal/handler) │
│ 参数绑定 · 鉴权中间件 · 响应格式化 │
└───────────────────┬─────────────────────────┘
│ 调用 Service 接口
┌───────────────────▼─────────────────────────┐
│ Service 层(internal/service) │
│ 业务规则 · 权限校验 · 事务编排 │
└──────────┬────────────────────┬─────────────┘
│ 调用 Repository │ 调用 Cache
┌──────────▼──────┐ ┌──────────▼──────────┐
│ Repository 层 │ │ Cache 层 │
│ internal/repo │ │ internal/cache │
└──────────┬──────┘ └──────────┬───────────┘
│ │
┌──────────▼──────┐ ┌──────────▼───────────┐
│ PostgreSQL │ │ Redis │
└─────────────────┘ └──────────────────────┘
API 端点设计
| 方法 | 路径 | 描述 | 认证 |
|---|---|---|---|
| POST | /api/v1/auth/register | 用户注册 | 否 |
| POST | /api/v1/auth/login | 用户登录,返回 JWT | 否 |
| GET | /api/v1/tasks | 获取当前用户任务列表(分页) | JWT |
| POST | /api/v1/tasks | 创建任务 | JWT |
| GET | /api/v1/tasks/:id | 获取单个任务 | JWT |
| PUT | /api/v1/tasks/:id | 更新任务 | JWT |
| DELETE | /api/v1/tasks/:id | 删除任务 | JWT |
| GET | /health | 健康检查 | 否 |
2. 项目初始化与配置管理
go.mod 与依赖
go mod init github.com/myorg/taskflow
# 核心依赖
go get github.com/gin-gonic/gin@latest # Web 框架
go get github.com/jackc/pgx/v5@latest # PostgreSQL 驱动(无 CGO)
go get github.com/redis/go-redis/v9@latest # Redis 客户端
go get github.com/golang-jwt/jwt/v5@latest # JWT
go get github.com/spf13/viper@latest # 配置管理
go get golang.org/x/crypto@latest # bcrypt 密码哈希
go get github.com/google/uuid@latest # UUID 生成
# 测试依赖
go get github.com/stretchr/testify@latest
go get github.com/stretchr/mock@latest
配置管理
// config/config.go
package config
import (
"fmt"
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
}
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
ReadTimeout int `mapstructure:"read_timeout_sec"`
WriteTimeout int `mapstructure:"write_timeout_sec"`
}
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
MaxConns int `mapstructure:"max_conns"`
}
func (d DatabaseConfig) DSN() string {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s pool_max_conns=%d",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, d.MaxConns)
}
type RedisConfig struct {
Addr string `mapstructure:"addr"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
}
// Load 加载配置(支持环境变量覆盖)
func Load(path string) (*Config, error) {
v := viper.New()
v.SetConfigFile(path)
v.SetConfigType("yaml")
// 环境变量覆盖(如 TASKFLOW_JWT_SECRET 覆盖 jwt.secret)
v.SetEnvPrefix("TASKFLOW")
v.AutomaticEnv()
if err := v.ReadInConfig(); err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
return &cfg, nil
}
# config/config.yaml
server:
host: "0.0.0.0"
port: 8080
read_timeout_sec: 10
write_timeout_sec: 10
database:
host: "localhost"
port: 5432
user: "taskflow"
password: "secret"
dbname: "taskflow"
sslmode: "disable"
max_conns: 20
redis:
addr: "localhost:6379"
password: ""
db: 0
jwt:
secret: "change-me-in-production" # 生产环境通过环境变量覆盖
expire_hour: 24
3. 数据库设计与 Repository 层
建表 SQL
-- migrations/001_init.sql
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE users (
id BIGSERIAL PRIMARY KEY,
uuid UUID NOT NULL DEFAULT uuid_generate_v4() UNIQUE,
name VARCHAR(100) NOT NULL,
email VARCHAR(255) NOT NULL UNIQUE,
password VARCHAR(255) NOT NULL, -- bcrypt hash
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_users_email ON users(email);
CREATE TYPE task_status AS ENUM ('todo', 'in_progress', 'done');
CREATE TYPE task_priority AS ENUM ('low', 'medium', 'high');
CREATE TABLE tasks (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
title VARCHAR(500) NOT NULL,
description TEXT,
status task_status NOT NULL DEFAULT 'todo',
priority task_priority NOT NULL DEFAULT 'medium',
due_date TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_tasks_user_id ON tasks(user_id);
CREATE INDEX idx_tasks_status ON tasks(user_id, status); -- 复合索引,按用户+状态查询
领域模型与接口定义
// internal/domain/task.go
package domain
import (
"context"
"time"
)
type TaskStatus string
type TaskPriority string
const (
StatusTodo TaskStatus = "todo"
StatusInProgress TaskStatus = "in_progress"
StatusDone TaskStatus = "done"
PriorityLow TaskPriority = "low"
PriorityMedium TaskPriority = "medium"
PriorityHigh TaskPriority = "high"
)
type Task struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Title string `json:"title"`
Description string `json:"description"`
Status TaskStatus `json:"status"`
Priority TaskPriority `json:"priority"`
DueDate *time.Time `json:"due_date,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ListParams 分页与过滤参数
type ListParams struct {
UserID int64
Status TaskStatus // 空字符串表示不过滤
Page int // 从 1 开始
PageSize int // 每页数量,最大 100
}
// PagedResult 分页结果
type PagedResult[T any] struct {
Data []T `json:"data"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int64 `json:"total_pages"`
}
// TaskRepository 数据访问接口
type TaskRepository interface {
Create(ctx context.Context, task *Task) error
FindByID(ctx context.Context, id int64) (*Task, error)
FindByUserID(ctx context.Context, params ListParams) (*PagedResult[Task], error)
Update(ctx context.Context, task *Task) error
Delete(ctx context.Context, id int64) error
}
// TaskService 业务逻辑接口
type TaskService interface {
Create(ctx context.Context, userID int64, input CreateTaskInput) (*Task, error)
List(ctx context.Context, userID int64, params ListParams) (*PagedResult[Task], error)
Get(ctx context.Context, userID, taskID int64) (*Task, error)
Update(ctx context.Context, userID, taskID int64, input UpdateTaskInput) (*Task, error)
Delete(ctx context.Context, userID, taskID int64) error
}
type CreateTaskInput struct {
Title string `json:"title" binding:"required,max=500"`
Description string `json:"description"`
Priority TaskPriority `json:"priority"`
DueDate *time.Time `json:"due_date"`
}
type UpdateTaskInput struct {
Title *string `json:"title"`
Description *string `json:"description"`
Status *TaskStatus `json:"status"`
Priority *TaskPriority `json:"priority"`
DueDate *time.Time `json:"due_date"`
}
// internal/domain/errors.go
package domain
import "errors"
var (
ErrNotFound = errors.New("资源不存在")
ErrForbidden = errors.New("无权操作此资源")
ErrConflict = errors.New("资源已存在")
ErrBadRequest = errors.New("请求参数错误")
)
PostgreSQL Repository 实现
// internal/repository/task_pg.go
package repository
import (
"context"
"errors"
"math"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/myorg/taskflow/internal/domain"
)
type taskRepository struct {
pool *pgxpool.Pool
}
func NewTaskRepository(pool *pgxpool.Pool) domain.TaskRepository {
return &taskRepository{pool: pool}
}
func (r *taskRepository) Create(ctx context.Context, task *domain.Task) error {
query := `
INSERT INTO tasks (user_id, title, description, status, priority, due_date)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, created_at, updated_at`
return r.pool.QueryRow(ctx, query,
task.UserID, task.Title, task.Description,
task.Status, task.Priority, task.DueDate,
).Scan(&task.ID, &task.CreatedAt, &task.UpdatedAt)
}
func (r *taskRepository) FindByID(ctx context.Context, id int64) (*domain.Task, error) {
query := `
SELECT id, user_id, title, description, status, priority, due_date, created_at, updated_at
FROM tasks WHERE id = $1`
task := &domain.Task{}
err := r.pool.QueryRow(ctx, query, id).Scan(
&task.ID, &task.UserID, &task.Title, &task.Description,
&task.Status, &task.Priority, &task.DueDate,
&task.CreatedAt, &task.UpdatedAt,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil // 未找到返回 nil, nil,由 Service 层转换为 ErrNotFound
}
return task, err
}
// FindByUserID 带分页和状态过滤的列表查询
func (r *taskRepository) FindByUserID(ctx context.Context, params domain.ListParams) (*domain.PagedResult[domain.Task], error) {
if params.PageSize <= 0 || params.PageSize > 100 {
params.PageSize = 20
}
if params.Page <= 0 {
params.Page = 1
}
offset := (params.Page - 1) * params.PageSize
// 动态构建 WHERE 子句
args := []any{params.UserID}
where := "user_id = $1"
if params.Status != "" {
args = append(args, params.Status)
where += " AND status = $" + fmt.Sprintf("%d", len(args))
}
// 先查总数
var total int64
countQuery := "SELECT COUNT(*) FROM tasks WHERE " + where
if err := r.pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, err
}
// 再查数据(追加分页参数)
args = append(args, params.PageSize, offset)
dataQuery := fmt.Sprintf(`
SELECT id, user_id, title, description, status, priority, due_date, created_at, updated_at
FROM tasks WHERE %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d`, where, len(args)-1, len(args))
rows, err := r.pool.Query(ctx, dataQuery, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var tasks []domain.Task
for rows.Next() {
var t domain.Task
if err := rows.Scan(
&t.ID, &t.UserID, &t.Title, &t.Description,
&t.Status, &t.Priority, &t.DueDate,
&t.CreatedAt, &t.UpdatedAt,
); err != nil {
return nil, err
}
tasks = append(tasks, t)
}
return &domain.PagedResult[domain.Task]{
Data: tasks,
Total: total,
Page: params.Page,
PageSize: params.PageSize,
TotalPages: int64(math.Ceil(float64(total) / float64(params.PageSize))),
}, nil
}
func (r *taskRepository) Update(ctx context.Context, task *domain.Task) error {
query := `
UPDATE tasks
SET title=$2, description=$3, status=$4, priority=$5, due_date=$6, updated_at=NOW()
WHERE id=$1
RETURNING updated_at`
return r.pool.QueryRow(ctx, query,
task.ID, task.Title, task.Description,
task.Status, task.Priority, task.DueDate,
).Scan(&task.UpdatedAt)
}
func (r *taskRepository) Delete(ctx context.Context, id int64) error {
_, err := r.pool.Exec(ctx, "DELETE FROM tasks WHERE id=$1", id)
return err
}
4. Service 层与业务逻辑
// internal/service/task.go
package service
import (
"context"
"fmt"
"github.com/myorg/taskflow/internal/domain"
)
type taskService struct {
repo domain.TaskRepository
cache TaskCache // 缓存接口
}
// TaskCache 缓存接口(方便测试 mock)
type TaskCache interface {
GetTaskList(ctx context.Context, key string) (*domain.PagedResult[domain.Task], error)
SetTaskList(ctx context.Context, key string, result *domain.PagedResult[domain.Task]) error
DeleteUserCache(ctx context.Context, userID int64) error
}
func NewTaskService(repo domain.TaskRepository, cache TaskCache) domain.TaskService {
return &taskService{repo: repo, cache: cache}
}
func (s *taskService) Create(ctx context.Context, userID int64, input domain.CreateTaskInput) (*domain.Task, error) {
task := &domain.Task{
UserID: userID,
Title: input.Title,
Description: input.Description,
Status: domain.StatusTodo,
Priority: input.Priority,
DueDate: input.DueDate,
}
if task.Priority == "" {
task.Priority = domain.PriorityMedium
}
if err := s.repo.Create(ctx, task); err != nil {
return nil, fmt.Errorf("创建任务失败: %w", err)
}
// 创建后使缓存失效
_ = s.cache.DeleteUserCache(ctx, userID)
return task, nil
}
func (s *taskService) List(ctx context.Context, userID int64, params domain.ListParams) (*domain.PagedResult[domain.Task], error) {
params.UserID = userID
// 尝试从缓存获取
cacheKey := fmt.Sprintf("tasks:user:%d:page:%d:size:%d:status:%s",
userID, params.Page, params.PageSize, params.Status)
if cached, err := s.cache.GetTaskList(ctx, cacheKey); err == nil && cached != nil {
return cached, nil
}
// 缓存未命中,查数据库
result, err := s.repo.FindByUserID(ctx, params)
if err != nil {
return nil, err
}
// 异步写入缓存(不阻塞响应)
go func() {
_ = s.cache.SetTaskList(context.Background(), cacheKey, result)
}()
return result, nil
}
func (s *taskService) Get(ctx context.Context, userID, taskID int64) (*domain.Task, error) {
task, err := s.repo.FindByID(ctx, taskID)
if err != nil {
return nil, err
}
if task == nil {
return nil, domain.ErrNotFound
}
// 权限校验:只能获取自己的任务
if task.UserID != userID {
return nil, domain.ErrForbidden
}
return task, nil
}
func (s *taskService) Update(ctx context.Context, userID, taskID int64, input domain.UpdateTaskInput) (*domain.Task, error) {
task, err := s.Get(ctx, userID, taskID) // 含权限校验
if err != nil {
return nil, err
}
// 仅更新非 nil 字段(Partial Update)
if input.Title != nil {
task.Title = *input.Title
}
if input.Description != nil {
task.Description = *input.Description
}
if input.Status != nil {
task.Status = *input.Status
}
if input.Priority != nil {
task.Priority = *input.Priority
}
if input.DueDate != nil {
task.DueDate = input.DueDate
}
if err := s.repo.Update(ctx, task); err != nil {
return nil, err
}
// 使相关缓存失效
_ = s.cache.DeleteUserCache(ctx, userID)
return task, nil
}
func (s *taskService) Delete(ctx context.Context, userID, taskID int64) error {
if _, err := s.Get(ctx, userID, taskID); err != nil { // 含权限校验
return err
}
if err := s.repo.Delete(ctx, taskID); err != nil {
return err
}
_ = s.cache.DeleteUserCache(ctx, userID)
return nil
}
5. HTTP Handler 层
JWT 中间件与统一响应
// pkg/respond/respond.go
package respond
import (
"net/http"
"github.com/gin-gonic/gin"
)
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
}
func OK(c *gin.Context, data any) {
c.JSON(http.StatusOK, Response{Code: 0, Message: "ok", Data: data})
}
func Created(c *gin.Context, data any) {
c.JSON(http.StatusCreated, Response{Code: 0, Message: "created", Data: data})
}
func BadRequest(c *gin.Context, msg string) {
c.JSON(http.StatusBadRequest, Response{Code: 400, Message: msg})
}
func Unauthorized(c *gin.Context) {
c.JSON(http.StatusUnauthorized, Response{Code: 401, Message: "未授权"})
}
func Forbidden(c *gin.Context) {
c.JSON(http.StatusForbidden, Response{Code: 403, Message: "禁止访问"})
}
func NotFound(c *gin.Context) {
c.JSON(http.StatusNotFound, Response{Code: 404, Message: "资源不存在"})
}
func InternalError(c *gin.Context) {
c.JSON(http.StatusInternalServerError, Response{Code: 500, Message: "服务器内部错误"})
}
// internal/handler/middleware.go
package handler
import (
"strings"
"github.com/gin-gonic/gin"
"github.com/myorg/taskflow/pkg/jwt"
"github.com/myorg/taskflow/pkg/respond"
)
const userIDKey = "userID"
// JWTMiddleware 验证 Bearer Token,将 userID 注入 Context
func JWTMiddleware(jwtSvc *jwt.Service) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
respond.Unauthorized(c)
c.Abort()
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
claims, err := jwtSvc.Verify(token)
if err != nil {
respond.Unauthorized(c)
c.Abort()
return
}
c.Set(userIDKey, claims.UserID)
c.Next()
}
}
func GetUserID(c *gin.Context) (int64, bool) {
v, exists := c.Get(userIDKey)
if !exists {
return 0, false
}
id, ok := v.(int64)
return id, ok
}
Task Handler
// internal/handler/task.go
package handler
import (
"errors"
"strconv"
"github.com/gin-gonic/gin"
"github.com/myorg/taskflow/internal/domain"
"github.com/myorg/taskflow/pkg/respond"
)
type TaskHandler struct {
service domain.TaskService
}
func NewTaskHandler(svc domain.TaskService) *TaskHandler {
return &TaskHandler{service: svc}
}
// RegisterRoutes 注册路由(需在 JWT 中间件组内调用)
func (h *TaskHandler) RegisterRoutes(rg *gin.RouterGroup) {
rg.GET("", h.List)
rg.POST("", h.Create)
rg.GET("/:id", h.Get)
rg.PUT("/:id", h.Update)
rg.DELETE("/:id", h.Delete)
}
// Create 创建任务
// POST /api/v1/tasks
func (h *TaskHandler) Create(c *gin.Context) {
userID, ok := GetUserID(c)
if !ok {
respond.Unauthorized(c)
return
}
var input domain.CreateTaskInput
if err := c.ShouldBindJSON(&input); err != nil {
respond.BadRequest(c, err.Error())
return
}
task, err := h.service.Create(c.Request.Context(), userID, input)
if err != nil {
respond.InternalError(c)
return
}
respond.Created(c, task)
}
// List 获取任务列表(支持分页和状态过滤)
// GET /api/v1/tasks?page=1&page_size=20&status=todo
func (h *TaskHandler) List(c *gin.Context) {
userID, ok := GetUserID(c)
if !ok {
respond.Unauthorized(c)
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
params := domain.ListParams{
Status: domain.TaskStatus(c.Query("status")),
Page: page,
PageSize: pageSize,
}
result, err := h.service.List(c.Request.Context(), userID, params)
if err != nil {
respond.InternalError(c)
return
}
respond.OK(c, result)
}
// Get 获取单个任务
// GET /api/v1/tasks/:id
func (h *TaskHandler) Get(c *gin.Context) {
userID, ok := GetUserID(c)
if !ok {
respond.Unauthorized(c)
return
}
taskID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
respond.BadRequest(c, "无效的任务 ID")
return
}
task, err := h.service.Get(c.Request.Context(), userID, taskID)
if err != nil {
h.handleServiceError(c, err)
return
}
respond.OK(c, task)
}
// Update 更新任务
// PUT /api/v1/tasks/:id
func (h *TaskHandler) Update(c *gin.Context) {
userID, ok := GetUserID(c)
if !ok {
respond.Unauthorized(c)
return
}
taskID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
respond.BadRequest(c, "无效的任务 ID")
return
}
var input domain.UpdateTaskInput
if err := c.ShouldBindJSON(&input); err != nil {
respond.BadRequest(c, err.Error())
return
}
task, err := h.service.Update(c.Request.Context(), userID, taskID, input)
if err != nil {
h.handleServiceError(c, err)
return
}
respond.OK(c, task)
}
// Delete 删除任务
// DELETE /api/v1/tasks/:id
func (h *TaskHandler) Delete(c *gin.Context) {
userID, ok := GetUserID(c)
if !ok {
respond.Unauthorized(c)
return
}
taskID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
respond.BadRequest(c, "无效的任务 ID")
return
}
if err := h.service.Delete(c.Request.Context(), userID, taskID); err != nil {
h.handleServiceError(c, err)
return
}
respond.OK(c, nil)
}
// handleServiceError 将领域错误转换为 HTTP 响应
func (h *TaskHandler) handleServiceError(c *gin.Context, err error) {
switch {
case errors.Is(err, domain.ErrNotFound):
respond.NotFound(c)
case errors.Is(err, domain.ErrForbidden):
respond.Forbidden(c)
default:
respond.InternalError(c)
}
}
路由组装(main.go)
// cmd/server/main.go
package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/jackc/pgx/v5/pgxpool"
goredis "github.com/redis/go-redis/v9"
"github.com/myorg/taskflow/config"
"github.com/myorg/taskflow/internal/cache"
"github.com/myorg/taskflow/internal/handler"
"github.com/myorg/taskflow/internal/repository"
"github.com/myorg/taskflow/internal/service"
"github.com/myorg/taskflow/pkg/jwt"
)
func main() {
cfg, err := config.Load("config/config.yaml")
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
// ── 初始化 PostgreSQL 连接池 ───────────────────────
pool, err := pgxpool.New(context.Background(), cfg.Database.DSN())
if err != nil {
log.Fatalf("连接数据库失败: %v", err)
}
defer pool.Close()
// ── 初始化 Redis ───────────────────────────────────
rdb := goredis.NewClient(&goredis.Options{
Addr: cfg.Redis.Addr,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
defer rdb.Close()
// ── 依赖组装(DI 手动注入)────────────────────────
jwtSvc := jwt.NewService(cfg.JWT.Secret, time.Duration(cfg.JWT.ExpireHour)*time.Hour)
taskRepo := repository.NewTaskRepository(pool)
userRepo := repository.NewUserRepository(pool)
taskCache := cache.NewRedisTaskCache(rdb)
taskSvc := service.NewTaskService(taskRepo, taskCache)
userSvc := service.NewUserService(userRepo, jwtSvc)
taskHandler := handler.NewTaskHandler(taskSvc)
userHandler := handler.NewUserHandler(userSvc)
// ── Gin 路由 ───────────────────────────────────────
r := gin.New()
r.Use(gin.Recovery())
r.Use(gin.Logger())
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
v1 := r.Group("/api/v1")
{
auth := v1.Group("/auth")
auth.POST("/register", userHandler.Register)
auth.POST("/login", userHandler.Login)
// 需要认证的路由组
tasks := v1.Group("/tasks")
tasks.Use(handler.JWTMiddleware(jwtSvc))
taskHandler.RegisterRoutes(tasks)
}
// ── 优雅关机 ──────────────────────────────────────
srv := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: r,
ReadTimeout: time.Duration(cfg.Server.ReadTimeout) * time.Second,
WriteTimeout: time.Duration(cfg.Server.WriteTimeout) * time.Second,
}
go func() {
log.Printf("服务启动,监听 %s", srv.Addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("服务器错误: %v", err)
}
}()
// 等待 SIGINT 或 SIGTERM
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("收到关机信号,优雅停止...")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Fatalf("强制关机: %v", err)
}
log.Println("服务器已停止")
}
6. Redis 缓存层
// internal/cache/redis.go
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/myorg/taskflow/internal/domain"
"github.com/myorg/taskflow/internal/service"
)
const (
taskListTTL = 5 * time.Minute
userCacheGlob = "tasks:user:%d:*"
)
type redisTaskCache struct {
rdb *redis.Client
}
func NewRedisTaskCache(rdb *redis.Client) service.TaskCache {
return &redisTaskCache{rdb: rdb}
}
func (c *redisTaskCache) GetTaskList(ctx context.Context, key string) (*domain.PagedResult[domain.Task], error) {
val, err := c.rdb.Get(ctx, key).Bytes()
if err != nil {
return nil, err // redis.Nil 表示缓存未命中
}
var result domain.PagedResult[domain.Task]
if err := json.Unmarshal(val, &result); err != nil {
return nil, err
}
return &result, nil
}
func (c *redisTaskCache) SetTaskList(ctx context.Context, key string, result *domain.PagedResult[domain.Task]) error {
data, err := json.Marshal(result)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, data, taskListTTL).Err()
}
// DeleteUserCache 使某用户的所有任务缓存失效(用 SCAN 删除匹配 key)
func (c *redisTaskCache) DeleteUserCache(ctx context.Context, userID int64) error {
pattern := fmt.Sprintf(userCacheGlob, userID)
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return err
}
if len(keys) > 0 {
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
return err
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
Cache-Aside 模式的工作原理
读取时:先查缓存 → 命中则返回 → 未命中则查 DB → 结果写入缓存。写入时:先更新 DB → 使相关缓存失效(而非更新缓存,避免竞态条件)。使缓存失效比更新缓存更安全,尤其在高并发场景下。
7. 测试覆盖
// internal/handler/task_test.go
package handler_test
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/myorg/taskflow/internal/domain"
"github.com/myorg/taskflow/internal/handler"
)
// ─── Mock TaskService ──────────────────────────────────────────
type mockTaskService struct {
createFn func(ctx context.Context, userID int64, input domain.CreateTaskInput) (*domain.Task, error)
listFn func(ctx context.Context, userID int64, params domain.ListParams) (*domain.PagedResult[domain.Task], error)
getFn func(ctx context.Context, userID, taskID int64) (*domain.Task, error)
}
func (m *mockTaskService) Create(ctx context.Context, userID int64, input domain.CreateTaskInput) (*domain.Task, error) {
return m.createFn(ctx, userID, input)
}
func (m *mockTaskService) List(ctx context.Context, userID int64, p domain.ListParams) (*domain.PagedResult[domain.Task], error) {
return m.listFn(ctx, userID, p)
}
func (m *mockTaskService) Get(ctx context.Context, userID, taskID int64) (*domain.Task, error) {
return m.getFn(ctx, userID, taskID)
}
func (m *mockTaskService) Update(ctx context.Context, userID, taskID int64, input domain.UpdateTaskInput) (*domain.Task, error) {
return nil, nil
}
func (m *mockTaskService) Delete(ctx context.Context, userID, taskID int64) error {
return nil
}
// ─── 辅助函数 ─────────────────────────────────────────────────
// setupRouter 创建测试用的 Gin Engine,并注入 userID
func setupRouter(h *handler.TaskHandler, userID int64) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
// 在中间件中注入模拟的 userID(绕过 JWT 验证)
r.Use(func(c *gin.Context) {
c.Set("userID", userID)
c.Next()
})
h.RegisterRoutes(r.Group("/tasks"))
return r
}
// ─── 测试用例 ─────────────────────────────────────────────────
func TestTaskHandler_Create_Success(t *testing.T) {
svc := &mockTaskService{
createFn: func(_ context.Context, userID int64, input domain.CreateTaskInput) (*domain.Task, error) {
return &domain.Task{
ID: 1,
UserID: userID,
Title: input.Title,
Status: domain.StatusTodo,
}, nil
},
}
h := handler.NewTaskHandler(svc)
router := setupRouter(h, 42)
body := `{"title":"完成项目报告","priority":"high"}`
req := httptest.NewRequest(http.MethodPost, "/tasks", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusCreated, rec.Code)
var resp map[string]any
json.NewDecoder(rec.Body).Decode(&resp)
assert.Equal(t, float64(0), resp["code"])
data := resp["data"].(map[string]any)
assert.Equal(t, "完成项目报告", data["title"])
}
func TestTaskHandler_Get_Forbidden(t *testing.T) {
svc := &mockTaskService{
getFn: func(_ context.Context, _, _ int64) (*domain.Task, error) {
return nil, domain.ErrForbidden // 模拟无权访问
},
}
h := handler.NewTaskHandler(svc)
router := setupRouter(h, 42)
req := httptest.NewRequest(http.MethodGet, "/tasks/99", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusForbidden, rec.Code)
}
func TestTaskHandler_List_WithPagination(t *testing.T) {
svc := &mockTaskService{
listFn: func(_ context.Context, _ int64, params domain.ListParams) (*domain.PagedResult[domain.Task], error) {
assert.Equal(t, 2, params.Page)
assert.Equal(t, 5, params.PageSize)
return &domain.PagedResult[domain.Task]{
Data: []domain.Task{{ID: 6, Title: "Task 6"}},
Total: 10, Page: 2, PageSize: 5, TotalPages: 2,
}, nil
},
}
h := handler.NewTaskHandler(svc)
router := setupRouter(h, 1)
req := httptest.NewRequest(http.MethodGet, "/tasks?page=2&page_size=5", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
8. Docker 部署
多阶段构建 Dockerfile
# Dockerfile
# ── 构建阶段 ───────────────────────────────────────────────────
FROM golang:1.22-alpine AS builder
WORKDIR /app
# 先复制依赖文件,利用 Docker 层缓存
COPY go.mod go.sum ./
RUN go mod download
# 复制全部源码并构建
COPY . .
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \
go build -ldflags="-s -w -X main.version=${VERSION}" \
-o bin/server ./cmd/server
# ── 运行阶段:极小化镜像 ───────────────────────────────────────
FROM gcr.io/distroless/static-debian12
WORKDIR /app
# 仅复制二进制和配置
COPY --from=builder /app/bin/server .
COPY --from=builder /app/config/config.yaml config/
# 非 root 用户运行(安全最佳实践)
USER nonroot:nonroot
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD ["/app/server", "healthcheck"]
ENTRYPOINT ["/app/server"]
docker-compose.yml
# docker-compose.yml
version: '3.9'
services:
app:
build:
context: .
dockerfile: Dockerfile
ports:
- "8080:8080"
environment:
TASKFLOW_DATABASE_HOST: postgres
TASKFLOW_DATABASE_PASSWORD: ${DB_PASSWORD:-secret}
TASKFLOW_JWT_SECRET: ${JWT_SECRET:-change-me}
TASKFLOW_REDIS_ADDR: redis:6379
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
restart: unless-stopped
postgres:
image: postgres:16-alpine
environment:
POSTGRES_USER: taskflow
POSTGRES_PASSWORD: ${DB_PASSWORD:-secret}
POSTGRES_DB: taskflow
volumes:
- postgres_data:/var/lib/postgresql/data
- ./migrations:/docker-entrypoint-initdb.d # 自动执行迁移
healthcheck:
test: ["CMD-SHELL", "pg_isready -U taskflow"]
interval: 10s
timeout: 5s
retries: 5
ports:
- "5432:5432" # 开发时暴露,生产环境可移除
redis:
image: redis:7-alpine
command: redis-server --appendonly yes # 开启 AOF 持久化
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 3s
retries: 5
volumes:
postgres_data:
redis_data:
部署命令
# 复制并填写环境变量
cp .env.example .env
# 编辑 .env,填入 DB_PASSWORD 和 JWT_SECRET
# 一键启动(首次会构建镜像)
docker compose up -d
# 查看日志
docker compose logs -f app
# 健康检查
curl http://localhost:8080/health
# 停止并清理
docker compose down
生产部署清单
发布前确保:① JWT_SECRET 使用强随机字符串(至少 32 字节)② 数据库密码通过 Secrets 管理器注入(如 AWS Secrets Manager / HashiCorp Vault)③ 开启 HTTPS(反向代理层)④ 配置日志收集(如 ELK / Loki)⑤ 设置 Prometheus 指标端点和告警规则 ⑥ 准备好回滚方案。