| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- package utils
- import (
- "database/sql"
- "encoding/csv"
- "fmt"
- "os"
- "path/filepath"
- _ "github.com/lib/pq"
- )
- // ConnectDB 连接数据库
- func ConnectDB(config map[string]string) (*sql.DB, error) {
- // 构建连接字符串
- connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
- config["host"], config["port"], config["username"], config["password"], config["database"])
- db, err := sql.Open("postgres", connStr)
- if err != nil {
- return nil, fmt.Errorf("数据库连接失败: %v", err)
- }
- // 测试连接
- err = db.Ping()
- if err != nil {
- return nil, fmt.Errorf("数据库连接测试失败: %v", err)
- }
- return db, nil
- }
- // ExportTableToCSV 导出表数据到CSV
- func ExportTableToCSV(db *sql.DB, dbName string, tableName string) error {
- // 查询表数据
- query := fmt.Sprintf("SELECT * FROM %s ORDER BY id DESC", tableName)
- rows, err := db.Query(query)
- if err != nil {
- return fmt.Errorf("查询表 %s 失败: %v", tableName, err)
- }
- defer rows.Close()
- // 获取列名
- columns, err := rows.Columns()
- if err != nil {
- return fmt.Errorf("获取列名失败: %v", err)
- }
- // 创建输出目录
- outputDir := filepath.Join("output", dbName)
- if err := os.MkdirAll(outputDir, 0755); err != nil {
- return fmt.Errorf("创建输出目录失败: %v", err)
- }
- // 创建CSV文件
- filePath := filepath.Join(outputDir, tableName+".csv")
- file, err := os.Create(filePath)
- if err != nil {
- return fmt.Errorf("创建CSV文件失败: %v", err)
- }
- defer file.Close()
- writer := csv.NewWriter(file)
- defer writer.Flush()
- // 写入列名
- if err := writer.Write(columns); err != nil {
- return fmt.Errorf("写入列名失败: %v", err)
- }
- // 准备接收数据的切片
- values := make([]interface{}, len(columns))
- valuePtrs := make([]interface{}, len(columns))
- for i := range columns {
- valuePtrs[i] = &values[i]
- }
- // 读取数据并写入CSV
- recordCount := 0
- for rows.Next() {
- err := rows.Scan(valuePtrs...)
- if err != nil {
- return fmt.Errorf("读取数据失败: %v", err)
- }
- // 转换数据为字符串
- record := make([]string, len(columns))
- for i, val := range values {
- if val == nil {
- record[i] = ""
- } else {
- record[i] = fmt.Sprintf("%v", val)
- }
- }
- // 写入CSV
- if err := writer.Write(record); err != nil {
- return fmt.Errorf("写入CSV数据失败: %v", err)
- }
- recordCount++
- }
- // 检查遍历过程中是否有错误
- if err := rows.Err(); err != nil {
- return fmt.Errorf("遍历数据失败: %v", err)
- }
- fmt.Printf("表 %s 导出完成,共 %d 条数据,文件位置: %s\n", tableName, recordCount, filePath)
- return nil
- }
- // ExportAllTables 导出所有表到CSV
- func ExportAllTables(db *sql.DB, dbName string, tables []string) error {
- for _, table := range tables {
- fmt.Printf("正在导出表: %s\n", table)
- if err := ExportTableToCSV(db, dbName, table); err != nil {
- return fmt.Errorf("导出表 %s 失败: %v", table, err)
- }
- }
- return nil
- }
|