|
|
@@ -0,0 +1,122 @@
|
|
|
+package utils
|
|
|
+
|
|
|
+import (
|
|
|
+ "database/sql"
|
|
|
+ "encoding/csv"
|
|
|
+ "fmt"
|
|
|
+ "os"
|
|
|
+ "path/filepath"
|
|
|
+
|
|
|
+ _ "github.com/lib/pq"
|
|
|
+)
|
|
|
+
|
|
|
+// ConnectDB 连接数据库
|
|
|
+func ConnectDB(config map[string]string) (*sql.DB, error) {
|
|
|
+ // 构建连接字符串
|
|
|
+ connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
|
|
|
+ config["host"], config["port"], config["username"], config["password"], config["database"])
|
|
|
+
|
|
|
+ db, err := sql.Open("postgres", connStr)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("数据库连接失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 测试连接
|
|
|
+ err = db.Ping()
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("数据库连接测试失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return db, nil
|
|
|
+}
|
|
|
+
|
|
|
+// ExportTableToCSV 导出表数据到CSV
|
|
|
+func ExportTableToCSV(db *sql.DB, dbName string, tableName string) error {
|
|
|
+ // 查询表数据
|
|
|
+ query := fmt.Sprintf("SELECT * FROM %s ORDER BY id DESC", tableName)
|
|
|
+ rows, err := db.Query(query)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("查询表 %s 失败: %v", tableName, err)
|
|
|
+ }
|
|
|
+ defer rows.Close()
|
|
|
+
|
|
|
+ // 获取列名
|
|
|
+ columns, err := rows.Columns()
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("获取列名失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 创建输出目录
|
|
|
+ outputDir := filepath.Join("output", dbName)
|
|
|
+ if err := os.MkdirAll(outputDir, 0755); err != nil {
|
|
|
+ return fmt.Errorf("创建输出目录失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 创建CSV文件
|
|
|
+ filePath := filepath.Join(outputDir, tableName+".csv")
|
|
|
+ file, err := os.Create(filePath)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("创建CSV文件失败: %v", err)
|
|
|
+ }
|
|
|
+ defer file.Close()
|
|
|
+
|
|
|
+ writer := csv.NewWriter(file)
|
|
|
+ defer writer.Flush()
|
|
|
+
|
|
|
+ // 写入列名
|
|
|
+ if err := writer.Write(columns); err != nil {
|
|
|
+ return fmt.Errorf("写入列名失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 准备接收数据的切片
|
|
|
+ values := make([]interface{}, len(columns))
|
|
|
+ valuePtrs := make([]interface{}, len(columns))
|
|
|
+ for i := range columns {
|
|
|
+ valuePtrs[i] = &values[i]
|
|
|
+ }
|
|
|
+
|
|
|
+ // 读取数据并写入CSV
|
|
|
+ recordCount := 0
|
|
|
+ for rows.Next() {
|
|
|
+ err := rows.Scan(valuePtrs...)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("读取数据失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 转换数据为字符串
|
|
|
+ record := make([]string, len(columns))
|
|
|
+ for i, val := range values {
|
|
|
+ if val == nil {
|
|
|
+ record[i] = ""
|
|
|
+ } else {
|
|
|
+ record[i] = fmt.Sprintf("%v", val)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 写入CSV
|
|
|
+ if err := writer.Write(record); err != nil {
|
|
|
+ return fmt.Errorf("写入CSV数据失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ recordCount++
|
|
|
+ }
|
|
|
+
|
|
|
+ // 检查遍历过程中是否有错误
|
|
|
+ if err := rows.Err(); err != nil {
|
|
|
+ return fmt.Errorf("遍历数据失败: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Printf("表 %s 导出完成,共 %d 条数据,文件位置: %s\n", tableName, recordCount, filePath)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+// ExportAllTables 导出所有表到CSV
|
|
|
+func ExportAllTables(db *sql.DB, dbName string, tables []string) error {
|
|
|
+ for _, table := range tables {
|
|
|
+ fmt.Printf("正在导出表: %s\n", table)
|
|
|
+ if err := ExportTableToCSV(db, dbName, table); err != nil {
|
|
|
+ return fmt.Errorf("导出表 %s 失败: %v", table, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|