upupa_dataist_ir/internal/store/sqlite.go

340 lines
7.0 KiB
Go

// upupa_dataist_ir/internal/store/sqlite.go
package store
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"strings"
"upupa_dataist_ir/pkg/models"
_ "github.com/mattn/go-sqlite3"
)
type SQLiteStore struct {
db *sql.DB
}
func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
db, err := sql.Open("sqlite3", dbPath+"?_loc=UTC")
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
return &SQLiteStore{db: db}, nil
}
func (s *SQLiteStore) InitSchema() error {
query := `
CREATE TABLE IF NOT EXISTS collections (
id TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
system BOOLEAN DEFAULT FALSE
);
CREATE TABLE IF NOT EXISTS records (
id TEXT PRIMARY KEY,
collection_id TEXT NOT NULL,
created DATETIME DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
updated DATETIME DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
data TEXT NOT NULL,
FOREIGN KEY (collection_id) REFERENCES collections (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_records_collection ON records (collection_id);
`
_, err := s.db.Exec(query)
return err
}
func (s *SQLiteStore) GetByName(name string) (models.Collection, error) {
query := "SELECT id, name, system FROM collections WHERE name = ?"
var c models.Collection
row := s.db.QueryRow(query, name)
err := row.Scan(&c.ID, &c.Name, &c.System)
if err == sql.ErrNoRows {
return c, err
}
return c, err
}
func (s *SQLiteStore) CreateCollection(c models.Collection) (models.Collection, error) {
query := `INSERT INTO collections (id, name, system) VALUES (?, ?, ?)`
_, err := s.db.Exec(query, c.ID, c.Name, c.System)
if err != nil {
return c, err
}
return c, nil
}
func (s *SQLiteStore) CreateRecord(r models.Record) (models.Record, error) {
jsonData, err := marshalRecordData(r.Data)
if err != nil {
return r, err
}
query := `INSERT INTO records
(id, collection_id, data)
VALUES (?, ?, ?)`
_, err = s.db.Exec(query, r.ID, r.CollectionID, jsonData)
if err != nil {
return r, err
}
return s.GetByID(r.CollectionID, r.ID)
}
func (s *SQLiteStore) GetAll(collectionID string, filters map[string]interface{}, orderBy string, limit, offset int) ([]models.Record, int, error) {
baseQuery := `SELECT id, collection_id, created, updated, data FROM records WHERE collection_id = ?`
args := []interface{}{collectionID}
filterClause := ""
if len(filters) > 0 {
for key, value := range filters {
field, op, err := parseFilterKey(key)
if err != nil {
log.Printf("Ignoring invalid filter key: %s, Error: %v", key, err)
continue
}
caster := getFieldCaster(field)
filterClause += fmt.Sprintf(" AND %s %s ?", caster, op)
args = append(args, value)
}
}
countQuery := "SELECT COUNT(*) FROM records WHERE collection_id = ?" + filterClause
var totalCount int
err := s.db.QueryRow(countQuery, args...).Scan(&totalCount)
if err != nil {
return nil, 0, err
}
if limit == 0 {
return []models.Record{}, totalCount, nil
}
baseQuery += filterClause
sortClause := " ORDER BY created DESC"
if orderBy != "" {
sortClause = buildOrderClause(orderBy)
}
baseQuery += sortClause
baseQuery += fmt.Sprintf(" LIMIT %d OFFSET %d", limit, offset)
rows, err := s.db.Query(baseQuery, args...)
if err != nil {
return nil, 0, err
}
defer rows.Close()
records := []models.Record{}
for rows.Next() {
var (
r models.Record
dataStr string
)
err := rows.Scan(&r.ID, &r.CollectionID, &r.Created, &r.Updated, &dataStr)
if err != nil {
return nil, 0, err
}
data, err := unmarshalRecordData(dataStr)
if err != nil {
return nil, 0, err
}
r.Data = data
records = append(records, r)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return records, totalCount, nil
}
func (s *SQLiteStore) GetByID(collectionID, recordID string) (models.Record, error) {
query := `SELECT id, collection_id, created, updated, data
FROM records
WHERE collection_id = ? AND id = ?`
row := s.db.QueryRow(query, collectionID, recordID)
var (
r models.Record
dataStr string
)
err := row.Scan(&r.ID, &r.CollectionID, &r.Created, &r.Updated, &dataStr)
if err == sql.ErrNoRows {
return r, err
}
if err != nil {
return r, err
}
data, err := unmarshalRecordData(dataStr)
if err != nil {
return r, err
}
r.Data = data
return r, nil
}
func (s *SQLiteStore) Update(r models.Record) error {
jsonData, err := marshalRecordData(r.Data)
if err != nil {
return err
}
query := `UPDATE records SET data = ?, updated = (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')) WHERE collection_id = ? AND id = ?`
result, err := s.db.Exec(query, jsonData, r.CollectionID, r.ID)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return sql.ErrNoRows
}
return nil
}
func (s *SQLiteStore) Delete(collectionID, recordID string) error {
query := `DELETE FROM records WHERE collection_id = ? AND id = ?`
result, err := s.db.Exec(query, collectionID, recordID)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return sql.ErrNoRows
}
return nil
}
func buildOrderClause(orderBy string) string {
orderBy = strings.TrimSpace(orderBy)
if orderBy == "" {
return " ORDER BY created DESC"
}
field := orderBy
direction := "ASC"
if strings.HasPrefix(orderBy, "-") {
field = orderBy[1:]
direction = "DESC"
}
caster := getFieldCaster(field)
return fmt.Sprintf(" ORDER BY %s %s", caster, direction)
}
func getFieldCaster(field string) string {
if field == "price" || field == "stock" {
return fmt.Sprintf("CAST(json_extract(data, '$.%s') AS REAL)", field)
}
if field == "releaseDate" {
return fmt.Sprintf("JULIANDAY(json_extract(data, '$.%s'))", field)
}
return fmt.Sprintf("json_extract(data, '$.%s')", field)
}
var validOperators = map[string]string{
"eq": "=",
"gt": ">",
"lt": "<",
"gte": ">=",
"lte": "<=",
"neq": "!=",
}
func parseFilterKey(key string) (field string, op string, err error) {
start := -1
end := -1
for i, r := range key {
if r == '[' {
start = i
} else if r == ']' {
end = i
break
}
}
if start == -1 || end == -1 || end <= start+1 || end != len(key)-1 {
return "", "", fmt.Errorf("invalid filter format: %s. Expected field[op]", key)
}
operatorCode := key[start+1 : end]
operator, ok := validOperators[operatorCode]
if !ok {
return "", "", fmt.Errorf("invalid operator: %s", operatorCode)
}
field = key[:start]
if field == "" {
return "", "", fmt.Errorf("field name cannot be empty")
}
return field, operator, nil
}
func marshalRecordData(data map[string]interface{}) (string, error) {
jsonData, err := json.Marshal(data)
if err != nil {
return "", fmt.Errorf("failed to marshal record data: %w", err)
}
return string(jsonData), nil
}
func unmarshalRecordData(dataStr string) (map[string]interface{}, error) {
var data map[string]interface{}
err := json.Unmarshal([]byte(dataStr), &data)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal record data: %w", err)
}
return data, nil
}