From 925ef3b41bc78244bd711f2050fb90ef08d3b2dd Mon Sep 17 00:00:00 2001 From: Bui Quang Minh Date: Thu, 14 Nov 2024 14:02:46 +0700 Subject: [PATCH 1/2] consortium-v2: add unit test for FindAncientHeader We remove the hash check in FindAncientHeader to make it easier to write unit test. This is fine because the parents are guaranteed to be ordered and linked by the check when InsertChain. This is a preparation commit for refactoring this function. --- consensus/consortium/v2/snapshot.go | 3 +- consensus/consortium/v2/snapshot_test.go | 93 ++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 consensus/consortium/v2/snapshot_test.go diff --git a/consensus/consortium/v2/snapshot.go b/consensus/consortium/v2/snapshot.go index 322f42de5..664e69bba 100644 --- a/consensus/consortium/v2/snapshot.go +++ b/consensus/consortium/v2/snapshot.go @@ -437,8 +437,7 @@ func FindAncientHeader(header *types.Header, ite uint64, chain consensus.ChainHe index := sort.Search(len(candidateParents), func(i int) bool { return candidateParents[i].Number.Uint64() >= parentHeight }) - if index < len(candidateParents) && candidateParents[index].Number.Uint64() == parentHeight && - candidateParents[index].Hash() == parentHash { + if index < len(candidateParents) && candidateParents[index].Number.Uint64() == parentHeight { ancient = candidateParents[index] found = true } diff --git a/consensus/consortium/v2/snapshot_test.go b/consensus/consortium/v2/snapshot_test.go new file mode 100644 index 000000000..2de1ab932 --- /dev/null +++ b/consensus/consortium/v2/snapshot_test.go @@ -0,0 +1,93 @@ +package v2 + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/params" +) + +type mockChainReader struct { + headerMapping map[common.Hash]*types.Header +} + +func (chainReader *mockChainReader) Config() *params.ChainConfig { return nil } +func (chainReader *mockChainReader) CurrentHeader() *types.Header { return nil } +func (chainReader *mockChainReader) GetHeader(hash common.Hash, number uint64) *types.Header { + return chainReader.headerMapping[hash] +} +func (chainReader *mockChainReader) GetHeaderByNumber(number uint64) *types.Header { return nil } +func (chainReader *mockChainReader) GetHeaderByHash(hash common.Hash) *types.Header { return nil } +func (chainReader *mockChainReader) DB() ethdb.Database { return nil } +func (chainReader *mockChainReader) StateCache() state.Database { return nil } +func (chainReader *mockChainReader) OpEvents() []*vm.PublishEvent { return nil } + +func TestFindCheckpointHeader(t *testing.T) { + // Case 1: checkpoint header is at block 5 (in parent list) + parents := make([]*types.Header, 10) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))} + } + + currentHeader := &types.Header{Number: big.NewInt(10)} + checkpointHeader := FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, nil, parents) + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.Coinbase != common.BigToAddress(big.NewInt(5)) { + t.Fatalf("Expect checkpoint header number: %d, got: %d", 5, checkpointHeader.Number.Int64()) + } + + // Case 2: checkpoint header is at 5 (lower than parent list) + // parent list ranges from [10, 20) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i + 10)), ParentHash: common.BigToHash(big.NewInt(int64(i + 10 - 1)))} + } + mockChain := mockChainReader{ + headerMapping: make(map[common.Hash]*types.Header), + } + // create mock chain 1 + for i := 5; i < 10; i++ { + mockChain.headerMapping[common.BigToHash(big.NewInt(int64(100+i)))] = &types.Header{ + Number: big.NewInt(int64(i)), + ParentHash: common.BigToHash(big.NewInt(int64(100 + i - 1))), + } + } + + // create mock chain 2 + for i := 5; i < 10; i++ { + mockChain.headerMapping[common.BigToHash(big.NewInt(int64(i)))] = &types.Header{ + Number: big.NewInt(int64(i)), + ParentHash: common.BigToHash(big.NewInt(int64(i - 1))), + } + } + + currentHeader = &types.Header{ParentHash: common.BigToHash(big.NewInt(19)), Number: big.NewInt(20)} + // Must traverse and get the correct header in chain 2 + checkpointHeader = FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, &mockChain, parents) + if checkpointHeader == nil { + t.Fatal("Failed to find checkpoint header") + } + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(4))) { + t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s", + 5, common.BigToHash(big.NewInt(int64(4))), + checkpointHeader.Number.Int64(), checkpointHeader.ParentHash, + ) + } + + // Case 3: find checkpoint header with nil parent list + currentHeader = &types.Header{Number: big.NewInt(10), ParentHash: common.BigToHash(big.NewInt(109))} + checkpointHeader = FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, &mockChain, nil) + // Must traverse and get the correct header in chain 1 + if checkpointHeader == nil { + t.Fatal("Failed to find checkpoint header") + } + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(104))) { + t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s", + 5, common.BigToHash(big.NewInt(int64(104))), + checkpointHeader.Number.Int64(), checkpointHeader.ParentHash, + ) + } +} From 9e7c9e42675b133cdb9f0b8074a397f90d8ce466 Mon Sep 17 00:00:00 2001 From: Bui Quang Minh Date: Thu, 14 Nov 2024 14:12:17 +0700 Subject: [PATCH 2/2] consortium-v2/snapshot: make FindAncientHeader more readable This commit refactors FindAncientHeader, changes its name to findAncestorHeader, adds some comments and unit test to make the code more readable. --- consensus/consortium/v2/snapshot.go | 82 +++++++++++++++--------- consensus/consortium/v2/snapshot_test.go | 19 +++++- 2 files changed, 69 insertions(+), 32 deletions(-) diff --git a/consensus/consortium/v2/snapshot.go b/consensus/consortium/v2/snapshot.go index 664e69bba..610233d2d 100644 --- a/consensus/consortium/v2/snapshot.go +++ b/consensus/consortium/v2/snapshot.go @@ -16,6 +16,7 @@ import ( blsCommon "github.com/ethereum/go-ethereum/crypto/bls/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/internal/ethapi" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" "github.com/hashicorp/golang-lru/arc/v2" ) @@ -243,7 +244,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainHeaderRea // Change the validator set base on the size of the validators set if number > 0 && number%s.config.EpochV2 == uint64(len(snap.validators())/2) { // Get the most recent checkpoint header - checkpointHeader := FindAncientHeader(header, uint64(len(snap.validators())/2), chain, parents) + checkpointHeader := findAncestorHeader(header, number-uint64(len(snap.validators())/2), chain, parents) if checkpointHeader == nil { return nil, consensus.ErrUnknownAncestor } @@ -420,35 +421,58 @@ func (s *Snapshot) IsRecentlySigned(validator common.Address) bool { return false } -// FindAncientHeader finds the most recent checkpoint header -// Travel through the candidateParents to find the ancient header. -// If all headers in candidateParents have the number is larger than the header number, -// the search function will return the index, but it is not valid if we check with the -// header since the number and hash is not equals. The candidateParents is -// only available when it downloads blocks from the network. -// Otherwise, the candidateParents is nil, and it will be found by header hash and number. -func FindAncientHeader(header *types.Header, ite uint64, chain consensus.ChainHeaderReader, candidateParents []*types.Header) *types.Header { - ancient := header - for i := uint64(1); i <= ite; i++ { - parentHash := ancient.ParentHash - parentHeight := ancient.Number.Uint64() - 1 - found := false - if len(candidateParents) > 0 { - index := sort.Search(len(candidateParents), func(i int) bool { - return candidateParents[i].Number.Uint64() >= parentHeight - }) - if index < len(candidateParents) && candidateParents[index].Number.Uint64() == parentHeight { - ancient = candidateParents[index] - found = true - } - } - if !found { - ancient = chain.GetHeader(parentHash, parentHeight) - found = true +// findAncestorHeader traverses back to look for the requested ancestor header +// in parents list or in chaindata +// +// parents are guaranteed to be ordered and linked by the check when InsertChain +// +// There are 2 possible cases: +// Case 1: ancestor header is in parents list +// <- parents -> +// [ ancestorHeader ] +// +// Case 2: ancestor header's height is lower than parents list +// <- parents -> +// ancestorHeader ... [ ] + +func findAncestorHeader( + currentHeader *types.Header, + ancestorBlockNumber uint64, + chain consensus.ChainHeaderReader, + parents []*types.Header, +) *types.Header { + // Find the first header in parents list that is higher or equal to checkpoint block + index := sort.Search(len(parents), func(i int) bool { + return parents[i].Number.Uint64() >= ancestorBlockNumber + }) + + // This must not happen, checkpoint header's height cannot be higher the parents list + if len(parents) != 0 && index >= len(parents) { + log.Warn( + "Checkpoint header's height is higher than parents list", + "checkpointNumber", ancestorBlockNumber, + "last parent", parents[len(parents)-1].Number, + ) + return nil + } + + if len(parents) != 0 && parents[index].Number.Uint64() == ancestorBlockNumber { + // Case 1: checkpoint header is in parents list + return parents[index] + } else { + // Case 2: checkpoint header's height is lower than parents list + var headerIterator *types.Header + if len(parents) != 0 { + headerIterator = parents[0] + } else { + headerIterator = currentHeader } - if ancient == nil || !found { - return nil + for headerIterator.Number.Uint64() != ancestorBlockNumber { + headerIterator = chain.GetHeader(headerIterator.ParentHash, headerIterator.Number.Uint64()-1) + if headerIterator == nil { + return nil + } } + return headerIterator } - return ancient } diff --git a/consensus/consortium/v2/snapshot_test.go b/consensus/consortium/v2/snapshot_test.go index 2de1ab932..0732fa3e7 100644 --- a/consensus/consortium/v2/snapshot_test.go +++ b/consensus/consortium/v2/snapshot_test.go @@ -29,13 +29,14 @@ func (chainReader *mockChainReader) OpEvents() []*vm.PublishEvent func TestFindCheckpointHeader(t *testing.T) { // Case 1: checkpoint header is at block 5 (in parent list) + // parent list ranges from [0, 10) parents := make([]*types.Header, 10) for i := range parents { parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))} } currentHeader := &types.Header{Number: big.NewInt(10)} - checkpointHeader := FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, nil, parents) + checkpointHeader := findAncestorHeader(currentHeader, 5, nil, parents) if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.Coinbase != common.BigToAddress(big.NewInt(5)) { t.Fatalf("Expect checkpoint header number: %d, got: %d", 5, checkpointHeader.Number.Int64()) } @@ -66,7 +67,7 @@ func TestFindCheckpointHeader(t *testing.T) { currentHeader = &types.Header{ParentHash: common.BigToHash(big.NewInt(19)), Number: big.NewInt(20)} // Must traverse and get the correct header in chain 2 - checkpointHeader = FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, &mockChain, parents) + checkpointHeader = findAncestorHeader(currentHeader, 5, &mockChain, parents) if checkpointHeader == nil { t.Fatal("Failed to find checkpoint header") } @@ -79,7 +80,7 @@ func TestFindCheckpointHeader(t *testing.T) { // Case 3: find checkpoint header with nil parent list currentHeader = &types.Header{Number: big.NewInt(10), ParentHash: common.BigToHash(big.NewInt(109))} - checkpointHeader = FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, &mockChain, nil) + checkpointHeader = findAncestorHeader(currentHeader, 5, &mockChain, nil) // Must traverse and get the correct header in chain 1 if checkpointHeader == nil { t.Fatal("Failed to find checkpoint header") @@ -90,4 +91,16 @@ func TestFindCheckpointHeader(t *testing.T) { checkpointHeader.Number.Int64(), checkpointHeader.ParentHash, ) } + + // Case 4: checkpoint header is higher than parent list, this must not happen + // but the function must not crash in this case + // parent list ranges from [0, 10) + parents = make([]*types.Header, 10) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))} + } + checkpointHeader = findAncestorHeader(nil, 10, nil, parents) + if checkpointHeader != nil { + t.Fatalf("Expect %v checkpoint header, got %v", nil, checkpointHeader) + } }