package utils import ( "database/sql" "encoding/csv" "fmt" "os" "path/filepath" _ "github.com/lib/pq" ) // ConnectDB 连接数据库 func ConnectDB(config map[string]string) (*sql.DB, error) { // 构建连接字符串 connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", config["host"], config["port"], config["username"], config["password"], config["database"]) db, err := sql.Open("postgres", connStr) if err != nil { return nil, fmt.Errorf("数据库连接失败: %v", err) } // 测试连接 err = db.Ping() if err != nil { return nil, fmt.Errorf("数据库连接测试失败: %v", err) } return db, nil } // ExportTableToCSV 导出表数据到CSV func ExportTableToCSV(db *sql.DB, dbName string, tableName string) error { // 查询表数据 query := fmt.Sprintf("SELECT * FROM %s ORDER BY id DESC", tableName) rows, err := db.Query(query) if err != nil { return fmt.Errorf("查询表 %s 失败: %v", tableName, err) } defer rows.Close() // 获取列名 columns, err := rows.Columns() if err != nil { return fmt.Errorf("获取列名失败: %v", err) } // 创建输出目录 outputDir := filepath.Join("output", dbName) if err := os.MkdirAll(outputDir, 0755); err != nil { return fmt.Errorf("创建输出目录失败: %v", err) } // 创建CSV文件 filePath := filepath.Join(outputDir, tableName+".csv") file, err := os.Create(filePath) if err != nil { return fmt.Errorf("创建CSV文件失败: %v", err) } defer file.Close() writer := csv.NewWriter(file) defer writer.Flush() // 写入列名 if err := writer.Write(columns); err != nil { return fmt.Errorf("写入列名失败: %v", err) } // 准备接收数据的切片 values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } // 读取数据并写入CSV recordCount := 0 for rows.Next() { err := rows.Scan(valuePtrs...) if err != nil { return fmt.Errorf("读取数据失败: %v", err) } // 转换数据为字符串 record := make([]string, len(columns)) for i, val := range values { if val == nil { record[i] = "" } else { record[i] = fmt.Sprintf("%v", val) } } // 写入CSV if err := writer.Write(record); err != nil { return fmt.Errorf("写入CSV数据失败: %v", err) } recordCount++ } // 检查遍历过程中是否有错误 if err := rows.Err(); err != nil { return fmt.Errorf("遍历数据失败: %v", err) } fmt.Printf("表 %s 导出完成,共 %d 条数据,文件位置: %s\n", tableName, recordCount, filePath) return nil } // ExportAllTables 导出所有表到CSV func ExportAllTables(db *sql.DB, dbName string, tables []string) error { for _, table := range tables { fmt.Printf("正在导出表: %s\n", table) if err := ExportTableToCSV(db, dbName, table); err != nil { return fmt.Errorf("导出表 %s 失败: %v", table, err) } } return nil }