Skip to content

Commit

Permalink
[BUG] Make memberlist use ips for routing (#3405)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- The memberlist uses k8s dns based routing for the stateful set.
However, this can result in delays when the routing updates. This
propagates the ip in the memberlist which is much faster.
   - This change is backwards compatible
- CRDs fields are optional by default, so not setting the ip is fine if
CRD is updated
       - CRD fields can be set in the code without the CRD being updated
- The python frontend will use the ip if present, otherwise it falls
back to the id
- The go code will read the ip to "" anywhere it expects it to be set if
the old version of the CR is used
 - New functionality
   - ...

## Test plan
*How are these changes tested?*
Added a test which adds data to a collection, kills the query service
pods and waits for them to be ready, simulating a roll out of sorts.
Then it issues another query and make sure the query succeeds with
updated routing. Before this change, this test failed.
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored Jan 8, 2025
1 parent f987ba6 commit 74146be
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 38 deletions.
6 changes: 1 addition & 5 deletions chromadb/execution/executor/distributed.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from typing import Dict, Optional

import grpc
from overrides import overrides

from chromadb.api.types import GetResult, Metadata, QueryResult
from chromadb.config import System
from chromadb.errors import VersionMismatchError
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.operator import Scan
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.proto import convert

from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
Expand Down Expand Up @@ -170,6 +166,6 @@ def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub:
channel = grpc.insecure_channel(grpc_url)
interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
channel = grpc.intercept_channel(channel, *interceptors)
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) # type: ignore[no-untyped-call]
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel)

return self._grpc_stub_pool[grpc_url]
9 changes: 8 additions & 1 deletion chromadb/segment/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, List

from overrides import EnforceOverrides, overrides
Expand All @@ -22,7 +23,13 @@ def register_updated_segment_callback(
pass


Memberlist = List[str]
@dataclass
class Member:
id: str
ip: str


Memberlist = List[Member]


class MemberlistProvider(Component, EnforceOverrides):
Expand Down
36 changes: 30 additions & 6 deletions chromadb/segment/impl/distributed/segment_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from chromadb.config import System
from chromadb.segment.distributed import (
Member,
Memberlist,
MemberlistProvider,
SegmentDirectory,
Expand Down Expand Up @@ -35,7 +36,11 @@ class MockMemberlistProvider(MemberlistProvider, EnforceOverrides):

def __init__(self, system: System):
super().__init__(system)
self._memberlist = ["a", "b", "c"]
self._memberlist = [
Member(id="a", ip="10.0.0.1"),
Member(id="b", ip="10.0.0.2"),
Member(id="c", ip="10.0.0.3"),
]

@override
def get_memberlist(self) -> Memberlist:
Expand Down Expand Up @@ -203,7 +208,12 @@ def _parse_response_memberlist(
) -> Memberlist:
if "members" not in api_response_spec:
return []
return [m["member_id"] for m in api_response_spec["members"]]
parsed = []
for m in api_response_spec["members"]:
id = m["member_id"]
ip = m["member_ip"] if "member_ip" in m else ""
parsed.append(Member(id=id, ip=ip))
return parsed

def _notify(self, memberlist: Memberlist) -> None:
for callback in self.callbacks:
Expand Down Expand Up @@ -245,11 +255,23 @@ def get_segment_endpoint(self, segment: Segment) -> str:
raise ValueError("Memberlist is not initialized")
# Query to the same collection should end up on the same endpoint
assignment = assign(
segment["collection"].hex, self._curr_memberlist, murmur3hasher, 1
segment["collection"].hex,
[m.id for m in self._curr_memberlist],
murmur3hasher,
1,
)[0]
service_name = self.extract_service_name(assignment)
assignment = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable
return assignment

# If the memberlist has an ip, use it, otherwise use the member id with the headless service
# this is for backwards compatibility with the old memberlist which only had ids
for member in self._curr_memberlist:
if member.id == assignment:
if member.ip is not None and member.ip != "":
endpoint = f"{member.ip}:50051"
return endpoint

endpoint = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable
return endpoint

@override
def register_updated_segment_callback(
Expand All @@ -263,7 +285,9 @@ def register_updated_segment_callback(
)
def _update_memberlist(self, memberlist: Memberlist) -> None:
with self._curr_memberlist_mutex:
add_attributes_to_current_span({"new_memberlist": memberlist})
add_attributes_to_current_span(
{"new_memberlist": [m.id for m in memberlist]}
)
self._curr_memberlist = memberlist

def extract_service_name(self, pod_name: str) -> Optional[str]:
Expand Down
13 changes: 8 additions & 5 deletions chromadb/segment/impl/manager/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,29 @@
from chromadb.segment.distributed import SegmentDirectory
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import Collection, CollectionAndSegments, Operation, Segment, SegmentScope
from chromadb.types import (
Collection,
Operation,
Segment,
SegmentScope,
)


class DistributedSegmentManager(SegmentManager):
_sysdb: SysDB
_system: System
_opentelemetry_client: OpenTelemetryClient
_instances: Dict[UUID, SegmentImplementation]
_segment_directory: SegmentDirectory
_lock: Lock
# _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub

def __init__(self, system: System):
super().__init__(system)
self._sysdb = self.require(SysDB)
self._segment_directory = self.require(SegmentDirectory)
self._system = system
self._opentelemetry_client = system.require(OpenTelemetryClient)
self._instances = {}
self._lock = Lock()

Expand Down Expand Up @@ -77,6 +78,8 @@ def prepare_segments_for_new_collection(

@override
def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
# TODO: this should be a pass, delete_collection is expected to delete segments in
# distributed
segments = self._sysdb.get_segments(collection=collection_id)
return [s["id"] for s in segments]

Expand Down
74 changes: 74 additions & 0 deletions chromadb/test/distributed/test_reroute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Sequence
from chromadb.test.conftest import (
reset,
skip_if_not_cluster,
)
from chromadb.api import ClientAPI
from kubernetes import client as k8s_client, config
import time


@skip_if_not_cluster()
def test_reroute(
client: ClientAPI,
) -> None:
reset(client)
collection = client.create_collection(
name="test",
metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128},
)

ids = [str(i) for i in range(10)]
embeddings: list[Sequence[float]] = [
[float(i), float(i), float(i)] for i in range(10)
]
collection.add(ids=ids, embeddings=embeddings)
collection.query(query_embeddings=[embeddings[0]])

# Restart the query service using k8s api, in order to trigger a reroute
# of the query service
config.load_kube_config()
v1 = k8s_client.CoreV1Api()
# Find all pods with the label "app=query"
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service")
assert len(res.items) > 0
items = res.items
seen_ids = set()

# Restart all the pods by deleting them
for item in items:
seen_ids.add(item.metadata.uid)
name = item.metadata.name
namespace = item.metadata.namespace
v1.delete_namespaced_pod(name, namespace)

# Wait until we have len(seen_ids) pods running with new UIDs
timeout_secs = 10
start_time = time.time()
while True:
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service")
items = res.items
new_ids = set([item.metadata.uid for item in items])
if len(new_ids) == len(seen_ids) and len(new_ids.intersection(seen_ids)) == 0:
break
if time.time() - start_time > timeout_secs:
assert False, "Timed out waiting for new pods to start"
time.sleep(1)

# Wait for the query service to be ready, or timeout
while True:
res = v1.list_namespaced_pod("chroma", label_selector="app=query-service")
items = res.items
ready = True
for item in items:
if item.status.phase != "Running":
ready = False
break
if ready:
break
if time.time() - start_time > timeout_secs:
assert False, "Timed out waiting for new pods to be ready"
time.sleep(1)

time.sleep(1)
collection.query(query_embeddings=[embeddings[0]])
13 changes: 8 additions & 5 deletions chromadb/test/segment/distributed/test_memberlist_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Tests the CustomResourceMemberlist provider
from dataclasses import asdict
import threading
from chromadb.test.conftest import skip_if_not_cluster
from kubernetes import client, config
from chromadb.config import System, Settings
from chromadb.segment.distributed import Memberlist
from chromadb.segment.distributed import Memberlist, Member
from chromadb.segment.impl.distributed.segment_directory import (
CustomResourceMemberlistProvider,
KUBERNETES_GROUP,
Expand All @@ -17,12 +18,12 @@ def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Membe
config.load_config()
api_instance = client.CustomObjectsApi()

members = [{"member_id": f"test-{i}"} for i in range(1, n + 1)]
members = [Member(id=f"test-{i}", ip=f"10.0.0.{i}") for i in range(1, n + 1)]

body = {
"kind": "MemberList",
"metadata": {"name": memberlist_name},
"spec": {"members": members},
"spec": {"members": [{"member_id": m.id, "member_ip": m.ip} for m in members]},
}

_ = api_instance.patch_namespaced_custom_object(
Expand All @@ -34,11 +35,13 @@ def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Membe
body=body,
)

return [m["member_id"] for m in members]
return members


def compare_memberlists(m1: Memberlist, m2: Memberlist) -> bool:
return sorted(m1) == sorted(m2)
m1_as_dict = sorted([asdict(m) for m in m1], key=lambda x: x["id"])
m2_as_dict = sorted([asdict(m) for m in m2], key=lambda x: x["id"])
return m1_as_dict == m2_as_dict


@skip_if_not_cluster()
Expand Down
10 changes: 10 additions & 0 deletions go/pkg/memberlist_manager/memberlist_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ func memberlistSame(oldMemberlist Memberlist, newMemberlist Memberlist) bool {
if len(oldMemberlist) != len(newMemberlist) {
return false
}
oldMemberlistIps := make(map[string]string)
for _, member := range oldMemberlist {
oldMemberlistIps[member.id] = member.ip
}
for _, member := range newMemberlist {
if ip, ok := oldMemberlistIps[member.id]; !ok || ip != member.ip {
return false
}
}

// use a map to check if the new memberlist contains all the old members
newMemberlistMap := make(map[string]bool)
for _, member := range newMemberlist {
Expand Down
26 changes: 13 additions & 13 deletions go/pkg/memberlist_manager/memberlist_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestNodeWatcher(t *testing.T) {
t.Fatalf("Error getting node status: %v", err)
}

return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0"}})
return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}})
}, 10, 1*time.Second)
if !ok {
t.Fatalf("Node status did not update after adding a pod")
Expand Down Expand Up @@ -83,7 +83,7 @@ func TestNodeWatcher(t *testing.T) {
if err != nil {
t.Fatalf("Error getting node status: %v", err)
}
return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0"}})
return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}})
}, 10, 1*time.Second)
if !ok {
t.Fatalf("Node status did not update after adding a not ready pod")
Expand All @@ -108,13 +108,13 @@ func TestMemberlistStore(t *testing.T) {
assert.Equal(t, Memberlist{}, memberlist)

// Add a member to the memberlist
memberlist_store.UpdateMemberlist(context.Background(), Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}, "0")
memberlist_store.UpdateMemberlist(context.Background(), Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}, "0")
memberlist, _, err = memberlist_store.GetMemberlist(context.Background())
if err != nil {
t.Fatalf("Error getting memberlist: %v", err)
}
// assert the memberlist has the correct members
if !memberlistSame(memberlist, Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}) {
if !memberlistSame(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}) {
t.Fatalf("Memberlist did not update after adding a member")
}
}
Expand Down Expand Up @@ -184,7 +184,7 @@ func TestMemberlistManager(t *testing.T) {

// Get the memberlist
ok := retryUntilCondition(func() bool {
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0"}})
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.49"}})
}, 30, 1*time.Second)
if !ok {
t.Fatalf("Memberlist did not update after adding a pod")
Expand All @@ -195,7 +195,7 @@ func TestMemberlistManager(t *testing.T) {

// Get the memberlist
ok = retryUntilCondition(func() bool {
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}})
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.49"}, Member{id: "test-pod-1", ip: "10.0.0.50"}})
}, 30, 1*time.Second)
if !ok {
t.Fatalf("Memberlist did not update after adding a pod")
Expand All @@ -206,7 +206,7 @@ func TestMemberlistManager(t *testing.T) {

// Get the memberlist
ok = retryUntilCondition(func() bool {
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-1"}})
return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-1", ip: "10.0.0.50"}})
}, 30, 1*time.Second)
if !ok {
t.Fatalf("Memberlist did not update after deleting a pod")
Expand All @@ -217,23 +217,23 @@ func TestMemberlistSame(t *testing.T) {
memberlist := Memberlist{}
assert.True(t, memberlistSame(memberlist, memberlist))

newMemberlist := Memberlist{Member{id: "test-pod-0"}}
newMemberlist := Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}}
assert.False(t, memberlistSame(memberlist, newMemberlist))
assert.False(t, memberlistSame(newMemberlist, memberlist))
assert.True(t, memberlistSame(newMemberlist, newMemberlist))

memberlist = Memberlist{Member{id: "test-pod-1"}}
memberlist = Memberlist{Member{id: "test-pod-1", ip: "10.0.0.2"}}
assert.False(t, memberlistSame(newMemberlist, memberlist))
assert.False(t, memberlistSame(memberlist, newMemberlist))
assert.True(t, memberlistSame(memberlist, memberlist))

memberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}
newMemberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}
memberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}
newMemberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}
assert.True(t, memberlistSame(memberlist, newMemberlist))
assert.True(t, memberlistSame(newMemberlist, memberlist))

memberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}
newMemberlist = Memberlist{Member{id: "test-pod-1"}, Member{id: "test-pod-0"}}
memberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}
newMemberlist = Memberlist{Member{id: "test-pod-1", ip: "10.0.0.2"}, Member{id: "test-pod-0", ip: "10.0.0.1"}}
assert.True(t, memberlistSame(memberlist, newMemberlist))
assert.True(t, memberlistSame(newMemberlist, memberlist))
}
Expand Down
Loading

0 comments on commit 74146be

Please sign in to comment.