diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index 3cf5c591c77..6bf795540c1 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -1,16 +1,12 @@ from typing import Dict, Optional - import grpc from overrides import overrides - from chromadb.api.types import GetResult, Metadata, QueryResult from chromadb.config import System -from chromadb.errors import VersionMismatchError from chromadb.execution.executor.abstract import Executor from chromadb.execution.expression.operator import Scan from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.proto import convert - from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor from chromadb.segment.impl.manager.distributed import DistributedSegmentManager @@ -170,6 +166,6 @@ def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub: channel = grpc.insecure_channel(grpc_url) interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()] channel = grpc.intercept_channel(channel, *interceptors) - self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) # type: ignore[no-untyped-call] + self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) return self._grpc_stub_pool[grpc_url] diff --git a/chromadb/segment/distributed/__init__.py b/chromadb/segment/distributed/__init__.py index 08efdafd18c..049c3b54b62 100644 --- a/chromadb/segment/distributed/__init__.py +++ b/chromadb/segment/distributed/__init__.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from dataclasses import dataclass from typing import Any, Callable, List from overrides import EnforceOverrides, overrides @@ -22,7 +23,13 @@ def register_updated_segment_callback( pass -Memberlist = List[str] +@dataclass +class Member: + id: str + ip: str + + +Memberlist = List[Member] class MemberlistProvider(Component, EnforceOverrides): diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py index 12a6b35fa7d..097b3b38566 100644 --- a/chromadb/segment/impl/distributed/segment_directory.py +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -8,6 +8,7 @@ from chromadb.config import System from chromadb.segment.distributed import ( + Member, Memberlist, MemberlistProvider, SegmentDirectory, @@ -35,7 +36,11 @@ class MockMemberlistProvider(MemberlistProvider, EnforceOverrides): def __init__(self, system: System): super().__init__(system) - self._memberlist = ["a", "b", "c"] + self._memberlist = [ + Member(id="a", ip="10.0.0.1"), + Member(id="b", ip="10.0.0.2"), + Member(id="c", ip="10.0.0.3"), + ] @override def get_memberlist(self) -> Memberlist: @@ -203,7 +208,12 @@ def _parse_response_memberlist( ) -> Memberlist: if "members" not in api_response_spec: return [] - return [m["member_id"] for m in api_response_spec["members"]] + parsed = [] + for m in api_response_spec["members"]: + id = m["member_id"] + ip = m["member_ip"] if "member_ip" in m else "" + parsed.append(Member(id=id, ip=ip)) + return parsed def _notify(self, memberlist: Memberlist) -> None: for callback in self.callbacks: @@ -245,11 +255,23 @@ def get_segment_endpoint(self, segment: Segment) -> str: raise ValueError("Memberlist is not initialized") # Query to the same collection should end up on the same endpoint assignment = assign( - segment["collection"].hex, self._curr_memberlist, murmur3hasher, 1 + segment["collection"].hex, + [m.id for m in self._curr_memberlist], + murmur3hasher, + 1, )[0] service_name = self.extract_service_name(assignment) - assignment = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable - return assignment + + # If the memberlist has an ip, use it, otherwise use the member id with the headless service + # this is for backwards compatibility with the old memberlist which only had ids + for member in self._curr_memberlist: + if member.id == assignment: + if member.ip is not None and member.ip != "": + endpoint = f"{member.ip}:50051" + return endpoint + + endpoint = f"{assignment}.{service_name}.{KUBERNETES_NAMESPACE}.{HEADLESS_SERVICE}:50051" # TODO: make port configurable + return endpoint @override def register_updated_segment_callback( @@ -263,7 +285,9 @@ def register_updated_segment_callback( ) def _update_memberlist(self, memberlist: Memberlist) -> None: with self._curr_memberlist_mutex: - add_attributes_to_current_span({"new_memberlist": memberlist}) + add_attributes_to_current_span( + {"new_memberlist": [m.id for m in memberlist]} + ) self._curr_memberlist = memberlist def extract_service_name(self, pod_name: str) -> Optional[str]: diff --git a/chromadb/segment/impl/manager/distributed.py b/chromadb/segment/impl/manager/distributed.py index 7acdaa4d860..4367ab1c44e 100644 --- a/chromadb/segment/impl/manager/distributed.py +++ b/chromadb/segment/impl/manager/distributed.py @@ -14,28 +14,29 @@ from chromadb.segment.distributed import SegmentDirectory from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams from chromadb.telemetry.opentelemetry import ( - OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) -from chromadb.types import Collection, CollectionAndSegments, Operation, Segment, SegmentScope +from chromadb.types import ( + Collection, + Operation, + Segment, + SegmentScope, +) class DistributedSegmentManager(SegmentManager): _sysdb: SysDB _system: System - _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] _segment_directory: SegmentDirectory _lock: Lock - # _segment_server_stubs: Dict[str, SegmentServerStub] # grpc_url -> grpc stub def __init__(self, system: System): super().__init__(system) self._sysdb = self.require(SysDB) self._segment_directory = self.require(SegmentDirectory) self._system = system - self._opentelemetry_client = system.require(OpenTelemetryClient) self._instances = {} self._lock = Lock() @@ -77,6 +78,8 @@ def prepare_segments_for_new_collection( @override def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: + # TODO: this should be a pass, delete_collection is expected to delete segments in + # distributed segments = self._sysdb.get_segments(collection=collection_id) return [s["id"] for s in segments] diff --git a/chromadb/test/distributed/test_reroute.py b/chromadb/test/distributed/test_reroute.py new file mode 100644 index 00000000000..824664aded0 --- /dev/null +++ b/chromadb/test/distributed/test_reroute.py @@ -0,0 +1,74 @@ +from typing import Sequence +from chromadb.test.conftest import ( + reset, + skip_if_not_cluster, +) +from chromadb.api import ClientAPI +from kubernetes import client as k8s_client, config +import time + + +@skip_if_not_cluster() +def test_reroute( + client: ClientAPI, +) -> None: + reset(client) + collection = client.create_collection( + name="test", + metadata={"hnsw:construction_ef": 128, "hnsw:search_ef": 128, "hnsw:M": 128}, + ) + + ids = [str(i) for i in range(10)] + embeddings: list[Sequence[float]] = [ + [float(i), float(i), float(i)] for i in range(10) + ] + collection.add(ids=ids, embeddings=embeddings) + collection.query(query_embeddings=[embeddings[0]]) + + # Restart the query service using k8s api, in order to trigger a reroute + # of the query service + config.load_kube_config() + v1 = k8s_client.CoreV1Api() + # Find all pods with the label "app=query" + res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") + assert len(res.items) > 0 + items = res.items + seen_ids = set() + + # Restart all the pods by deleting them + for item in items: + seen_ids.add(item.metadata.uid) + name = item.metadata.name + namespace = item.metadata.namespace + v1.delete_namespaced_pod(name, namespace) + + # Wait until we have len(seen_ids) pods running with new UIDs + timeout_secs = 10 + start_time = time.time() + while True: + res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") + items = res.items + new_ids = set([item.metadata.uid for item in items]) + if len(new_ids) == len(seen_ids) and len(new_ids.intersection(seen_ids)) == 0: + break + if time.time() - start_time > timeout_secs: + assert False, "Timed out waiting for new pods to start" + time.sleep(1) + + # Wait for the query service to be ready, or timeout + while True: + res = v1.list_namespaced_pod("chroma", label_selector="app=query-service") + items = res.items + ready = True + for item in items: + if item.status.phase != "Running": + ready = False + break + if ready: + break + if time.time() - start_time > timeout_secs: + assert False, "Timed out waiting for new pods to be ready" + time.sleep(1) + + time.sleep(1) + collection.query(query_embeddings=[embeddings[0]]) diff --git a/chromadb/test/segment/distributed/test_memberlist_provider.py b/chromadb/test/segment/distributed/test_memberlist_provider.py index c97bcbd06cb..0422d84431a 100644 --- a/chromadb/test/segment/distributed/test_memberlist_provider.py +++ b/chromadb/test/segment/distributed/test_memberlist_provider.py @@ -1,9 +1,10 @@ # Tests the CustomResourceMemberlist provider +from dataclasses import asdict import threading from chromadb.test.conftest import skip_if_not_cluster from kubernetes import client, config from chromadb.config import System, Settings -from chromadb.segment.distributed import Memberlist +from chromadb.segment.distributed import Memberlist, Member from chromadb.segment.impl.distributed.segment_directory import ( CustomResourceMemberlistProvider, KUBERNETES_GROUP, @@ -17,12 +18,12 @@ def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Membe config.load_config() api_instance = client.CustomObjectsApi() - members = [{"member_id": f"test-{i}"} for i in range(1, n + 1)] + members = [Member(id=f"test-{i}", ip=f"10.0.0.{i}") for i in range(1, n + 1)] body = { "kind": "MemberList", "metadata": {"name": memberlist_name}, - "spec": {"members": members}, + "spec": {"members": [{"member_id": m.id, "member_ip": m.ip} for m in members]}, } _ = api_instance.patch_namespaced_custom_object( @@ -34,11 +35,13 @@ def update_memberlist(n: int, memberlist_name: str = "test-memberlist") -> Membe body=body, ) - return [m["member_id"] for m in members] + return members def compare_memberlists(m1: Memberlist, m2: Memberlist) -> bool: - return sorted(m1) == sorted(m2) + m1_as_dict = sorted([asdict(m) for m in m1], key=lambda x: x["id"]) + m2_as_dict = sorted([asdict(m) for m in m2], key=lambda x: x["id"]) + return m1_as_dict == m2_as_dict @skip_if_not_cluster() diff --git a/go/pkg/memberlist_manager/memberlist_manager.go b/go/pkg/memberlist_manager/memberlist_manager.go index 24a54e5c8c8..588fceeaa26 100644 --- a/go/pkg/memberlist_manager/memberlist_manager.go +++ b/go/pkg/memberlist_manager/memberlist_manager.go @@ -128,6 +128,16 @@ func memberlistSame(oldMemberlist Memberlist, newMemberlist Memberlist) bool { if len(oldMemberlist) != len(newMemberlist) { return false } + oldMemberlistIps := make(map[string]string) + for _, member := range oldMemberlist { + oldMemberlistIps[member.id] = member.ip + } + for _, member := range newMemberlist { + if ip, ok := oldMemberlistIps[member.id]; !ok || ip != member.ip { + return false + } + } + // use a map to check if the new memberlist contains all the old members newMemberlistMap := make(map[string]bool) for _, member := range newMemberlist { diff --git a/go/pkg/memberlist_manager/memberlist_manager_test.go b/go/pkg/memberlist_manager/memberlist_manager_test.go index 9fb9ff1e172..ccea2ee71aa 100644 --- a/go/pkg/memberlist_manager/memberlist_manager_test.go +++ b/go/pkg/memberlist_manager/memberlist_manager_test.go @@ -52,7 +52,7 @@ func TestNodeWatcher(t *testing.T) { t.Fatalf("Error getting node status: %v", err) } - return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0"}}) + return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}}) }, 10, 1*time.Second) if !ok { t.Fatalf("Node status did not update after adding a pod") @@ -83,7 +83,7 @@ func TestNodeWatcher(t *testing.T) { if err != nil { t.Fatalf("Error getting node status: %v", err) } - return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0"}}) + return reflect.DeepEqual(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}}) }, 10, 1*time.Second) if !ok { t.Fatalf("Node status did not update after adding a not ready pod") @@ -108,13 +108,13 @@ func TestMemberlistStore(t *testing.T) { assert.Equal(t, Memberlist{}, memberlist) // Add a member to the memberlist - memberlist_store.UpdateMemberlist(context.Background(), Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}, "0") + memberlist_store.UpdateMemberlist(context.Background(), Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}, "0") memberlist, _, err = memberlist_store.GetMemberlist(context.Background()) if err != nil { t.Fatalf("Error getting memberlist: %v", err) } // assert the memberlist has the correct members - if !memberlistSame(memberlist, Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}) { + if !memberlistSame(memberlist, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}}) { t.Fatalf("Memberlist did not update after adding a member") } } @@ -184,7 +184,7 @@ func TestMemberlistManager(t *testing.T) { // Get the memberlist ok := retryUntilCondition(func() bool { - return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0"}}) + return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.49"}}) }, 30, 1*time.Second) if !ok { t.Fatalf("Memberlist did not update after adding a pod") @@ -195,7 +195,7 @@ func TestMemberlistManager(t *testing.T) { // Get the memberlist ok = retryUntilCondition(func() bool { - return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}}) + return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-0", ip: "10.0.0.49"}, Member{id: "test-pod-1", ip: "10.0.0.50"}}) }, 30, 1*time.Second) if !ok { t.Fatalf("Memberlist did not update after adding a pod") @@ -206,7 +206,7 @@ func TestMemberlistManager(t *testing.T) { // Get the memberlist ok = retryUntilCondition(func() bool { - return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-1"}}) + return getMemberlistAndCompare(t, memberlistStore, Memberlist{Member{id: "test-pod-1", ip: "10.0.0.50"}}) }, 30, 1*time.Second) if !ok { t.Fatalf("Memberlist did not update after deleting a pod") @@ -217,23 +217,23 @@ func TestMemberlistSame(t *testing.T) { memberlist := Memberlist{} assert.True(t, memberlistSame(memberlist, memberlist)) - newMemberlist := Memberlist{Member{id: "test-pod-0"}} + newMemberlist := Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}} assert.False(t, memberlistSame(memberlist, newMemberlist)) assert.False(t, memberlistSame(newMemberlist, memberlist)) assert.True(t, memberlistSame(newMemberlist, newMemberlist)) - memberlist = Memberlist{Member{id: "test-pod-1"}} + memberlist = Memberlist{Member{id: "test-pod-1", ip: "10.0.0.2"}} assert.False(t, memberlistSame(newMemberlist, memberlist)) assert.False(t, memberlistSame(memberlist, newMemberlist)) assert.True(t, memberlistSame(memberlist, memberlist)) - memberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}} - newMemberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}} + memberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}} + newMemberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}} assert.True(t, memberlistSame(memberlist, newMemberlist)) assert.True(t, memberlistSame(newMemberlist, memberlist)) - memberlist = Memberlist{Member{id: "test-pod-0"}, Member{id: "test-pod-1"}} - newMemberlist = Memberlist{Member{id: "test-pod-1"}, Member{id: "test-pod-0"}} + memberlist = Memberlist{Member{id: "test-pod-0", ip: "10.0.0.1"}, Member{id: "test-pod-1", ip: "10.0.0.2"}} + newMemberlist = Memberlist{Member{id: "test-pod-1", ip: "10.0.0.2"}, Member{id: "test-pod-0", ip: "10.0.0.1"}} assert.True(t, memberlistSame(memberlist, newMemberlist)) assert.True(t, memberlistSame(newMemberlist, memberlist)) } diff --git a/go/pkg/memberlist_manager/memberlist_store.go b/go/pkg/memberlist_manager/memberlist_store.go index 42a2efe4261..d7046205431 100644 --- a/go/pkg/memberlist_manager/memberlist_store.go +++ b/go/pkg/memberlist_manager/memberlist_store.go @@ -20,11 +20,13 @@ type IMemberlistStore interface { type Member struct { id string + ip string } // MarshalLogObject implements the zapcore.ObjectMarshaler interface func (m Member) MarshalLogObject(enc zapcore.ObjectEncoder) error { enc.AddString("id", m.id) + enc.AddString("ip", m.ip) return nil } @@ -80,7 +82,14 @@ func (s *CRMemberlistStore) GetMemberlist(ctx context.Context) (return_memberlis if !ok { return nil, "", errors.New("failed to cast member_id to string") } - memberlist = append(memberlist, Member{member_id}) + // If member_ip is in the CR, extract it, otherwise set it to empty string + // This is for backwards compatibility with older CRs that don't have member_ip + member_ip, ok := member_map["member_ip"].(string) + if !ok { + member_ip = "" + } + + memberlist = append(memberlist, Member{member_id, member_ip}) } return memberlist, unstrucuted.GetResourceVersion(), nil } @@ -107,6 +116,7 @@ func (list Memberlist) toCr(namespace string, memberlistName string, resourceVer for i, member := range list { members[i] = map[string]interface{}{ "member_id": member.id, + "member_ip": member.ip, } } diff --git a/go/pkg/memberlist_manager/node_watcher.go b/go/pkg/memberlist_manager/node_watcher.go index 4351255da95..a79b73d59f6 100644 --- a/go/pkg/memberlist_manager/node_watcher.go +++ b/go/pkg/memberlist_manager/node_watcher.go @@ -165,7 +165,7 @@ func (w *KubernetesWatcher) ListReadyMembers() (Memberlist, error) { for _, condition := range pod.Status.Conditions { if condition.Type == v1.PodReady { if condition.Status == v1.ConditionTrue { - memberlist = append(memberlist, Member{pod.Name}) + memberlist = append(memberlist, Member{pod.Name, pod.Status.PodIP}) } break } diff --git a/k8s/distributed-chroma/Chart.yaml b/k8s/distributed-chroma/Chart.yaml index 72d420042f6..ab51db46b14 100644 --- a/k8s/distributed-chroma/Chart.yaml +++ b/k8s/distributed-chroma/Chart.yaml @@ -16,7 +16,7 @@ apiVersion: v2 name: distributed-chroma description: A helm chart for distributed Chroma type: application -version: 0.1.12 +version: 0.1.13 appVersion: "0.4.24" keywords: - chroma diff --git a/k8s/distributed-chroma/crds/memberlist_crd.yaml b/k8s/distributed-chroma/crds/memberlist_crd.yaml index 9cde59ab468..51e34426db6 100644 --- a/k8s/distributed-chroma/crds/memberlist_crd.yaml +++ b/k8s/distributed-chroma/crds/memberlist_crd.yaml @@ -27,6 +27,8 @@ spec: properties: member_id: type: string + member_ip: + type: string scope: Namespaced names: plural: memberlists