database.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package utils
  2. import (
  3. "database/sql"
  4. "encoding/csv"
  5. "fmt"
  6. "os"
  7. "path/filepath"
  8. _ "github.com/lib/pq"
  9. )
  10. // ConnectDB 连接数据库
  11. func ConnectDB(config map[string]string) (*sql.DB, error) {
  12. // 构建连接字符串
  13. connStr := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
  14. config["host"], config["port"], config["username"], config["password"], config["database"])
  15. db, err := sql.Open("postgres", connStr)
  16. if err != nil {
  17. return nil, fmt.Errorf("数据库连接失败: %v", err)
  18. }
  19. // 测试连接
  20. err = db.Ping()
  21. if err != nil {
  22. return nil, fmt.Errorf("数据库连接测试失败: %v", err)
  23. }
  24. return db, nil
  25. }
  26. // ExportTableToCSV 导出表数据到CSV
  27. func ExportTableToCSV(db *sql.DB, dbName string, tableName string) error {
  28. // 查询表数据
  29. query := fmt.Sprintf("SELECT * FROM %s ORDER BY id DESC", tableName)
  30. rows, err := db.Query(query)
  31. if err != nil {
  32. return fmt.Errorf("查询表 %s 失败: %v", tableName, err)
  33. }
  34. defer rows.Close()
  35. // 获取列名
  36. columns, err := rows.Columns()
  37. if err != nil {
  38. return fmt.Errorf("获取列名失败: %v", err)
  39. }
  40. // 创建输出目录
  41. outputDir := filepath.Join("output", dbName)
  42. if err := os.MkdirAll(outputDir, 0755); err != nil {
  43. return fmt.Errorf("创建输出目录失败: %v", err)
  44. }
  45. // 创建CSV文件
  46. filePath := filepath.Join(outputDir, tableName+".csv")
  47. file, err := os.Create(filePath)
  48. if err != nil {
  49. return fmt.Errorf("创建CSV文件失败: %v", err)
  50. }
  51. defer file.Close()
  52. writer := csv.NewWriter(file)
  53. defer writer.Flush()
  54. // 写入列名
  55. if err := writer.Write(columns); err != nil {
  56. return fmt.Errorf("写入列名失败: %v", err)
  57. }
  58. // 准备接收数据的切片
  59. values := make([]interface{}, len(columns))
  60. valuePtrs := make([]interface{}, len(columns))
  61. for i := range columns {
  62. valuePtrs[i] = &values[i]
  63. }
  64. // 读取数据并写入CSV
  65. recordCount := 0
  66. for rows.Next() {
  67. err := rows.Scan(valuePtrs...)
  68. if err != nil {
  69. return fmt.Errorf("读取数据失败: %v", err)
  70. }
  71. // 转换数据为字符串
  72. record := make([]string, len(columns))
  73. for i, val := range values {
  74. if val == nil {
  75. record[i] = ""
  76. } else {
  77. record[i] = fmt.Sprintf("%v", val)
  78. }
  79. }
  80. // 写入CSV
  81. if err := writer.Write(record); err != nil {
  82. return fmt.Errorf("写入CSV数据失败: %v", err)
  83. }
  84. recordCount++
  85. }
  86. // 检查遍历过程中是否有错误
  87. if err := rows.Err(); err != nil {
  88. return fmt.Errorf("遍历数据失败: %v", err)
  89. }
  90. fmt.Printf("表 %s 导出完成,共 %d 条数据,文件位置: %s\n", tableName, recordCount, filePath)
  91. return nil
  92. }
  93. // ExportAllTables 导出所有表到CSV
  94. func ExportAllTables(db *sql.DB, dbName string, tables []string) error {
  95. for _, table := range tables {
  96. fmt.Printf("正在导出表: %s\n", table)
  97. if err := ExportTableToCSV(db, dbName, table); err != nil {
  98. return fmt.Errorf("导出表 %s 失败: %v", table, err)
  99. }
  100. }
  101. return nil
  102. }