diff --git a/.golangci.yml b/.golangci.yml index fd2e9ac..e3cf2f1 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -5,16 +5,76 @@ run: linters: enable: + # complexity - gocyclo - - misspell + + # correctness - bodyclose + - errorlint + - exhaustive + - nilerr + - noctx + + # security + - gosec + + # performance + - intrange + - perfsprint - prealloc + - unconvert + + # code hygiene + - goconst + - misspell + - protogetter + - revive + - unparam + settings: gocyclo: min-complexity: 15 + goconst: + min-len: 3 + min-occurrences: 3 + exhaustive: + default-signifies-exhaustive: true + gosec: + excludes: + - G104 # unhandled errors — covered by errcheck (default linter) + - G115 # integer overflow conversion — too noisy for blockchain height math + - G304 # file inclusion via variable — expected for config loading + - G306 # file permissions > 0600 — config templates, not secrets + revive: + rules: + - name: blank-imports + - name: context-as-argument + - name: dot-imports + - name: error-return + - name: error-strings + - name: error-naming + - name: exported + disabled: true + - name: increment-decrement + - name: var-naming + - name: range + - name: receiver-naming + - name: indent-error-flow + - name: empty-block + - name: superfluous-else + - name: unreachable-code + - name: redefines-builtin-id + exclusions: + rules: + # Standard Go package names are fine for internal packages. + - linters: + - revive + text: "avoid.*package names" paths: - vendor + - ".*\\.pb\\.go$" + - ".*\\.pb\\.gw\\.go$" formatters: enable: diff --git a/README.md b/README.md index 03d28e4..08ea9e5 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ just check # tidy + lint + test + build ## Documentation - [Running Apex](docs/running.md) -- setup, configuration, Docker, CLI +- [API Compatibility Policy](docs/api-compat.md) -- JSON-RPC compatibility boundary vs gRPC evolution ## License diff --git a/cmd/apex-loadtest/main.go b/cmd/apex-loadtest/main.go new file mode 100644 index 0000000..ad7e6b5 --- /dev/null +++ b/cmd/apex-loadtest/main.go @@ -0,0 +1,397 @@ +// apex-loadtest tests JSON-RPC throughput and WebSocket subscriber limits +// against a running apex indexer. +// +// Usage: +// +// apex-loadtest -addr stg-devnet-collect:8080 -rpc-concurrency 50 -rpc-duration 10s +// apex-loadtest -addr stg-devnet-collect:8080 -subscribers 500 -sub-duration 30s +// apex-loadtest -addr stg-devnet-collect:8080 -all +package main + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "net/http" + "os" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" +) + +func main() { + addr := flag.String("addr", "localhost:8080", "apex JSON-RPC address (host:port)") + runAll := flag.Bool("all", false, "run all tests") + + // RPC throughput flags + rpcConcurrency := flag.Int("rpc-concurrency", 50, "number of concurrent RPC workers") + rpcDuration := flag.Duration("rpc-duration", 10*time.Second, "duration of RPC throughput test") + rpcHeight := flag.Uint64("rpc-height", 0, "height to query (0 = fetch from health)") + + // Subscriber flags + subCount := flag.Int("subscribers", 200, "number of WebSocket subscribers to open") + subDuration := flag.Duration("sub-duration", 30*time.Second, "how long to hold subscribers open") + subBatch := flag.Int("sub-batch", 50, "subscribers to add per batch") + + flag.Parse() + + if !*runAll && flag.NArg() == 0 && !isFlagSet("rpc-concurrency") && !isFlagSet("subscribers") { + *runAll = true + } + + height := *rpcHeight + if height == 0 { + h, err := fetchCurrentHeight(*addr) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to get current height: %v\n", err) + os.Exit(1) + } + height = h + fmt.Printf("using height %d from health endpoint\n\n", height) + } + + if *runAll || isFlagSet("rpc-concurrency") || isFlagSet("rpc-duration") { + runRPCThroughput(*addr, height, *rpcConcurrency, *rpcDuration) + fmt.Println() + } + + if *runAll || isFlagSet("subscribers") || isFlagSet("sub-duration") { + runSubscriberTest(*addr, *subCount, *subBatch, *subDuration) + } +} + +func isFlagSet(name string) bool { + found := false + flag.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} + +// --- RPC Throughput Test --- + +type rpcRequest struct { + Jsonrpc string `json:"jsonrpc"` + Method string `json:"method"` + Params []any `json:"params"` + ID int `json:"id"` +} + +func runRPCThroughput(addr string, height uint64, concurrency int, duration time.Duration) { + fmt.Printf("=== RPC Throughput Test ===\n") + fmt.Printf("target: %s, concurrency: %d, duration: %s\n", addr, concurrency, duration) + fmt.Printf("methods: header.GetByHeight, header.LocalHead, blob.GetAll\n\n") + + methods := []struct { + name string + params []any + }{ + {"header.GetByHeight", []any{height}}, + {"header.LocalHead", []any{}}, + {"blob.GetAll", []any{height, nil}}, + } + + for _, m := range methods { + stats := runMethodBench(addr, m.name, m.params, concurrency, duration) + printStats(m.name, stats) + } +} + +type benchStats struct { + total int64 + errors int64 + duration time.Duration + latP50 time.Duration + latP95 time.Duration + latP99 time.Duration + latMax time.Duration +} + +type sub struct { + conn *websocket.Conn + id int +} + +func runMethodBench(addr, method string, params []any, concurrency int, duration time.Duration) benchStats { + body, _ := json.Marshal(rpcRequest{ + Jsonrpc: "2.0", + Method: method, + Params: params, + ID: 1, + }) + + url := "http://" + addr + client := &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: concurrency * 2, + MaxIdleConnsPerHost: concurrency * 2, + IdleConnTimeout: 90 * time.Second, + }, + Timeout: 10 * time.Second, + } + + var total, errCount atomic.Int64 + latencies := make([]time.Duration, 0, 100000) + var latMu sync.Mutex + + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + + var wg sync.WaitGroup + start := time.Now() + + for range concurrency { + wg.Add(1) + go func() { + defer wg.Done() + for { + if ctx.Err() != nil { + return + } + + t0 := time.Now() + req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if reqErr != nil { + errCount.Add(1) + total.Add(1) + continue + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) //nolint:gosec // URL from CLI flag + lat := time.Since(t0) + + if err != nil { + if ctx.Err() != nil { + return + } + errCount.Add(1) + total.Add(1) + continue + } + if _, copyErr := io.Copy(io.Discard, resp.Body); copyErr != nil { + errCount.Add(1) + } + if closeErr := resp.Body.Close(); closeErr != nil { + errCount.Add(1) + } + + total.Add(1) + if resp.StatusCode != 200 { + errCount.Add(1) + } + + latMu.Lock() + latencies = append(latencies, lat) + latMu.Unlock() + } + }() + } + + wg.Wait() + elapsed := time.Since(start) + + sort.Slice(latencies, func(i, j int) bool { return latencies[i] < latencies[j] }) + + stats := benchStats{ + total: total.Load(), + errors: errCount.Load(), + duration: elapsed, + } + if n := len(latencies); n > 0 { + stats.latP50 = latencies[n*50/100] + stats.latP95 = latencies[n*95/100] + stats.latP99 = latencies[n*99/100] + stats.latMax = latencies[n-1] + } + return stats +} + +func printStats(method string, s benchStats) { + rps := float64(s.total) / s.duration.Seconds() + fmt.Printf(" %-25s %6d reqs %8.1f req/s err=%d p50=%s p95=%s p99=%s max=%s\n", + method, s.total, rps, s.errors, + s.latP50.Round(100*time.Microsecond), + s.latP95.Round(100*time.Microsecond), + s.latP99.Round(100*time.Microsecond), + s.latMax.Round(100*time.Microsecond), + ) +} + +// --- Subscriber Test --- + +func runSubscriberTest(addr string, target, batch int, duration time.Duration) { + fmt.Printf("=== Subscriber Stress Test ===\n") + fmt.Printf("target: %s, max subscribers: %d, batch: %d, hold: %s\n\n", addr, target, batch, duration) + + wsURL := "ws://" + addr + + var subs []sub + var connected, failed, eventsReceived atomic.Int64 + + // Open subscribers in batches + for i := 0; i < target; i += batch { + end := i + batch + if end > target { + end = target + } + + batchSubs, batchConnected, batchFailed := openSubscriberBatch(wsURL, i, end) + subs = append(subs, batchSubs...) + connected.Add(batchConnected) + failed.Add(batchFailed) + + fmt.Printf(" batch %d-%d: connected=%d failed=%d\n", + i, end-1, connected.Load(), failed.Load()) + } + + totalConnected := connected.Load() + fmt.Printf("\n total connected: %d / %d\n", totalConnected, target) + + if totalConnected == 0 { + fmt.Println(" no subscribers connected, skipping event collection") + return + } + + // Read events from all subscribers for the duration + readCtx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + + var wg sync.WaitGroup + for _, s := range subs { + wg.Add(1) + go func(s sub) { + defer wg.Done() + if err := s.conn.SetReadDeadline(time.Now().Add(duration + 5*time.Second)); err != nil { + return + } + for { + if readCtx.Err() != nil { + return + } + _, _, err := s.conn.ReadMessage() + if err != nil { + return + } + eventsReceived.Add(1) + } + }(s) + } + + // Print progress every 5s + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + go func() { + for { + select { + case <-readCtx.Done(): + return + case <-ticker.C: + fmt.Printf(" events received: %d\n", eventsReceived.Load()) + } + } + }() + + wg.Wait() + + fmt.Printf("\n final: %d events from %d subscribers over %s\n", + eventsReceived.Load(), totalConnected, duration) + fmt.Printf(" avg events/subscriber: %.1f\n", + float64(eventsReceived.Load())/float64(totalConnected)) + + // Cleanup + for _, s := range subs { + _ = s.conn.Close() + } +} + +func openSubscriberBatch(wsURL string, start, end int) ([]sub, int64, int64) { + var ( + wg sync.WaitGroup + mu sync.Mutex + subs []sub + connected atomic.Int64 + failed atomic.Int64 + ) + + for j := start; j < end; j++ { + wg.Add(1) + id := j + go func() { + defer wg.Done() + conn, err := dialAndSubscribe(wsURL, id) + if err != nil { + failed.Add(1) + if id < 5 || id%100 == 0 { + fmt.Fprintf(os.Stderr, " sub %d: dial failed: %v\n", id, err) + } + return + } + connected.Add(1) + mu.Lock() + subs = append(subs, sub{conn: conn, id: id}) + mu.Unlock() + }() + } + wg.Wait() + return subs, connected.Load(), failed.Load() +} + +func dialAndSubscribe(wsURL string, id int) (*websocket.Conn, error) { + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + conn, resp, err := dialer.Dial(wsURL, nil) + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + if err != nil { + return nil, err + } + + req := rpcRequest{ + Jsonrpc: "2.0", + Method: "header.Subscribe", + Params: []any{}, + ID: id + 1, + } + if err := conn.WriteJSON(req); err != nil { + closeErr := conn.Close() + return nil, errors.Join(err, closeErr) + } + return conn, nil +} + +// --- Helpers --- + +func fetchCurrentHeight(addr string) (uint64, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://"+addr+"/health", nil) + if err != nil { + return 0, err + } + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) //nolint:gosec // URL from CLI flag + if err != nil { + return 0, err + } + defer func() { _ = resp.Body.Close() }() + + var health struct { + LatestHeight uint64 `json:"latest_height"` + } + if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { + return 0, err + } + return health.LatestHeight, nil +} diff --git a/cmd/apex/client.go b/cmd/apex/client.go index e35b7b0..9a49330 100644 --- a/cmd/apex/client.go +++ b/cmd/apex/client.go @@ -63,7 +63,7 @@ func (c *rpcClient) call(ctx context.Context, method string, params ...any) (jso } req.Header.Set("Content-Type", "application/json") - resp, err := c.client.Do(req) + resp, err := c.client.Do(req) //nolint:gosec // URL comes from user-configured --url flag, not untrusted input if err != nil { return nil, fmt.Errorf("send request: %w", err) } @@ -93,7 +93,7 @@ func (c *rpcClient) fetchHealth(ctx context.Context) (json.RawMessage, error) { return nil, fmt.Errorf("create request: %w", err) } - resp, err := c.client.Do(req) + resp, err := c.client.Do(req) //nolint:gosec // URL comes from user-configured --url flag, not untrusted input if err != nil { return nil, fmt.Errorf("send request: %w", err) } diff --git a/cmd/apex/main.go b/cmd/apex/main.go index 33d5d79..418f345 100644 --- a/cmd/apex/main.go +++ b/cmd/apex/main.go @@ -32,6 +32,8 @@ import ( // Set via ldflags at build time. var version = "dev" +const dataSourceTypeApp = "app" + func main() { if err := rootCmd().Execute(); err != nil { os.Exit(1) @@ -112,7 +114,7 @@ func startCmd() *cobra.Command { Str("version", version). Str("datasource_type", cfg.DataSource.Type). Int("namespaces", len(cfg.DataSource.Namespaces)) - if cfg.DataSource.Type == "app" { + if cfg.DataSource.Type == dataSourceTypeApp { startLog = startLog.Str("app_grpc_addr", cfg.DataSource.CelestiaAppGRPCAddr) } else { startLog = startLog.Str("node_url", cfg.DataSource.CelestiaNodeURL) @@ -168,7 +170,7 @@ func setupProfiling(cfg *config.Config) *profile.Server { func openDataSource(ctx context.Context, cfg *config.Config) (fetch.DataFetcher, fetch.ProofForwarder, error) { switch cfg.DataSource.Type { - case "app": + case dataSourceTypeApp: appFetcher, err := fetch.NewCelestiaAppFetcher(cfg.DataSource.CelestiaAppGRPCAddr, cfg.DataSource.AuthToken, log.Logger) if err != nil { return nil, nil, fmt.Errorf("create celestia-app fetcher: %w", err) @@ -215,7 +217,7 @@ func persistNamespaces(ctx context.Context, db store.Store, namespaces []types.N } func maybeBackfillSourceOption(cfg *config.Config, logger zerolog.Logger) (syncer.Option, func(), error) { - if cfg.DataSource.Type != "app" || cfg.DataSource.BackfillSource != "db" { + if cfg.DataSource.Type != dataSourceTypeApp || cfg.DataSource.BackfillSource != "db" { return nil, nil, nil } @@ -328,7 +330,7 @@ func runIndexer(ctx context.Context, cfg *config.Config) error { // Start gRPC server. grpcSrv := grpcapi.NewServer(svc, log.Logger) - lis, err := net.Listen("tcp", cfg.RPC.GRPCListenAddr) + lis, err := (&net.ListenConfig{}).Listen(ctx, "tcp", cfg.RPC.GRPCListenAddr) if err != nil { _ = httpSrv.Close() return fmt.Errorf("listen gRPC: %w", err) diff --git a/config/config.go b/config/config.go index 1616945..3bf23bb 100644 --- a/config/config.go +++ b/config/config.go @@ -29,7 +29,7 @@ type DataSourceConfig struct { CelestiaAppDBPath string `yaml:"celestia_app_db_path"` // required when backfill_source=db CelestiaAppDBBackend string `yaml:"celestia_app_db_backend"` // auto|pebble|leveldb CelestiaAppDBLayout string `yaml:"celestia_app_db_layout"` // auto|v1|v2 - AuthToken string `yaml:"-"` // populated only via APEX_AUTH_TOKEN env var + AuthToken string `yaml:"-"` //nolint:gosec // populated only via APEX_AUTH_TOKEN env var; not a hardcoded credential Namespaces []string `yaml:"namespaces"` } diff --git a/config/load.go b/config/load.go index 2bbbdae..b449975 100644 --- a/config/load.go +++ b/config/load.go @@ -10,6 +10,8 @@ import ( "gopkg.in/yaml.v3" ) +const configValueAuto = "auto" + // Generate writes a default config file with comments to the given path. // Returns an error if the file already exists. func Generate(path string) error { @@ -149,34 +151,34 @@ func validateDataSource(ds *DataSourceConfig) error { switch ds.Type { case "node", "": if ds.CelestiaNodeURL == "" { - return fmt.Errorf("data_source.celestia_node_url is required for type \"node\"") + return errors.New("data_source.celestia_node_url is required for type \"node\"") } case "app": if ds.CelestiaAppGRPCAddr == "" { - return fmt.Errorf("data_source.celestia_app_grpc_addr is required for type \"app\"") + return errors.New("data_source.celestia_app_grpc_addr is required for type \"app\"") } if ds.BackfillSource == "" { ds.BackfillSource = "rpc" } if ds.CelestiaAppDBBackend == "" { - ds.CelestiaAppDBBackend = "auto" + ds.CelestiaAppDBBackend = configValueAuto } if ds.CelestiaAppDBLayout == "" { - ds.CelestiaAppDBLayout = "auto" + ds.CelestiaAppDBLayout = configValueAuto } switch ds.BackfillSource { case "rpc": case "db": if ds.CelestiaAppDBPath == "" { - return fmt.Errorf("data_source.celestia_app_db_path is required when data_source.backfill_source is \"db\"") + return errors.New("data_source.celestia_app_db_path is required when data_source.backfill_source is \"db\"") } switch ds.CelestiaAppDBBackend { - case "auto", "pebble", "leveldb": + case configValueAuto, "pebble", "leveldb": default: return fmt.Errorf("data_source.celestia_app_db_backend %q is invalid; must be auto|pebble|leveldb", ds.CelestiaAppDBBackend) } switch ds.CelestiaAppDBLayout { - case "auto", "v1", "v2": + case configValueAuto, "v1", "v2": default: return fmt.Errorf("data_source.celestia_app_db_layout %q is invalid; must be auto|v1|v2", ds.CelestiaAppDBLayout) } @@ -198,23 +200,23 @@ func validateStorage(s *StorageConfig) error { switch s.Type { case "s3": if s.S3 == nil { - return fmt.Errorf("storage.s3 is required when storage.type is \"s3\"") + return errors.New("storage.s3 is required when storage.type is \"s3\"") } if s.S3.Bucket == "" { - return fmt.Errorf("storage.s3.bucket is required") + return errors.New("storage.s3.bucket is required") } if s.S3.Region == "" && s.S3.Endpoint == "" { - return fmt.Errorf("storage.s3.region is required (unless endpoint is set)") + return errors.New("storage.s3.region is required (unless endpoint is set)") } if s.S3.ChunkSize == 0 { s.S3.ChunkSize = 64 } if s.S3.ChunkSize < 0 { - return fmt.Errorf("storage.s3.chunk_size must be positive") + return errors.New("storage.s3.chunk_size must be positive") } case "sqlite", "": if s.DBPath == "" { - return fmt.Errorf("storage.db_path is required") + return errors.New("storage.db_path is required") } default: return fmt.Errorf("storage.type %q is invalid; must be \"sqlite\" or \"s3\"", s.Type) @@ -256,50 +258,50 @@ func validate(cfg *Config) error { func validateRPC(rpc *RPCConfig) error { if rpc.ListenAddr == "" { - return fmt.Errorf("rpc.listen_addr is required") + return errors.New("rpc.listen_addr is required") } if rpc.GRPCListenAddr == "" { - return fmt.Errorf("rpc.grpc_listen_addr is required") + return errors.New("rpc.grpc_listen_addr is required") } if rpc.ReadTimeout < 0 { - return fmt.Errorf("rpc.read_timeout must be non-negative") + return errors.New("rpc.read_timeout must be non-negative") } if rpc.WriteTimeout < 0 { - return fmt.Errorf("rpc.write_timeout must be non-negative") + return errors.New("rpc.write_timeout must be non-negative") } return nil } func validateSync(sync *SyncConfig) error { if sync.BatchSize <= 0 { - return fmt.Errorf("sync.batch_size must be positive") + return errors.New("sync.batch_size must be positive") } if sync.Concurrency <= 0 { - return fmt.Errorf("sync.concurrency must be positive") + return errors.New("sync.concurrency must be positive") } return nil } func validateSubscription(sub *SubscriptionConfig) error { if sub.BufferSize <= 0 { - return fmt.Errorf("subscription.buffer_size must be positive") + return errors.New("subscription.buffer_size must be positive") } if sub.MaxSubscribers <= 0 { - return fmt.Errorf("subscription.max_subscribers must be positive") + return errors.New("subscription.max_subscribers must be positive") } return nil } func validateMetrics(m *MetricsConfig) error { if m.Enabled && m.ListenAddr == "" { - return fmt.Errorf("metrics.listen_addr is required when metrics are enabled") + return errors.New("metrics.listen_addr is required when metrics are enabled") } return nil } func validateProfiling(p *ProfilingConfig) error { if p.Enabled && p.ListenAddr == "" { - return fmt.Errorf("profiling.listen_addr is required when profiling is enabled") + return errors.New("profiling.listen_addr is required when profiling is enabled") } return nil } diff --git a/docs/api-compat.md b/docs/api-compat.md new file mode 100644 index 0000000..9aad45d --- /dev/null +++ b/docs/api-compat.md @@ -0,0 +1,40 @@ +# API Compatibility Policy + +Apex exposes two external API surfaces with different compatibility rules. + +## JSON-RPC + +JSON-RPC is the compatibility surface. + +- Method names and core response shapes are intended to remain compatible with `celestia-node`. +- Compatibility takes priority over API cleanup when the two are in tension. +- Behavioral quirks inherited for compatibility should be preserved unless upstream compatibility requirements change. +- Operational protections for JSON-RPC should be added outside the method contract when possible: + - auth + - reverse-proxy rate limits + - request size and timeout limits + +## gRPC + +gRPC is Apex-owned and may evolve independently. + +- gRPC may add stricter validation and explicit limits. +- gRPC should prefer stable, well-specified semantics over mirroring JSON-RPC quirks. +- Transport-level improvements should land in gRPC first when they do not belong to the shared domain model. + +Current intentional differences: + +- `BlobService.GetAll` enforces a namespace cap; JSON-RPC `blob.GetAll` does not. +- gRPC blob subscriptions emit only matching blob events; JSON-RPC remains compatibility-oriented. + +## Internal Service Boundary + +Shared correctness belongs in the internal service and storage layers. + +- Store invariants, sync correctness, and read semantics should not diverge by transport. +- Compatibility shims belong at the JSON-RPC adapter edge. +- gRPC handlers should call service methods instead of reaching directly into store or fetcher internals. + +## Testing Rule + +When a transport difference is intentional, it should be documented here and pinned by tests. diff --git a/docs/running.md b/docs/running.md index e267577..153de64 100644 --- a/docs/running.md +++ b/docs/running.md @@ -164,6 +164,9 @@ Celestia-node compatible methods: ### gRPC (port 9090) +gRPC is Apex-owned and may evolve independently from the JSON-RPC compatibility layer. +See [API Compatibility Policy](api-compat.md) for the boundary. + - `apex.v1.BlobService`: Get, GetAll, GetByCommitment, Subscribe (server-streaming) - `apex.v1.HeaderService`: GetByHeight, LocalHead, NetworkHead, Subscribe (server-streaming) diff --git a/pkg/api/grpc/blob_service.go b/pkg/api/grpc/blob_service.go index 0783b3d..b7f9bac 100644 --- a/pkg/api/grpc/blob_service.go +++ b/pkg/api/grpc/blob_service.go @@ -1,7 +1,6 @@ package grpcapi import ( - "bytes" "context" "errors" "fmt" @@ -23,31 +22,27 @@ type BlobServiceServer struct { } func (s *BlobServiceServer) Get(ctx context.Context, req *pb.GetRequest) (*pb.GetResponse, error) { - ns, err := bytesToNamespace(req.Namespace) + ns, err := bytesToNamespace(req.GetNamespace()) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid namespace: %v", err) } - blobs, err := s.svc.Store().GetBlobs(ctx, ns, req.Height, req.Height, 0, 0) + b, err := s.svc.GetBlob(ctx, req.GetHeight(), ns, req.GetCommitment()) if err != nil { - return nil, status.Errorf(codes.Internal, "get blobs: %v", err) - } - - for i := range blobs { - if bytes.Equal(blobs[i].Commitment, req.Commitment) { - return &pb.GetResponse{Blob: blobToProto(&blobs[i])}, nil + if errors.Is(err, store.ErrNotFound) { + return nil, status.Error(codes.NotFound, store.ErrNotFound.Error()) } + return nil, status.Errorf(codes.Internal, "get blob: %v", err) } - - return nil, status.Error(codes.NotFound, store.ErrNotFound.Error()) + return &pb.GetResponse{Blob: blobToProto(b)}, nil } func (s *BlobServiceServer) GetByCommitment(ctx context.Context, req *pb.GetByCommitmentRequest) (*pb.GetByCommitmentResponse, error) { - if len(req.Commitment) == 0 { + if len(req.GetCommitment()) == 0 { return nil, status.Error(codes.InvalidArgument, "commitment is required") } - b, err := s.svc.Store().GetBlobByCommitment(ctx, req.Commitment) + b, err := s.svc.GetBlobByCommitment(ctx, req.GetCommitment()) if err != nil { if errors.Is(err, store.ErrNotFound) { return nil, status.Error(codes.NotFound, store.ErrNotFound.Error()) @@ -60,12 +55,12 @@ func (s *BlobServiceServer) GetByCommitment(ctx context.Context, req *pb.GetByCo func (s *BlobServiceServer) GetAll(ctx context.Context, req *pb.GetAllRequest) (*pb.GetAllResponse, error) { const maxNamespaces = 16 - if len(req.Namespaces) > maxNamespaces { - return nil, status.Errorf(codes.InvalidArgument, "too many namespaces: %d (max %d)", len(req.Namespaces), maxNamespaces) + if len(req.GetNamespaces()) > maxNamespaces { + return nil, status.Errorf(codes.InvalidArgument, "too many namespaces: %d (max %d)", len(req.GetNamespaces()), maxNamespaces) } - nsList := make([]types.Namespace, len(req.Namespaces)) - for i, nsBytes := range req.Namespaces { + nsList := make([]types.Namespace, len(req.GetNamespaces())) + for i, nsBytes := range req.GetNamespaces() { ns, err := bytesToNamespace(nsBytes) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid namespace %d: %v", i, err) @@ -73,25 +68,9 @@ func (s *BlobServiceServer) GetAll(ctx context.Context, req *pb.GetAllRequest) ( nsList[i] = ns } - allBlobs := make([]types.Blob, 0, len(nsList)*8) - for _, ns := range nsList { - blobs, err := s.svc.Store().GetBlobs(ctx, ns, req.Height, req.Height, 0, 0) - if err != nil { - return nil, status.Errorf(codes.Internal, "get blobs: %v", err) - } - allBlobs = append(allBlobs, blobs...) - } - - // Apply pagination to the aggregate result. - if req.Offset > 0 { - if int(req.Offset) >= len(allBlobs) { - allBlobs = nil - } else { - allBlobs = allBlobs[req.Offset:] - } - } - if req.Limit > 0 && int(req.Limit) < len(allBlobs) { - allBlobs = allBlobs[:req.Limit] + allBlobs, err := s.svc.GetAllBlobs(ctx, req.GetHeight(), nsList, int(req.GetLimit()), int(req.GetOffset())) + if err != nil { + return nil, status.Errorf(codes.Internal, "get blobs: %v", err) } pbBlobs := make([]*pb.Blob, len(allBlobs)) @@ -103,7 +82,7 @@ func (s *BlobServiceServer) GetAll(ctx context.Context, req *pb.GetAllRequest) ( } func (s *BlobServiceServer) Subscribe(req *pb.BlobServiceSubscribeRequest, stream grpc.ServerStreamingServer[pb.BlobServiceSubscribeResponse]) error { - ns, err := bytesToNamespace(req.Namespace) + ns, err := bytesToNamespace(req.GetNamespace()) if err != nil { return status.Errorf(codes.InvalidArgument, "invalid namespace: %v", err) } @@ -123,6 +102,9 @@ func (s *BlobServiceServer) Subscribe(req *pb.BlobServiceSubscribeRequest, strea if !ok { return nil } + if len(ev.Blobs) == 0 { + continue + } pbBlobs := make([]*pb.Blob, len(ev.Blobs)) for i := range ev.Blobs { pbBlobs[i] = blobToProto(&ev.Blobs[i]) diff --git a/pkg/api/grpc/header_service.go b/pkg/api/grpc/header_service.go index e1ef40d..994f1e6 100644 --- a/pkg/api/grpc/header_service.go +++ b/pkg/api/grpc/header_service.go @@ -22,10 +22,10 @@ type HeaderServiceServer struct { } func (s *HeaderServiceServer) GetByHeight(ctx context.Context, req *pb.GetByHeightRequest) (*pb.GetByHeightResponse, error) { - hdr, err := s.svc.Store().GetHeader(ctx, req.Height) + hdr, err := s.svc.GetHeaderByHeight(ctx, req.GetHeight()) if err != nil { if errors.Is(err, store.ErrNotFound) { - return nil, status.Errorf(codes.NotFound, "header at height %d not found", req.Height) + return nil, status.Errorf(codes.NotFound, "header at height %d not found", req.GetHeight()) } return nil, status.Errorf(codes.Internal, "get header: %v", err) } @@ -33,25 +33,18 @@ func (s *HeaderServiceServer) GetByHeight(ctx context.Context, req *pb.GetByHeig } func (s *HeaderServiceServer) LocalHead(ctx context.Context, _ *pb.LocalHeadRequest) (*pb.LocalHeadResponse, error) { - ss, err := s.svc.Store().GetSyncState(ctx) + hdr, err := s.svc.GetLocalHead(ctx) if err != nil { if errors.Is(err, store.ErrNotFound) { - return nil, status.Errorf(codes.NotFound, "no sync state available") + return nil, status.Errorf(codes.NotFound, "no local head available") } - return nil, status.Errorf(codes.Internal, "get sync state: %v", err) - } - hdr, err := s.svc.Store().GetHeader(ctx, ss.LatestHeight) - if err != nil { - if errors.Is(err, store.ErrNotFound) { - return nil, status.Errorf(codes.NotFound, "header at height %d not found", ss.LatestHeight) - } - return nil, status.Errorf(codes.Internal, "get header: %v", err) + return nil, status.Errorf(codes.Internal, "get local head: %v", err) } return &pb.LocalHeadResponse{Header: headerToProto(hdr)}, nil } func (s *HeaderServiceServer) NetworkHead(ctx context.Context, _ *pb.NetworkHeadRequest) (*pb.NetworkHeadResponse, error) { - hdr, err := s.svc.Fetcher().GetNetworkHead(ctx) + hdr, err := s.svc.GetNetworkHead(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "get network head: %v", err) } diff --git a/pkg/api/grpc/server_test.go b/pkg/api/grpc/server_test.go index 52c09e6..f0fb109 100644 --- a/pkg/api/grpc/server_test.go +++ b/pkg/api/grpc/server_test.go @@ -10,7 +10,9 @@ import ( "github.com/rs/zerolog" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" "github.com/evstack/apex/pkg/api" pb "github.com/evstack/apex/pkg/api/grpc/gen/apex/v1" @@ -136,7 +138,7 @@ func startTestServer(t *testing.T, svc *api.Service) pb.BlobServiceClient { t.Helper() srv := NewServer(svc, zerolog.Nop()) - lis, err := net.Listen("tcp", "127.0.0.1:0") + lis, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } @@ -157,7 +159,7 @@ func startTestHeaderServer(t *testing.T, svc *api.Service) pb.HeaderServiceClien t.Helper() srv := NewServer(svc, zerolog.Nop()) - lis, err := net.Listen("tcp", "127.0.0.1:0") + lis, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } @@ -195,11 +197,11 @@ func TestGRPCBlobGet(t *testing.T) { if err != nil { t.Fatalf("Get: %v", err) } - if resp.Blob.Height != 10 { - t.Errorf("Height = %d, want 10", resp.Blob.Height) + if resp.GetBlob().GetHeight() != 10 { + t.Errorf("Height = %d, want 10", resp.GetBlob().GetHeight()) } - if string(resp.Blob.Data) != "d1" { - t.Errorf("Data = %q, want %q", resp.Blob.Data, "d1") + if string(resp.GetBlob().GetData()) != "d1" { + t.Errorf("Data = %q, want %q", resp.GetBlob().GetData(), "d1") } } @@ -223,8 +225,32 @@ func TestGRPCBlobGetAll(t *testing.T) { if err != nil { t.Fatalf("GetAll: %v", err) } - if len(resp.Blobs) != 2 { - t.Errorf("got %d blobs, want 2", len(resp.Blobs)) + if len(resp.GetBlobs()) != 2 { + t.Errorf("got %d blobs, want 2", len(resp.GetBlobs())) + } +} + +func TestGRPCBlobGetAllRejectsTooManyNamespaces(t *testing.T) { + st := newMockStore() + notifier := api.NewNotifier(16, 1024, zerolog.Nop()) + svc := api.NewService(st, &mockFetcher{}, nil, notifier, zerolog.Nop()) + client := startTestServer(t, svc) + + namespaces := make([][]byte, 0, 17) + for i := range 17 { + ns := testNamespace(byte(i + 1)) + namespaces = append(namespaces, ns[:]) + } + + _, err := client.GetAll(context.Background(), &pb.GetAllRequest{ + Height: 10, + Namespaces: namespaces, + }) + if err == nil { + t.Fatal("expected GetAll to reject too many namespaces") + } + if code := status.Code(err); code != codes.InvalidArgument { + t.Fatalf("GetAll error code = %v, want InvalidArgument", code) } } @@ -247,11 +273,11 @@ func TestGRPCBlobGetByCommitment(t *testing.T) { if err != nil { t.Fatalf("GetByCommitment: %v", err) } - if resp.Blob.Height != 10 { - t.Errorf("Height = %d, want 10", resp.Blob.Height) + if resp.GetBlob().GetHeight() != 10 { + t.Errorf("Height = %d, want 10", resp.GetBlob().GetHeight()) } - if string(resp.Blob.Data) != "d1" { - t.Errorf("Data = %q, want %q", resp.Blob.Data, "d1") + if string(resp.GetBlob().GetData()) != "d1" { + t.Errorf("Data = %q, want %q", resp.GetBlob().GetData(), "d1") } } @@ -274,11 +300,11 @@ func TestGRPCHeaderGetByHeight(t *testing.T) { if err != nil { t.Fatalf("GetByHeight: %v", err) } - if resp.Header.Height != 42 { - t.Errorf("Height = %d, want 42", resp.Header.Height) + if resp.GetHeader().GetHeight() != 42 { + t.Errorf("Height = %d, want 42", resp.GetHeader().GetHeight()) } - if string(resp.Header.Hash) != "hash" { - t.Errorf("Hash = %q, want %q", resp.Header.Hash, "hash") + if string(resp.GetHeader().GetHash()) != "hash" { + t.Errorf("Hash = %q, want %q", resp.GetHeader().GetHash(), "hash") } } @@ -299,8 +325,8 @@ func TestGRPCHeaderLocalHead(t *testing.T) { if err != nil { t.Fatalf("LocalHead: %v", err) } - if resp.Header.Height != 100 { - t.Errorf("Height = %d, want 100", resp.Header.Height) + if resp.GetHeader().GetHeight() != 100 { + t.Errorf("Height = %d, want 100", resp.GetHeader().GetHeight()) } } @@ -321,8 +347,8 @@ func TestGRPCHeaderNetworkHead(t *testing.T) { if err != nil { t.Fatalf("NetworkHead: %v", err) } - if resp.Header.Height != 200 { - t.Errorf("Height = %d, want 200", resp.Header.Height) + if resp.GetHeader().GetHeight() != 200 { + t.Errorf("Height = %d, want 200", resp.GetHeader().GetHeight()) } } @@ -368,10 +394,69 @@ func TestGRPCBlobSubscribe(t *testing.T) { if err != nil { t.Fatalf("Recv: %v", err) } - if ev.Height != 1 { - t.Errorf("Height = %d, want 1", ev.Height) + if ev.GetHeight() != 1 { + t.Errorf("Height = %d, want 1", ev.GetHeight()) + } + if len(ev.GetBlobs()) != 1 { + t.Errorf("Blobs = %d, want 1", len(ev.GetBlobs())) + } +} + +func TestGRPCBlobSubscribeSkipsEmptyFilteredEvents(t *testing.T) { + st := newMockStore() + ns := testNamespace(1) + other := testNamespace(2) + + notifier := api.NewNotifier(16, 1024, zerolog.Nop()) + svc := api.NewService(st, &mockFetcher{}, nil, notifier, zerolog.Nop()) + client := startTestServer(t, svc) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := client.Subscribe(ctx, &pb.BlobServiceSubscribeRequest{ + Namespace: ns[:], + }) + if err != nil { + t.Fatalf("Subscribe: %v", err) + } + + deadline := time.After(5 * time.Second) + for notifier.SubscriberCount() == 0 { + select { + case <-deadline: + t.Fatal("timed out waiting for subscriber registration") + default: + time.Sleep(10 * time.Millisecond) + } + } + + notifier.Publish(api.HeightEvent{ + Height: 1, + Header: &types.Header{Height: 1}, + Blobs: []types.Blob{ + {Height: 1, Namespace: other, Data: []byte("ignore"), Index: 0}, + }, + }) + notifier.Publish(api.HeightEvent{ + Height: 2, + Header: &types.Header{Height: 2}, + Blobs: []types.Blob{ + {Height: 2, Namespace: ns, Data: []byte("deliver"), Index: 0}, + }, + }) + + ev, err := stream.Recv() + if err != nil { + t.Fatalf("Recv: %v", err) + } + if ev.GetHeight() != 2 { + t.Fatalf("Height = %d, want 2", ev.GetHeight()) + } + if len(ev.GetBlobs()) != 1 { + t.Fatalf("Blobs = %d, want 1", len(ev.GetBlobs())) } - if len(ev.Blobs) != 1 { - t.Errorf("Blobs = %d, want 1", len(ev.Blobs)) + if string(ev.GetBlobs()[0].GetData()) != "deliver" { + t.Fatalf("Data = %q, want %q", ev.GetBlobs()[0].GetData(), "deliver") } } diff --git a/pkg/api/jsonrpc/server_test.go b/pkg/api/jsonrpc/server_test.go index 388b02e..a4006cf 100644 --- a/pkg/api/jsonrpc/server_test.go +++ b/pkg/api/jsonrpc/server_test.go @@ -276,6 +276,39 @@ func TestJSONRPCBlobGetAll(t *testing.T) { } } +func TestJSONRPCBlobGetAllAllowsCompatibilitySizedNamespaceList(t *testing.T) { + st := newMockStore() + namespaces := make([][]byte, 0, 17) + for i := range 17 { + ns := testNamespace(byte(i + 1)) + namespaces = append(namespaces, ns[:]) + st.blobs[10] = append(st.blobs[10], types.Blob{ + Height: 10, + Namespace: ns, + Data: []byte{byte(i)}, + Commitment: []byte{byte(i)}, + Index: 0, + }) + } + + notifier := api.NewNotifier(16, 1024, zerolog.Nop()) + svc := api.NewService(st, &mockFetcher{}, nil, notifier, zerolog.Nop()) + srv := NewServer(svc, zerolog.Nop()) + + resp := doRPC(t, srv, "blob.GetAll", uint64(10), namespaces) + if resp.Error != nil { + t.Fatalf("RPC error: %s", resp.Error.Message) + } + + var blobs []json.RawMessage + if err := json.Unmarshal(resp.Result, &blobs); err != nil { + t.Fatalf("unmarshal blobs: %v", err) + } + if len(blobs) != 17 { + t.Fatalf("got %d blobs, want 17", len(blobs)) + } +} + func TestJSONRPCBlobGetByCommitment(t *testing.T) { st := newMockStore() ns := testNamespace(1) diff --git a/pkg/api/jsonrpc/stubs.go b/pkg/api/jsonrpc/stubs.go index 1c2ff98..91ade3c 100644 --- a/pkg/api/jsonrpc/stubs.go +++ b/pkg/api/jsonrpc/stubs.go @@ -3,12 +3,12 @@ package jsonrpc import ( "context" "encoding/json" - "fmt" + "errors" ) var ( - errNotSupported = fmt.Errorf("method not supported by apex indexer") - errReadOnly = fmt.Errorf("apex is a read-only indexer, blob submission not supported") + errNotSupported = errors.New("method not supported by apex indexer") + errReadOnly = errors.New("apex is a read-only indexer, blob submission not supported") ) // ShareStub holds stub methods for the share namespace. diff --git a/pkg/api/service.go b/pkg/api/service.go index 758b568..d39c456 100644 --- a/pkg/api/service.go +++ b/pkg/api/service.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "github.com/rs/zerolog" @@ -39,6 +40,15 @@ func NewService(s store.Store, f fetch.DataFetcher, proof fetch.ProofForwarder, // BlobGet returns a single blob matching the namespace and commitment at the // given height. Returns the blob as celestia-node compatible JSON. func (s *Service) BlobGet(ctx context.Context, height uint64, namespace types.Namespace, commitment []byte) (json.RawMessage, error) { + b, err := s.GetBlob(ctx, height, namespace, commitment) + if err != nil { + return nil, err + } + return MarshalBlob(b), nil +} + +// GetBlob returns a single blob matching the namespace and commitment. +func (s *Service) GetBlob(ctx context.Context, height uint64, namespace types.Namespace, commitment []byte) (*types.Blob, error) { blobs, err := s.store.GetBlobs(ctx, namespace, height, height, 0, 0) if err != nil { return nil, fmt.Errorf("get blobs: %w", err) @@ -46,7 +56,7 @@ func (s *Service) BlobGet(ctx context.Context, height uint64, namespace types.Na for i := range blobs { if bytes.Equal(blobs[i].Commitment, commitment) { - return MarshalBlob(&blobs[i]), nil + return &blobs[i], nil } } @@ -56,20 +66,52 @@ func (s *Service) BlobGet(ctx context.Context, height uint64, namespace types.Na // BlobGetByCommitment returns a blob matching the given commitment as JSON. // No height or namespace required — commitment is cryptographically unique. func (s *Service) BlobGetByCommitment(ctx context.Context, commitment []byte) (json.RawMessage, error) { + b, err := s.GetBlobByCommitment(ctx, commitment) + if err != nil { + return nil, err + } + return MarshalBlob(b), nil +} + +// GetBlobByCommitment returns a blob matching the given commitment. +func (s *Service) GetBlobByCommitment(ctx context.Context, commitment []byte) (*types.Blob, error) { if len(commitment) == 0 { - return nil, fmt.Errorf("commitment is required") + return nil, errors.New("commitment is required") } b, err := s.store.GetBlobByCommitment(ctx, commitment) if err != nil { return nil, fmt.Errorf("get blob by commitment: %w", err) } - return MarshalBlob(b), nil + return b, nil } // BlobGetAll returns all blobs for the given namespaces at the given height. // limit=0 means no limit; offset=0 means no offset. // Pagination is applied to the aggregate result across all namespaces. func (s *Service) BlobGetAll(ctx context.Context, height uint64, namespaces []types.Namespace, limit, offset int) (json.RawMessage, error) { + allBlobs, err := s.GetAllBlobs(ctx, height, namespaces, limit, offset) + if err != nil { + return nil, err + } + if len(allBlobs) == 0 { + return json.RawMessage("null"), nil + } + + result := make([]json.RawMessage, len(allBlobs)) + for i := range allBlobs { + result[i] = MarshalBlob(&allBlobs[i]) + } + + out, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("marshal blobs: %w", err) + } + return out, nil +} + +// GetAllBlobs returns all blobs for the given namespaces at the given height. +// Pagination is applied to the aggregate result across all namespaces. +func (s *Service) GetAllBlobs(ctx context.Context, height uint64, namespaces []types.Namespace, limit, offset int) ([]types.Blob, error) { allBlobs := make([]types.Blob, 0, len(namespaces)*8) // preallocate for typical workload for _, ns := range namespaces { blobs, err := s.store.GetBlobs(ctx, ns, height, height, 0, 0) @@ -90,27 +132,13 @@ func (s *Service) BlobGetAll(ctx context.Context, height uint64, namespaces []ty if limit > 0 && limit < len(allBlobs) { allBlobs = allBlobs[:limit] } - - if len(allBlobs) == 0 { - return json.RawMessage("null"), nil - } - - result := make([]json.RawMessage, len(allBlobs)) - for i := range allBlobs { - result[i] = MarshalBlob(&allBlobs[i]) - } - - out, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("marshal blobs: %w", err) - } - return out, nil + return allBlobs, nil } // BlobGetProof forwards a proof request to the upstream Celestia node. func (s *Service) BlobGetProof(ctx context.Context, height uint64, namespace, commitment []byte) (json.RawMessage, error) { if s.proof == nil { - return nil, fmt.Errorf("proof forwarding not available") + return nil, errors.New("proof forwarding not available") } return s.proof.GetProof(ctx, height, namespace, commitment) } @@ -118,7 +146,7 @@ func (s *Service) BlobGetProof(ctx context.Context, height uint64, namespace, co // BlobIncluded forwards an inclusion check to the upstream Celestia node. func (s *Service) BlobIncluded(ctx context.Context, height uint64, namespace []byte, proof json.RawMessage, commitment []byte) (bool, error) { if s.proof == nil { - return false, fmt.Errorf("proof forwarding not available") + return false, errors.New("proof forwarding not available") } return s.proof.Included(ctx, height, namespace, proof, commitment) } @@ -130,15 +158,33 @@ func (s *Service) BlobSubscribe(namespace types.Namespace) (*Subscription, error // HeaderGetByHeight returns the raw header JSON at the given height. func (s *Service) HeaderGetByHeight(ctx context.Context, height uint64) (json.RawMessage, error) { + hdr, err := s.GetHeaderByHeight(ctx, height) + if err != nil { + return nil, err + } + return hdr.RawHeader, nil +} + +// GetHeaderByHeight returns the stored header at the given height. +func (s *Service) GetHeaderByHeight(ctx context.Context, height uint64) (*types.Header, error) { hdr, err := s.store.GetHeader(ctx, height) if err != nil { return nil, fmt.Errorf("get header: %w", err) } - return hdr.RawHeader, nil + return hdr, nil } // HeaderLocalHead returns the header at the latest synced height. func (s *Service) HeaderLocalHead(ctx context.Context) (json.RawMessage, error) { + hdr, err := s.GetLocalHead(ctx) + if err != nil { + return nil, err + } + return hdr.RawHeader, nil +} + +// GetLocalHead returns the locally indexed head header. +func (s *Service) GetLocalHead(ctx context.Context) (*types.Header, error) { ss, err := s.store.GetSyncState(ctx) if err != nil { return nil, fmt.Errorf("get sync state: %w", err) @@ -147,16 +193,25 @@ func (s *Service) HeaderLocalHead(ctx context.Context) (json.RawMessage, error) if err != nil { return nil, fmt.Errorf("get header at height %d: %w", ss.LatestHeight, err) } - return hdr.RawHeader, nil + return hdr, nil } // HeaderNetworkHead returns the current network head from the upstream node. func (s *Service) HeaderNetworkHead(ctx context.Context) (json.RawMessage, error) { + hdr, err := s.GetNetworkHead(ctx) + if err != nil { + return nil, err + } + return hdr.RawHeader, nil +} + +// GetNetworkHead returns the current network head from the upstream node. +func (s *Service) GetNetworkHead(ctx context.Context) (*types.Header, error) { hdr, err := s.fetcher.GetNetworkHead(ctx) if err != nil { return nil, fmt.Errorf("get network head: %w", err) } - return hdr.RawHeader, nil + return hdr, nil } // HeaderSubscribe creates a subscription for all new headers. @@ -169,16 +224,6 @@ func (s *Service) Notifier() *Notifier { return s.notifier } -// Store returns the underlying store for direct access. -func (s *Service) Store() store.Store { - return s.store -} - -// Fetcher returns the underlying fetcher for direct access. -func (s *Service) Fetcher() fetch.DataFetcher { - return s.fetcher -} - // blobJSON is a struct-based representation for celestia-node compatible JSON. // Using a struct avoids the per-call map[string]any allocation that json.Marshal // requires for maps. diff --git a/pkg/backfill/db/source.go b/pkg/backfill/db/source.go index c448e22..2cfbc95 100644 --- a/pkg/backfill/db/source.go +++ b/pkg/backfill/db/source.go @@ -4,9 +4,11 @@ import ( "context" "encoding/hex" "encoding/json" + "errors" "fmt" "os" "path/filepath" + "strconv" "time" "github.com/cockroachdb/pebble" @@ -24,6 +26,10 @@ import ( const ( layoutV1 = "v1" layoutV2 = "v2" + + backendAuto = "auto" + backendPebble = "pebble" + backendLevelDB = "leveldb" ) // Config controls direct celestia-app DB reads. @@ -47,10 +53,10 @@ var _ backfill.Source = (*Source)(nil) // NewSource opens blockstore.db and auto-detects backend/layout when configured. func NewSource(cfg Config, log zerolog.Logger) (*Source, error) { if cfg.Backend == "" { - cfg.Backend = "auto" + cfg.Backend = backendAuto } if cfg.Layout == "" { - cfg.Layout = "auto" + cfg.Layout = backendAuto } dbPath, err := normalizePath(cfg.Path) @@ -80,7 +86,7 @@ func NewSource(cfg Config, log zerolog.Logger) (*Source, error) { func normalizePath(path string) (string, error) { if path == "" { - return "", fmt.Errorf("celestia-app db path is required") + return "", errors.New("celestia-app db path is required") } path = filepath.Clean(path) @@ -107,7 +113,7 @@ func normalizePath(path string) (string, error) { } func detectLayout(db kvDB, requested string) (layout string, version string, err error) { - if requested != "auto" && requested != layoutV1 && requested != layoutV2 { + if requested != backendAuto && requested != layoutV1 && requested != layoutV2 { return "", "", fmt.Errorf("invalid layout %q: must be auto|v1|v2", requested) } if requested == layoutV1 || requested == layoutV2 { @@ -162,7 +168,7 @@ func (s *Source) FetchHeight(_ context.Context, height uint64, namespaces []type } rawBlock := make([]byte, 0, meta.PartsTotal*65536) // 64KB per part estimate - for idx := uint32(0); idx < meta.PartsTotal; idx++ { + for idx := range meta.PartsTotal { partKey := blockPartKey(s.layout, height, idx) partRaw, err := s.db.Get(partKey) if err != nil { @@ -247,7 +253,7 @@ func decodeBlockMeta(raw []byte) (decodedMeta, error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return decodedMeta{}, fmt.Errorf("invalid block meta tag") + return decodedMeta{}, errors.New("invalid block meta tag") } buf = buf[n:] @@ -262,7 +268,7 @@ func decodeBlockMeta(raw []byte) (decodedMeta, error) { blockIDBytes, n := protowire.ConsumeBytes(buf) if n < 0 { - return decodedMeta{}, fmt.Errorf("invalid block_id bytes") + return decodedMeta{}, errors.New("invalid block_id bytes") } buf = buf[n:] @@ -275,7 +281,7 @@ func decodeBlockMeta(raw []byte) (decodedMeta, error) { } if out.PartsTotal == 0 { - return decodedMeta{}, fmt.Errorf("missing part_set_header.total") + return decodedMeta{}, errors.New("missing part_set_header.total") } return out, nil } @@ -287,7 +293,7 @@ func decodeBlockID(raw []byte) ([]byte, uint32, error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return nil, 0, fmt.Errorf("invalid block_id tag") + return nil, 0, errors.New("invalid block_id tag") } buf = buf[n:] if typ != protowire.BytesType { @@ -300,7 +306,7 @@ func decodeBlockID(raw []byte) ([]byte, uint32, error) { } val, n := protowire.ConsumeBytes(buf) if n < 0 { - return nil, 0, fmt.Errorf("invalid block_id bytes") + return nil, 0, errors.New("invalid block_id bytes") } buf = buf[n:] switch num { @@ -312,6 +318,7 @@ func decodeBlockID(raw []byte) ([]byte, uint32, error) { if err != nil { return nil, 0, err } + default: } } return hash, total, nil @@ -322,13 +329,13 @@ func decodePartSetHeaderTotal(raw []byte) (uint32, error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return 0, fmt.Errorf("invalid part_set_header tag") + return 0, errors.New("invalid part_set_header tag") } buf = buf[n:] if num == 1 && typ == protowire.VarintType { v, n := protowire.ConsumeVarint(buf) if n < 0 { - return 0, fmt.Errorf("invalid part_set_header.total") + return 0, errors.New("invalid part_set_header.total") } return uint32(v), nil } @@ -338,7 +345,7 @@ func decodePartSetHeaderTotal(raw []byte) (uint32, error) { } buf = buf[n:] } - return 0, fmt.Errorf("missing part_set_header.total") + return 0, errors.New("missing part_set_header.total") } func decodePartBytes(raw []byte) ([]byte, error) { @@ -346,13 +353,13 @@ func decodePartBytes(raw []byte) ([]byte, error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return nil, fmt.Errorf("invalid part tag") + return nil, errors.New("invalid part tag") } buf = buf[n:] if num == 2 && typ == protowire.BytesType { v, n := protowire.ConsumeBytes(buf) if n < 0 { - return nil, fmt.Errorf("invalid part.bytes") + return nil, errors.New("invalid part.bytes") } return append([]byte(nil), v...), nil } @@ -362,7 +369,7 @@ func decodePartBytes(raw []byte) ([]byte, error) { } buf = buf[n:] } - return nil, fmt.Errorf("missing part.bytes") + return nil, errors.New("missing part.bytes") } type decodedBlock struct { @@ -378,7 +385,7 @@ func decodeBlock(raw []byte) (decodedBlock, error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return decodedBlock{}, fmt.Errorf("invalid block tag") + return decodedBlock{}, errors.New("invalid block tag") } buf = buf[n:] if typ != protowire.BytesType { @@ -407,6 +414,7 @@ func decodeBlock(raw []byte) (decodedBlock, error) { return decodedBlock{}, err } out.Txs = txs + default: } } return out, nil @@ -417,21 +425,21 @@ func decodeHeader(raw []byte, out *decodedBlock) error { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return fmt.Errorf("invalid header tag") + return errors.New("invalid header tag") } buf = buf[n:] switch { case num == 3 && typ == protowire.VarintType: v, n := protowire.ConsumeVarint(buf) if n < 0 { - return fmt.Errorf("invalid header.height") + return errors.New("invalid header.height") } out.Height = int64(v) buf = buf[n:] case num == 4 && typ == protowire.BytesType: tsRaw, n := protowire.ConsumeBytes(buf) if n < 0 { - return fmt.Errorf("invalid header.time") + return errors.New("invalid header.time") } t, err := decodeTimestamp(tsRaw) if err != nil { @@ -442,7 +450,7 @@ func decodeHeader(raw []byte, out *decodedBlock) error { case num == 7 && typ == protowire.BytesType: v, n := protowire.ConsumeBytes(buf) if n < 0 { - return fmt.Errorf("invalid header.data_hash") + return errors.New("invalid header.data_hash") } out.DataHash = append([]byte(nil), v...) buf = buf[n:] @@ -466,7 +474,7 @@ func decodeTimestamp(raw []byte) (time.Time, error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return time.Time{}, fmt.Errorf("invalid timestamp tag") + return time.Time{}, errors.New("invalid timestamp tag") } buf = buf[n:] if typ != protowire.VarintType { @@ -479,7 +487,7 @@ func decodeTimestamp(raw []byte) (time.Time, error) { } v, n := protowire.ConsumeVarint(buf) if n < 0 { - return time.Time{}, fmt.Errorf("invalid timestamp varint") + return time.Time{}, errors.New("invalid timestamp varint") } buf = buf[n:] switch num { @@ -487,6 +495,7 @@ func decodeTimestamp(raw []byte) (time.Time, error) { seconds = int64(v) case 2: nanos = int64(v) + default: } } return time.Unix(seconds, nanos).UTC(), nil @@ -498,13 +507,13 @@ func decodeDataTxs(raw []byte) ([][]byte, error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return nil, fmt.Errorf("invalid data tag") + return nil, errors.New("invalid data tag") } buf = buf[n:] if num == 1 && typ == protowire.BytesType { tx, n := protowire.ConsumeBytes(buf) if n < 0 { - return nil, fmt.Errorf("invalid tx bytes") + return nil, errors.New("invalid tx bytes") } txs = append(txs, append([]byte(nil), tx...)) buf = buf[n:] @@ -563,33 +572,33 @@ func probeKV(db kvDB) bool { func openKV(path, backend string) (kvDB, string, error) { switch backend { - case "pebble": + case backendPebble: db, err := pebble.Open(path, &pebble.Options{ReadOnly: true}) if err != nil { return nil, "", fmt.Errorf("open pebble blockstore %q: %w", path, err) } - return &pebbleDB{db: db}, "pebble", nil - case "leveldb": + return &pebbleDB{db: db}, backendPebble, nil + case backendLevelDB: db, err := leveldb.OpenFile(path, &ldbopt.Options{ReadOnly: true}) if err != nil { return nil, "", fmt.Errorf("open leveldb blockstore %q: %w", path, err) } - return &levelDB{db: db}, "leveldb", nil - case "auto": + return &levelDB{db: db}, backendLevelDB, nil + case backendAuto: // Pebble can sometimes open LevelDB files due to format compatibility. // After opening, probe for a known marker key to confirm the backend // is reading valid data before committing to it. if db, err := pebble.Open(path, &pebble.Options{ReadOnly: true}); err == nil { kv := &pebbleDB{db: db} if probeKV(kv) { - return kv, "pebble", nil + return kv, backendPebble, nil } _ = kv.Close() } if db, err := leveldb.OpenFile(path, &ldbopt.Options{ReadOnly: true}); err == nil { kv := &levelDB{db: db} if probeKV(kv) { - return kv, "leveldb", nil + return kv, backendLevelDB, nil } _ = kv.Close() } @@ -617,13 +626,13 @@ func (w *writableLevel) close() error { return w.db.Close() } func openWritable(path, backend string) (writableKV, error) { switch backend { - case "pebble": + case backendPebble: db, err := pebble.Open(path, &pebble.Options{}) if err != nil { return nil, err } return &writablePebble{db: db}, nil - case "leveldb": + case backendLevelDB: db, err := leveldb.OpenFile(path, nil) if err != nil { return nil, err @@ -721,14 +730,15 @@ func splitIntoParts(raw []byte, partSize int) [][]byte { // fields. The backfill source reads raw protobuf, so we build the JSON that // consumers (ev-node) expect rather than storing the full protobuf blob. func buildMinimalRawHeader(height uint64, t time.Time, dataHash, blockHash []byte) ([]byte, error) { + heightStr := strconv.FormatUint(height, 10) obj := map[string]any{ "header": map[string]any{ - "height": fmt.Sprintf("%d", height), + "height": heightStr, "time": t.Format(time.RFC3339Nano), "data_hash": hex.EncodeToString(dataHash), }, "commit": map[string]any{ - "height": fmt.Sprintf("%d", height), + "height": heightStr, "block_id": map[string]any{ "hash": hex.EncodeToString(blockHash), }, diff --git a/pkg/fetch/blobtx.go b/pkg/fetch/blobtx.go index ca176e5..26ac795 100644 --- a/pkg/fetch/blobtx.go +++ b/pkg/fetch/blobtx.go @@ -1,6 +1,7 @@ package fetch import ( + "errors" "fmt" "google.golang.org/protobuf/encoding/protowire" @@ -45,7 +46,7 @@ type parsedBlobTx struct { // inner_tx (length-prefixed) || blob1 (length-prefixed) || blob2 ... || 0x62 func parseBlobTx(raw []byte) (*parsedBlobTx, error) { if len(raw) == 0 { - return nil, fmt.Errorf("empty BlobTx") + return nil, errors.New("empty BlobTx") } if raw[len(raw)-1] != blobTxTypeID { return nil, fmt.Errorf("not a BlobTx: trailing byte 0x%02x, want 0x%02x", raw[len(raw)-1], blobTxTypeID) @@ -57,7 +58,7 @@ func parseBlobTx(raw []byte) (*parsedBlobTx, error) { // Read inner SDK tx (length-prefixed). innerTx, n := protowire.ConsumeBytes(data) if n < 0 { - return nil, fmt.Errorf("decode inner tx: invalid length prefix") + return nil, errors.New("decode inner tx: invalid length prefix") } data = data[n:] @@ -101,7 +102,7 @@ func parsePFBFromTx(txBytes []byte) (pfbData, error) { return pfbData{}, fmt.Errorf("extract tx body: %w", err) } if bodyBytes == nil { - return pfbData{}, fmt.Errorf("tx has no body") + return pfbData{}, errors.New("tx has no body") } // Iterate messages (field 1, repeated) in TxBody to find MsgPayForBlobs. @@ -117,7 +118,7 @@ func parsePFBFromTx(txBytes []byte) (pfbData, error) { return parseMsgPayForBlobs(value) } - return pfbData{}, fmt.Errorf("no MsgPayForBlobs found in tx") + return pfbData{}, errors.New("no MsgPayForBlobs found in tx") } // parseMsgPayForBlobs extracts signer and share_commitments from MsgPayForBlobs. @@ -130,7 +131,7 @@ func parseMsgPayForBlobs(data []byte) (pfbData, error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return pfbData{}, fmt.Errorf("invalid tag in MsgPayForBlobs") + return pfbData{}, errors.New("invalid tag in MsgPayForBlobs") } buf = buf[n:] @@ -146,6 +147,7 @@ func parseMsgPayForBlobs(data []byte) (pfbData, error) { result.Signer = append([]byte(nil), val...) case 4: // share_commitments (repeated bytes) result.ShareCommitments = append(result.ShareCommitments, append([]byte(nil), val...)) + default: } case protowire.VarintType: _, n := protowire.ConsumeVarint(buf) @@ -173,7 +175,7 @@ func parseAny(data []byte) (typeURL string, value []byte, err error) { for len(buf) > 0 { num, typ, n := protowire.ConsumeTag(buf) if n < 0 { - return "", nil, fmt.Errorf("invalid tag") + return "", nil, errors.New("invalid tag") } buf = buf[n:] @@ -198,6 +200,7 @@ func parseAny(data []byte) (typeURL string, value []byte, err error) { typeURL = string(val) case 2: value = val + default: } } return typeURL, value, nil @@ -212,7 +215,7 @@ func parseRawBlob(data []byte) (rawBlob, error) { for len(data) > 0 { num, typ, n := protowire.ConsumeTag(data) if n < 0 { - return rawBlob{}, fmt.Errorf("invalid proto tag") + return rawBlob{}, errors.New("invalid proto tag") } data = data[n:] @@ -230,6 +233,7 @@ func parseRawBlob(data []byte) (rawBlob, error) { b.Data = append([]byte(nil), val...) case 5: // signer (celestia-app v2+) b.Signer = append([]byte(nil), val...) + default: } case protowire.VarintType: val, n := protowire.ConsumeVarint(data) @@ -242,6 +246,7 @@ func parseRawBlob(data []byte) (rawBlob, error) { b.ShareVersion = uint32(val) case 4: // namespace_version b.NamespaceVersion = uint32(val) + default: } default: // Skip unknown wire types for forward compatibility. @@ -330,7 +335,7 @@ func extractBytesField(data []byte, target protowire.Number) ([]byte, error) { for len(data) > 0 { num, typ, n := protowire.ConsumeTag(data) if n < 0 { - return nil, fmt.Errorf("invalid tag") + return nil, errors.New("invalid tag") } data = data[n:] diff --git a/pkg/fetch/blobtx_test.go b/pkg/fetch/blobtx_test.go index 2bc7d93..591abc7 100644 --- a/pkg/fetch/blobtx_test.go +++ b/pkg/fetch/blobtx_test.go @@ -80,8 +80,8 @@ func buildTx(body []byte) []byte { // buildInnerSDKTx constructs a valid inner SDK tx containing MsgPayForBlobs. func buildInnerSDKTx(signer string, commitments [][]byte) []byte { pfb := buildMsgPayForBlobs(signer, commitments) - any := buildAny(msgPayForBlobsTypeURL, pfb) - body := buildTxBody(any) + anyPB := buildAny(msgPayForBlobsTypeURL, pfb) + body := buildTxBody(anyPB) return buildTx(body) } diff --git a/pkg/fetch/celestia_app.go b/pkg/fetch/celestia_app.go index 6479f43..7c00d0f 100644 --- a/pkg/fetch/celestia_app.go +++ b/pkg/fetch/celestia_app.go @@ -3,7 +3,9 @@ package fetch import ( "context" "encoding/json" + "errors" "fmt" + "strconv" "sync" "time" @@ -101,7 +103,7 @@ func (f *CelestiaAppFetcher) GetHeightData(ctx context.Context, height uint64, n return nil, nil, fmt.Errorf("get block at height %d: %w", height, err) } - hdr, err := mapBlockResponse(resp.BlockId, resp.Block) + hdr, err := mapBlockResponse(resp.GetBlockId(), resp.GetBlock()) if err != nil { return nil, nil, err } @@ -109,11 +111,11 @@ func (f *CelestiaAppFetcher) GetHeightData(ctx context.Context, height uint64, n if len(namespaces) == 0 { return hdr, nil, nil } - if resp.Block == nil || resp.Block.Data == nil { + if resp.GetBlock() == nil || resp.GetBlock().GetData() == nil { return hdr, nil, nil } - txs := resp.Block.Data.Txs + txs := resp.GetBlock().GetData().GetTxs() blobs, err := extractBlobsFromBlock(txs, namespaces, height) if err != nil { return nil, nil, fmt.Errorf("extract blobs at height %d: %w", height, err) @@ -130,7 +132,7 @@ func (f *CelestiaAppFetcher) GetNetworkHead(ctx context.Context) (*types.Header, if err != nil { return nil, fmt.Errorf("get latest block: %w", err) } - return mapBlockResponse(resp.BlockId, resp.Block) + return mapBlockResponse(resp.GetBlockId(), resp.GetBlock()) } // SubscribeHeaders polls GetLatestBlock at 1s intervals and emits new headers @@ -143,7 +145,7 @@ func (f *CelestiaAppFetcher) SubscribeHeaders(ctx context.Context) (<-chan *type if f.closed { f.mu.Unlock() cancel() - return nil, fmt.Errorf("fetcher is closed") + return nil, errors.New("fetcher is closed") } if f.cancelSub != nil { f.cancelSub() @@ -207,15 +209,15 @@ func (f *CelestiaAppFetcher) Close() error { // mapBlockResponse converts a gRPC block response into a types.Header. func mapBlockResponse(blockID *cometpb.BlockID, block *cometpb.Block) (*types.Header, error) { - if block == nil || block.Header == nil { - return nil, fmt.Errorf("nil block or header in response") + if block == nil || block.GetHeader() == nil { + return nil, errors.New("nil block or header in response") } if blockID == nil { - return nil, fmt.Errorf("nil block_id in response") + return nil, errors.New("nil block_id in response") } - hdr := block.Header - t := hdr.Time.AsTime() + hdr := block.GetHeader() + t := hdr.GetTime().AsTime() // Wrap in envelope matching the canonical shape used by celestia_node and // backfill: {"header": ..., "commit": ...}. The gRPC response does not @@ -224,7 +226,7 @@ func mapBlockResponse(blockID *cometpb.BlockID, block *cometpb.Block) (*types.He envelope := map[string]any{ "header": hdr, "commit": map[string]any{ - "height": fmt.Sprintf("%d", hdr.Height), + "height": strconv.FormatInt(hdr.GetHeight(), 10), "block_id": blockID, }, } @@ -234,9 +236,9 @@ func mapBlockResponse(blockID *cometpb.BlockID, block *cometpb.Block) (*types.He } return &types.Header{ - Height: uint64(hdr.Height), - Hash: blockID.Hash, - DataHash: hdr.DataHash, + Height: uint64(hdr.GetHeight()), + Hash: blockID.GetHash(), + DataHash: hdr.GetDataHash(), Time: t, RawHeader: raw, }, nil diff --git a/pkg/fetch/celestia_app_test.go b/pkg/fetch/celestia_app_test.go index 235f5ec..8e84214 100644 --- a/pkg/fetch/celestia_app_test.go +++ b/pkg/fetch/celestia_app_test.go @@ -24,13 +24,13 @@ type mockCometService struct { } func (m *mockCometService) GetBlockByHeight(_ context.Context, req *cometpb.GetBlockByHeightRequest) (*cometpb.GetBlockByHeightResponse, error) { - resp, ok := m.blocks[req.Height] + resp, ok := m.blocks[req.GetHeight()] if !ok { return &cometpb.GetBlockByHeightResponse{ BlockId: &cometpb.BlockID{Hash: []byte("default")}, Block: &cometpb.Block{ Header: &cometpb.Header{ - Height: req.Height, + Height: req.GetHeight(), Time: timestamppb.Now(), }, Data: &cometpb.Data{}, diff --git a/pkg/fetch/celestia_node.go b/pkg/fetch/celestia_node.go index 83083d8..acbb4f6 100644 --- a/pkg/fetch/celestia_node.go +++ b/pkg/fetch/celestia_node.go @@ -489,7 +489,7 @@ func retryDelay(attempt int) time.Duration { if jitterCap <= 0 { return base } - return base + time.Duration(rand.Int64N(int64(jitterCap))) + return base + time.Duration(rand.Int64N(int64(jitterCap))) //nolint:gosec // G404: weak random is fine for jitter } func isNotFoundErr(err error) bool { diff --git a/pkg/store/s3.go b/pkg/store/s3.go index d8e02c3..70c7e64 100644 --- a/pkg/store/s3.go +++ b/pkg/store/s3.go @@ -87,11 +87,18 @@ type S3Store struct { blobBuf map[blobChunkKey][]types.Blob headerBuf map[uint64][]*types.Header commitBuf []commitEntry + inflight flushBuffers nsMu sync.Mutex // guards PutNamespace read-modify-write (separate from buffer mu) flushMu sync.Mutex // serializes flush operations } +type flushBuffers struct { + blobBuf map[blobChunkKey][]types.Blob + headerBuf map[uint64][]*types.Header + commitBuf []commitEntry +} + // NewS3Store creates a new S3Store from the given config. func NewS3Store(ctx context.Context, cfg *config.S3Config) (*S3Store, error) { opts := []func(*awsconfig.LoadOptions) error{} @@ -203,6 +210,13 @@ func (s *S3Store) PutBlobs(ctx context.Context, blobs []types.Blob) error { for i := range blobs { b := &blobs[i] + exists, err := s.ensureBufferedBlobInvariant(b, s.blobBuf, s.inflight.blobBuf) + if err != nil { + return err + } + if exists { + continue + } key := blobChunkKey{namespace: b.Namespace, chunkStart: s.chunkStart(b.Height)} s.blobBuf[key] = append(s.blobBuf[key], *b) @@ -307,7 +321,10 @@ func (s *S3Store) GetBlobs(ctx context.Context, ns types.Namespace, startHeight, return nil, err } - allBlobs := deduplicateBlobs(append(buffered, s3Blobs...)) + allBlobs, err := mergeUniqueBlobs(append(buffered, s3Blobs...)) + if err != nil { + return nil, err + } sort.Slice(allBlobs, func(i, j int) bool { if allBlobs[i].Height != allBlobs[j].Height { @@ -325,16 +342,8 @@ func (s *S3Store) collectBufferedBlobs(ns types.Namespace, startHeight, endHeigh defer s.mu.Unlock() var result []types.Blob - for key, bufs := range s.blobBuf { - if key.namespace != ns { - continue - } - for i := range bufs { - if bufs[i].Height >= startHeight && bufs[i].Height <= endHeight { - result = append(result, bufs[i]) - } - } - } + result = collectBlobsInRange(result, s.blobBuf, ns, startHeight, endHeight) + result = collectBlobsInRange(result, s.inflight.blobBuf, ns, startHeight, endHeight) return result } @@ -388,18 +397,11 @@ func (s *S3Store) GetBlobByCommitment(ctx context.Context, commitment []byte) (* commitHex := hex.EncodeToString(commitment) s.mu.Lock() - for _, entry := range s.commitBuf { - if entry.commitmentHex == commitHex { - ns, err := types.NamespaceFromHex(entry.pointer.Namespace) - if err == nil { - if b := s.findBlobInBufferLocked(ns, entry.pointer.Height, entry.pointer.Index); b != nil { - s.mu.Unlock() - return b, nil - } - } - } - } + b := s.findCommitEntryBlobLocked(commitHex) s.mu.Unlock() + if b != nil { + return b, nil + } // Look up commitment index in S3. key := s.key("index", "commitments", commitHex+".json") @@ -549,15 +551,20 @@ func (s *S3Store) flush(ctx context.Context) error { blobBuf := s.blobBuf headerBuf := s.headerBuf commitBuf := s.commitBuf + if len(blobBuf) == 0 && len(headerBuf) == 0 && len(commitBuf) == 0 { + s.mu.Unlock() + return nil + } + s.inflight = flushBuffers{ + blobBuf: blobBuf, + headerBuf: headerBuf, + commitBuf: commitBuf, + } s.blobBuf = make(map[blobChunkKey][]types.Blob) s.headerBuf = make(map[uint64][]*types.Header) s.commitBuf = nil s.mu.Unlock() - if len(blobBuf) == 0 && len(headerBuf) == 0 && len(commitBuf) == 0 { - return nil - } - // Use a semaphore to bound concurrency. sem := make(chan struct{}, maxFlushConcurrency) var ( @@ -612,7 +619,13 @@ func (s *S3Store) flush(ctx context.Context) error { } wg.Wait() - return errors.Join(errs...) + if err := errors.Join(errs...); err != nil { + s.restoreInflight(blobBuf, headerBuf, commitBuf) + return err + } + + s.clearInflight() + return nil } func (s *S3Store) flushBlobChunk(ctx context.Context, key blobChunkKey, newBlobs []types.Blob) error { @@ -624,7 +637,7 @@ func (s *S3Store) flushBlobChunk(ctx context.Context, key blobChunkKey, newBlobs return fmt.Errorf("read blob chunk for merge: %w", err) } - var merged []types.Blob + merged := make([]types.Blob, 0, len(newBlobs)) if existing != nil { merged, err = decodeS3Blobs(existing) if err != nil { @@ -633,7 +646,10 @@ func (s *S3Store) flushBlobChunk(ctx context.Context, key blobChunkKey, newBlobs } merged = append(merged, newBlobs...) - merged = deduplicateBlobs(merged) + merged, err = mergeUniqueBlobs(merged) + if err != nil { + return err + } sort.Slice(merged, func(i, j int) bool { if merged[i].Height != merged[j].Height { return merged[i].Height < merged[j].Height @@ -698,12 +714,14 @@ func (s *S3Store) findBlobInBuffer(ns types.Namespace, height uint64, index int) } func (s *S3Store) findBlobInBufferLocked(ns types.Namespace, height uint64, index int) *types.Blob { - key := blobChunkKey{namespace: ns, chunkStart: s.chunkStart(height)} - for i := range s.blobBuf[key] { - b := &s.blobBuf[key][i] - if b.Height == height && b.Index == index { - cp := *b - return &cp + for _, buf := range []map[blobChunkKey][]types.Blob{s.blobBuf, s.inflight.blobBuf} { + key := blobChunkKey{namespace: ns, chunkStart: s.chunkStart(height)} + for i := range buf[key] { + b := &buf[key][i] + if b.Height == height && b.Index == index { + cp := *b + return &cp + } } } return nil @@ -714,15 +732,91 @@ func (s *S3Store) findHeaderInBuffer(height uint64) *types.Header { defer s.mu.Unlock() cs := s.chunkStart(height) - for _, h := range s.headerBuf[cs] { - if h.Height == height { - cp := *h - return &cp + for _, buf := range []map[uint64][]*types.Header{s.headerBuf, s.inflight.headerBuf} { + for _, h := range buf[cs] { + if h.Height == height { + cp := *h + return &cp + } } } return nil } +func collectBlobsInRange(result []types.Blob, buf map[blobChunkKey][]types.Blob, ns types.Namespace, startHeight, endHeight uint64) []types.Blob { + for key, blobs := range buf { + if key.namespace != ns { + continue + } + for i := range blobs { + if blobs[i].Height >= startHeight && blobs[i].Height <= endHeight { + result = append(result, blobs[i]) + } + } + } + return result +} + +func (s *S3Store) ensureBufferedBlobInvariant(b *types.Blob, bufs ...map[blobChunkKey][]types.Blob) (bool, error) { + for _, buf := range bufs { + key := blobChunkKey{namespace: b.Namespace, chunkStart: s.chunkStart(b.Height)} + if blobs, ok := buf[key]; ok { + for i := range blobs { + existing := &blobs[i] + if existing.Height == b.Height && existing.Namespace == b.Namespace && existing.Index == b.Index { + if !sameBlob(existing, b) { + return false, fmt.Errorf("blob conflict at height %d namespace %s index %d", b.Height, b.Namespace, b.Index) + } + return true, nil + } + if bytes.Equal(existing.Commitment, b.Commitment) && !sameBlob(existing, b) { + return false, fmt.Errorf("blob commitment conflict for %x", b.Commitment) + } + } + } + } + return false, nil +} + +// findCommitEntryBlobLocked searches both commitBuf and inflight.commitBuf +// for a matching commitment and returns the corresponding buffered blob. +// Caller must hold s.mu. +func (s *S3Store) findCommitEntryBlobLocked(commitHex string) *types.Blob { + for _, entries := range [2][]commitEntry{s.commitBuf, s.inflight.commitBuf} { + for _, entry := range entries { + if entry.commitmentHex == commitHex { + ns, err := types.NamespaceFromHex(entry.pointer.Namespace) + if err == nil { + if b := s.findBlobInBufferLocked(ns, entry.pointer.Height, entry.pointer.Index); b != nil { + return b + } + } + } + } + } + return nil +} + +func (s *S3Store) restoreInflight(blobBuf map[blobChunkKey][]types.Blob, headerBuf map[uint64][]*types.Header, commitBuf []commitEntry) { + s.mu.Lock() + defer s.mu.Unlock() + + for key, blobs := range blobBuf { + s.blobBuf[key] = append(blobs, s.blobBuf[key]...) + } + for cs, headers := range headerBuf { + s.headerBuf[cs] = append(headers, s.headerBuf[cs]...) + } + s.commitBuf = append(commitBuf, s.commitBuf...) + s.inflight = flushBuffers{} +} + +func (s *S3Store) clearInflight() { + s.mu.Lock() + s.inflight = flushBuffers{} + s.mu.Unlock() +} + // --- S3 helpers --- func (s *S3Store) getObject(ctx context.Context, key string) ([]byte, error) { @@ -835,23 +929,36 @@ func decodeS3Headers(data []byte) ([]types.Header, error) { // --- Deduplication --- -func deduplicateBlobs(blobs []types.Blob) []types.Blob { - type blobKey struct { +func mergeUniqueBlobs(blobs []types.Blob) ([]types.Blob, error) { + type positionKey struct { height uint64 namespace types.Namespace index int } - seen := make(map[blobKey]struct{}, len(blobs)) + byPosition := make(map[positionKey]types.Blob, len(blobs)) + byCommitment := make(map[string]types.Blob, len(blobs)) out := make([]types.Blob, 0, len(blobs)) + for _, b := range blobs { - k := blobKey{height: b.Height, namespace: b.Namespace, index: b.Index} - if _, ok := seen[k]; ok { + pos := positionKey{height: b.Height, namespace: b.Namespace, index: b.Index} + if existing, ok := byPosition[pos]; ok { + if !sameBlob(&existing, &b) { + return nil, fmt.Errorf("blob conflict at height %d namespace %s index %d", b.Height, b.Namespace, b.Index) + } continue } - seen[k] = struct{}{} + commitKey := string(b.Commitment) + if existing, ok := byCommitment[commitKey]; ok { + if !sameBlob(&existing, &b) { + return nil, fmt.Errorf("blob commitment conflict for %x", b.Commitment) + } + continue + } + byPosition[pos] = b + byCommitment[commitKey] = b out = append(out, b) } - return out + return out, nil } func deduplicateHeaders(headers []types.Header) []types.Header { diff --git a/pkg/store/s3_test.go b/pkg/store/s3_test.go index 8d64fd4..c161ad1 100644 --- a/pkg/store/s3_test.go +++ b/pkg/store/s3_test.go @@ -3,6 +3,7 @@ package store import ( "bytes" "context" + "errors" "io" "sync" "testing" @@ -16,12 +17,18 @@ import ( // mockS3Client is an in-memory S3 client for testing. type mockS3Client struct { - mu sync.RWMutex - objects map[string][]byte + mu sync.RWMutex + objects map[string][]byte + putErrByKey map[string]error + putCallsByKey map[string]int } func newMockS3Client() *mockS3Client { - return &mockS3Client{objects: make(map[string][]byte)} + return &mockS3Client{ + objects: make(map[string][]byte), + putErrByKey: make(map[string]error), + putCallsByKey: make(map[string]int), + } } func (m *mockS3Client) GetObject(_ context.Context, input *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { @@ -41,11 +48,18 @@ func (m *mockS3Client) PutObject(_ context.Context, input *s3.PutObjectInput, _ m.mu.Lock() defer m.mu.Unlock() + key := *input.Key + m.putCallsByKey[key]++ + if err, ok := m.putErrByKey[key]; ok { + delete(m.putErrByKey, key) + return nil, err + } + data, err := io.ReadAll(input.Body) if err != nil { return nil, err } - m.objects[*input.Key] = data + m.objects[key] = data return &s3.PutObjectOutput{}, nil } @@ -55,6 +69,12 @@ func (m *mockS3Client) objectCount() int { return len(m.objects) } +func (m *mockS3Client) failNextPut(key string, err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.putErrByKey[key] = err +} + func newTestS3Store(t *testing.T) (*S3Store, *mockS3Client) { t.Helper() mock := newMockS3Client() @@ -102,7 +122,7 @@ func TestS3Store_PutBlobsAndGetBlob(t *testing.T) { // Not found. _, err = s.GetBlob(ctx, ns, 99, 0) - if err != ErrNotFound { + if !errors.Is(err, ErrNotFound) { t.Errorf("expected ErrNotFound, got %v", err) } } @@ -196,7 +216,7 @@ func TestS3Store_GetBlobByCommitment(t *testing.T) { // Not found. _, err = s.GetBlobByCommitment(ctx, []byte("nonexistent")) - if err != ErrNotFound { + if !errors.Is(err, ErrNotFound) { t.Errorf("expected ErrNotFound, got %v", err) } } @@ -245,7 +265,7 @@ func TestS3Store_PutHeaderAndGetHeader(t *testing.T) { // Not found. _, err = s.GetHeader(ctx, 999) - if err != ErrNotFound { + if !errors.Is(err, ErrNotFound) { t.Errorf("expected ErrNotFound, got %v", err) } } @@ -256,7 +276,7 @@ func TestS3Store_SyncState(t *testing.T) { // Initially not found. _, err := s.GetSyncState(ctx) - if err != ErrNotFound { + if !errors.Is(err, ErrNotFound) { t.Fatalf("expected ErrNotFound, got %v", err) } @@ -410,6 +430,45 @@ func TestS3Store_IdempotentPut(t *testing.T) { } } +func TestS3Store_RejectsConflictingBufferedBlob(t *testing.T) { + ctx := context.Background() + s, _ := newTestS3Store(t) + ns := testNamespace(1) + + original := types.Blob{Height: 3, Namespace: ns, Data: []byte("d1"), Commitment: []byte("c1"), Index: 0} + conflict := types.Blob{Height: 3, Namespace: ns, Data: []byte("d2"), Commitment: []byte("c2"), Index: 0} + + if err := s.PutBlobs(ctx, []types.Blob{original}); err != nil { + t.Fatalf("PutBlobs (original): %v", err) + } + if err := s.PutBlobs(ctx, []types.Blob{conflict}); err == nil { + t.Fatal("expected conflicting buffered blob insert to fail") + } +} + +func TestS3Store_RejectsConflictingPersistedBlob(t *testing.T) { + ctx := context.Background() + s, _ := newTestS3Store(t) + ns := testNamespace(1) + + original := types.Blob{Height: 3, Namespace: ns, Data: []byte("d1"), Commitment: []byte("c1"), Index: 0} + conflict := types.Blob{Height: 3, Namespace: ns, Data: []byte("d2"), Commitment: []byte("c2"), Index: 0} + + if err := s.PutBlobs(ctx, []types.Blob{original}); err != nil { + t.Fatalf("PutBlobs (original): %v", err) + } + if err := s.SetSyncState(ctx, types.SyncStatus{LatestHeight: 3}); err != nil { + t.Fatalf("SetSyncState (original): %v", err) + } + + if err := s.PutBlobs(ctx, []types.Blob{conflict}); err != nil { + t.Fatalf("PutBlobs (conflict): %v", err) + } + if err := s.SetSyncState(ctx, types.SyncStatus{LatestHeight: 3}); err == nil { + t.Fatal("expected persisted blob conflict to fail during flush") + } +} + func TestS3Store_ChunkBoundary(t *testing.T) { ctx := context.Background() s, _ := newTestS3Store(t) // chunkSize=4 @@ -512,3 +571,47 @@ func TestS3Store_BufferReadThrough(t *testing.T) { t.Errorf("blob data %q, want %q", gotB.Data, "buf") } } + +func TestS3Store_FlushFailureRetainsBufferedData(t *testing.T) { + ctx := context.Background() + s, mock := newTestS3Store(t) + ns := testNamespace(1) + + blob := types.Blob{ + Height: 1, + Namespace: ns, + Data: []byte("retry-me"), + Commitment: []byte("c1"), + Index: 0, + } + if err := s.PutBlobs(ctx, []types.Blob{blob}); err != nil { + t.Fatalf("PutBlobs: %v", err) + } + + blobKey := s.key("blobs", ns.String(), chunkFileName(s.chunkStart(blob.Height))) + mock.failNextPut(blobKey, errors.New("injected put failure")) + + if err := s.SetSyncState(ctx, types.SyncStatus{LatestHeight: blob.Height}); err == nil { + t.Fatal("expected SetSyncState to fail") + } + + got, err := s.GetBlob(ctx, ns, blob.Height, blob.Index) + if err != nil { + t.Fatalf("GetBlob after failed flush: %v", err) + } + if !bytes.Equal(got.Data, blob.Data) { + t.Errorf("got data %q, want %q", got.Data, blob.Data) + } + + if err := s.SetSyncState(ctx, types.SyncStatus{LatestHeight: blob.Height}); err != nil { + t.Fatalf("SetSyncState retry: %v", err) + } + + got, err = s.GetBlob(ctx, ns, blob.Height, blob.Index) + if err != nil { + t.Fatalf("GetBlob after retry: %v", err) + } + if !bytes.Equal(got.Data, blob.Data) { + t.Errorf("got data %q, want %q", got.Data, blob.Data) + } +} diff --git a/pkg/store/sqlite.go b/pkg/store/sqlite.go index 4acba92..8659966 100644 --- a/pkg/store/sqlite.go +++ b/pkg/store/sqlite.go @@ -1,6 +1,7 @@ package store import ( + "bytes" "context" "database/sql" "embed" @@ -12,7 +13,7 @@ import ( "github.com/evstack/apex/pkg/metrics" "github.com/evstack/apex/pkg/types" - _ "modernc.org/sqlite" + _ "modernc.org/sqlite" // registers sqlite driver ) //go:embed migrations/*.sql @@ -44,7 +45,9 @@ func Open(path string) (*SQLiteStore, error) { } writer.SetMaxOpenConns(1) - if err := configureSQLite(writer); err != nil { + ctx := context.Background() + + if err := configureSQLite(ctx, writer); err != nil { _ = writer.Close() return nil, fmt.Errorf("configure writer: %w", err) } @@ -56,7 +59,7 @@ func Open(path string) (*SQLiteStore, error) { } reader.SetMaxOpenConns(poolSize) - if err := configureSQLite(reader); err != nil { + if err := configureSQLite(ctx, reader); err != nil { _ = writer.Close() _ = reader.Close() return nil, fmt.Errorf("configure reader: %w", err) @@ -76,14 +79,14 @@ func (s *SQLiteStore) SetMetrics(m metrics.Recorder) { s.metrics = m } -func configureSQLite(db *sql.DB) error { - if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { +func configureSQLite(ctx context.Context, db *sql.DB) error { + if _, err := db.ExecContext(ctx, "PRAGMA journal_mode=WAL"); err != nil { return fmt.Errorf("set WAL mode: %w", err) } - if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil { + if _, err := db.ExecContext(ctx, "PRAGMA busy_timeout=5000"); err != nil { return fmt.Errorf("set busy_timeout: %w", err) } - if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil { + if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys=ON"); err != nil { return fmt.Errorf("set foreign_keys: %w", err) } return nil @@ -102,8 +105,10 @@ var allMigrations = []migrationStep{ } func (s *SQLiteStore) migrate() error { + ctx := context.Background() + var version int - if err := s.writer.QueryRow("PRAGMA user_version").Scan(&version); err != nil { + if err := s.writer.QueryRowContext(ctx, "PRAGMA user_version").Scan(&version); err != nil { return fmt.Errorf("read user_version: %w", err) } @@ -111,7 +116,7 @@ func (s *SQLiteStore) migrate() error { if version >= m.version { continue } - if err := s.applyMigration(m); err != nil { + if err := s.applyMigration(ctx, m); err != nil { return err } version = m.version @@ -120,22 +125,22 @@ func (s *SQLiteStore) migrate() error { return nil } -func (s *SQLiteStore) applyMigration(m migrationStep) error { +func (s *SQLiteStore) applyMigration(ctx context.Context, m migrationStep) error { ddl, err := migrations.ReadFile(m.file) if err != nil { return fmt.Errorf("read migration %d: %w", m.version, err) } - tx, err := s.writer.Begin() + tx, err := s.writer.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %d tx: %w", m.version, err) } defer tx.Rollback() //nolint:errcheck - if _, err := tx.Exec(string(ddl)); err != nil { + if _, err := tx.ExecContext(ctx, string(ddl)); err != nil { return fmt.Errorf("exec migration %d: %w", m.version, err) } - if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", m.version)); err != nil { + if _, err := tx.ExecContext(ctx, fmt.Sprintf("PRAGMA user_version = %d", m.version)); err != nil { return fmt.Errorf("set user_version to %d: %w", m.version, err) } @@ -165,6 +170,9 @@ func (s *SQLiteStore) PutBlobs(ctx context.Context, blobs []types.Blob) error { for i := range blobs { b := &blobs[i] + if err := ensureSQLiteBlobInvariant(ctx, tx, b); err != nil { + return err + } if _, err := stmt.ExecContext(ctx, b.Height, b.Namespace[:], b.Commitment, b.Data, b.ShareVersion, b.Signer, b.Index, ); err != nil { @@ -358,3 +366,62 @@ func scanBlobRow(rows *sql.Rows) (types.Blob, error) { copy(b.Namespace[:], nsBytes) return b, nil } + +func ensureSQLiteBlobInvariant(ctx context.Context, tx *sql.Tx, b *types.Blob) error { + existingByIndex, err := queryBlobByIndex(ctx, tx, b.Namespace, b.Height, b.Index) + if err != nil { + return err + } + if existingByIndex != nil && !sameBlob(existingByIndex, b) { + return fmt.Errorf("blob conflict at height %d namespace %s index %d", b.Height, b.Namespace, b.Index) + } + + existingByCommitment, err := queryBlobByCommitment(ctx, tx, b.Commitment) + if err != nil { + return err + } + if existingByCommitment != nil && !sameBlob(existingByCommitment, b) { + return fmt.Errorf("blob commitment conflict for %x", b.Commitment) + } + + return nil +} + +func queryBlobByIndex(ctx context.Context, tx *sql.Tx, ns types.Namespace, height uint64, index int) (*types.Blob, error) { + row := tx.QueryRowContext(ctx, + `SELECT height, namespace, commitment, data, share_version, signer, blob_index + FROM blobs WHERE namespace = ? AND height = ? AND blob_index = ?`, + ns[:], height, index) + b, err := scanBlob(row) + if errors.Is(err, ErrNotFound) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("query blob by index: %w", err) + } + return b, nil +} + +func queryBlobByCommitment(ctx context.Context, tx *sql.Tx, commitment []byte) (*types.Blob, error) { + row := tx.QueryRowContext(ctx, + `SELECT height, namespace, commitment, data, share_version, signer, blob_index + FROM blobs WHERE commitment = ? LIMIT 1`, commitment) + b, err := scanBlob(row) + if errors.Is(err, ErrNotFound) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("query blob by commitment: %w", err) + } + return b, nil +} + +func sameBlob(a, b *types.Blob) bool { + return a.Height == b.Height && + a.Namespace == b.Namespace && + a.Index == b.Index && + a.ShareVersion == b.ShareVersion && + bytes.Equal(a.Commitment, b.Commitment) && + bytes.Equal(a.Data, b.Data) && + bytes.Equal(a.Signer, b.Signer) +} diff --git a/pkg/store/sqlite_test.go b/pkg/store/sqlite_test.go index 5aa54f1..ceb2a78 100644 --- a/pkg/store/sqlite_test.go +++ b/pkg/store/sqlite_test.go @@ -205,6 +205,50 @@ func TestPutBlobsIdempotent(t *testing.T) { } } +func TestPutBlobsRejectsConflictingIndex(t *testing.T) { + s := openTestDB(t) + ctx := context.Background() + ns := testNamespace(1) + + original := types.Blob{ + Height: 10, Namespace: ns, Commitment: []byte("c1"), + Data: []byte("d1"), ShareVersion: 0, Index: 0, + } + conflict := types.Blob{ + Height: 10, Namespace: ns, Commitment: []byte("c2"), + Data: []byte("d2"), ShareVersion: 0, Index: 0, + } + + if err := s.PutBlobs(ctx, []types.Blob{original}); err != nil { + t.Fatalf("PutBlobs (original): %v", err) + } + if err := s.PutBlobs(ctx, []types.Blob{conflict}); err == nil { + t.Fatal("expected conflicting blob insert to fail") + } +} + +func TestPutBlobsRejectsConflictingCommitment(t *testing.T) { + s := openTestDB(t) + ctx := context.Background() + ns := testNamespace(1) + + original := types.Blob{ + Height: 10, Namespace: ns, Commitment: []byte("c1"), + Data: []byte("d1"), ShareVersion: 0, Index: 0, + } + conflict := types.Blob{ + Height: 10, Namespace: ns, Commitment: []byte("c1"), + Data: []byte("d1"), ShareVersion: 0, Index: 1, + } + + if err := s.PutBlobs(ctx, []types.Blob{original}); err != nil { + t.Fatalf("PutBlobs (original): %v", err) + } + if err := s.PutBlobs(ctx, []types.Blob{conflict}); err == nil { + t.Fatal("expected conflicting commitment insert to fail") + } +} + func TestPutHeaderIdempotent(t *testing.T) { s := openTestDB(t) ctx := context.Background() diff --git a/pkg/sync/subscription.go b/pkg/sync/subscription.go index ecb48ed..3394671 100644 --- a/pkg/sync/subscription.go +++ b/pkg/sync/subscription.go @@ -63,33 +63,56 @@ func (sm *SubscriptionManager) Run(ctx context.Context) error { Msg("streaming progress") processed = 0 case hdr, ok := <-ch: - if !ok { - // Channel closed (disconnect or ctx cancelled). - if ctx.Err() != nil { - return nil - } - return fmt.Errorf("header subscription closed unexpectedly") - } - - // Check contiguity. - if lastHeight > 0 && hdr.Height != lastHeight+1 { - sm.log.Warn(). - Uint64("expected", lastHeight+1). - Uint64("got", hdr.Height). - Msg("gap detected") - return ErrGapDetected + nextNetworkHeight, err := sm.handleHeader(ctx, hdr, ok, lastHeight, networkHeight, namespaces) + if err != nil { + return err } - - if err := sm.processHeader(ctx, hdr, namespaces, networkHeight); err != nil { - return fmt.Errorf("process height %d: %w", hdr.Height, err) + if !ok { + return nil // channel closed; handleHeader already logged/validated } - + networkHeight = nextNetworkHeight lastHeight = hdr.Height processed++ } } } +func (sm *SubscriptionManager) handleHeader(ctx context.Context, hdr *types.Header, ok bool, lastHeight, networkHeight uint64, namespaces []types.Namespace) (uint64, error) { + if !ok { + // Channel closed (disconnect or ctx cancelled). + if ctx.Err() != nil { + return networkHeight, nil //nolint:nilerr // context cancellation is a clean shutdown, not an error + } + return networkHeight, errors.New("header subscription closed unexpectedly") + } + + if err := sm.checkContiguous(lastHeight, hdr.Height); err != nil { + return networkHeight, err + } + + if hdr.Height > networkHeight { + networkHeight = hdr.Height + } + + if err := sm.processHeader(ctx, hdr, namespaces, networkHeight); err != nil { + return networkHeight, fmt.Errorf("process height %d: %w", hdr.Height, err) + } + + return networkHeight, nil +} + +func (sm *SubscriptionManager) checkContiguous(lastHeight, nextHeight uint64) error { + if lastHeight == 0 || nextHeight == lastHeight+1 { + return nil + } + + sm.log.Warn(). + Uint64("expected", lastHeight+1). + Uint64("got", nextHeight). + Msg("gap detected") + return ErrGapDetected +} + func (sm *SubscriptionManager) processHeader(ctx context.Context, hdr *types.Header, namespaces []types.Namespace, networkHeight uint64) error { if err := sm.store.PutHeader(ctx, hdr); err != nil { return fmt.Errorf("put header: %w", err) diff --git a/pkg/sync/subscription_test.go b/pkg/sync/subscription_test.go new file mode 100644 index 0000000..69342be --- /dev/null +++ b/pkg/sync/subscription_test.go @@ -0,0 +1,70 @@ +package syncer + +import ( + "context" + "testing" + "time" + + "github.com/rs/zerolog" + + "github.com/evstack/apex/pkg/types" +) + +func TestSubscriptionManagerUpdatesNetworkHeightFromStream(t *testing.T) { + st := newMockStore() + ns := types.Namespace{0: 1} + if err := st.PutNamespace(context.Background(), ns); err != nil { + t.Fatalf("PutNamespace: %v", err) + } + if err := st.SetSyncState(context.Background(), types.SyncStatus{ + State: types.Streaming, + LatestHeight: 5, + NetworkHeight: 5, + }); err != nil { + t.Fatalf("SetSyncState: %v", err) + } + + ft := newMockFetcher(5) + subCh := make(chan *types.Header, 1) + ft.mu.Lock() + ft.subCh = subCh + ft.mu.Unlock() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + sm := &SubscriptionManager{ + store: st, + fetcher: ft, + log: zerolog.Nop(), + } + go func() { + done <- sm.Run(ctx) + }() + + subCh <- makeHeader(6) + + deadline := time.After(2 * time.Second) + for { + ss, err := st.GetSyncState(context.Background()) + if err == nil && ss.LatestHeight == 6 { + if ss.NetworkHeight != 6 { + t.Fatalf("NetworkHeight = %d, want 6", ss.NetworkHeight) + } + cancel() + break + } + + select { + case <-deadline: + t.Fatal("timed out waiting for sync state update") + default: + time.Sleep(10 * time.Millisecond) + } + } + + if err := <-done; err != nil { + t.Fatalf("Run: %v", err) + } +}