diff --git a/channels/layers.py b/channels/layers.py index 12bbd2b8..48f7baca 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -198,13 +198,13 @@ def __init__( group_expiry=86400, capacity=100, channel_capacity=None, - **kwargs + **kwargs, ): super().__init__( expiry=expiry, capacity=capacity, channel_capacity=channel_capacity, - **kwargs + **kwargs, ) self.channels = {} self.groups = {} @@ -225,13 +225,14 @@ async def send(self, channel, message): # name in message assert "__asgi_channel__" not in message - queue = self.channels.setdefault(channel, asyncio.Queue()) - # Are we full - if queue.qsize() >= self.capacity: - raise ChannelFull(channel) - + queue = self.channels.setdefault( + channel, asyncio.Queue(maxsize=self.get_capacity(channel)) + ) # Add message - await queue.put((time.time() + self.expiry, deepcopy(message))) + try: + queue.put_nowait((time.time() + self.expiry, deepcopy(message))) + except asyncio.queues.QueueFull: + raise ChannelFull(channel) async def receive(self, channel): """ @@ -242,14 +243,16 @@ async def receive(self, channel): assert self.valid_channel_name(channel) self._clean_expired() - queue = self.channels.setdefault(channel, asyncio.Queue()) + queue = self.channels.setdefault( + channel, asyncio.Queue(maxsize=self.get_capacity(channel)) + ) # Do a plain direct receive try: _, message = await queue.get() finally: if queue.empty(): - del self.channels[channel] + self.channels.pop(channel, None) return message @@ -279,19 +282,17 @@ def _clean_expired(self): self._remove_from_groups(channel) # Is the channel now empty and needs deleting? if queue.empty(): - del self.channels[channel] + self.channels.pop(channel, None) # Group Expiration timeout = int(time.time()) - self.group_expiry - for group in self.groups: - for channel in list(self.groups.get(group, set())): - # If join time is older than group_expiry end the group membership - if ( - self.groups[group][channel] - and int(self.groups[group][channel]) < timeout - ): + for channels in self.groups.values(): + for name, timestamp in list(channels.items()): + # If join time is older than group_expiry + # end the group membership + if timestamp and timestamp < timeout: # Delete from group - del self.groups[group][channel] + channels.pop(name, None) # Flush extension @@ -308,8 +309,7 @@ def _remove_from_groups(self, channel): Removes a channel from all groups. Used when a message on it expires. """ for channels in self.groups.values(): - if channel in channels: - del channels[channel] + channels.pop(channel, None) # Groups extension @@ -329,11 +329,13 @@ async def group_discard(self, group, channel): assert self.valid_channel_name(channel), "Invalid channel name" assert self.valid_group_name(group), "Invalid group name" # Remove from group set - if group in self.groups: - if channel in self.groups[group]: - del self.groups[group][channel] - if not self.groups[group]: - del self.groups[group] + group_channels = self.groups.get(group, None) + if group_channels: + # remove channel if in group + group_channels.pop(channel, None) + # is group now empty? If yes remove it + if not group_channels: + self.groups.pop(group, None) async def group_send(self, group, message): # Check types @@ -341,10 +343,15 @@ async def group_send(self, group, message): assert self.valid_group_name(group), "Invalid group name" # Run clean self._clean_expired() + # Send to each channel - for channel in self.groups.get(group, set()): + ops = [] + if group in self.groups: + for channel in self.groups[group].keys(): + ops.append(asyncio.create_task(self.send(channel, message))) + for send_result in asyncio.as_completed(ops): try: - await self.send(channel, message) + await send_result except ChannelFull: pass diff --git a/tests/test_inmemorychannel.py b/tests/test_inmemorychannel.py index 3f05ed7e..4ba4bfab 100644 --- a/tests/test_inmemorychannel.py +++ b/tests/test_inmemorychannel.py @@ -26,9 +26,36 @@ async def test_send_receive(channel_layer): await channel_layer.send( "test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"} ) + await channel_layer.send( + "test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"} + ) message = await channel_layer.receive("test-channel-1") assert message["type"] == "test.message" assert message["text"] == "Ahoy-hoy!" + # not removed because not empty + assert "test-channel-1" in channel_layer.channels + message = await channel_layer.receive("test-channel-1") + assert message["type"] == "test.message" + assert message["text"] == "Ahoy-hoy!" + # removed because empty + assert "test-channel-1" not in channel_layer.channels + + +@pytest.mark.asyncio +async def test_race_empty(channel_layer): + """ + Makes sure the race is handled gracefully. + """ + receive_task = asyncio.create_task(channel_layer.receive("test-channel-1")) + await asyncio.sleep(0.1) + await channel_layer.send( + "test-channel-1", {"type": "test.message", "text": "Ahoy-hoy!"} + ) + del channel_layer.channels["test-channel-1"] + await asyncio.sleep(0.1) + message = await receive_task + assert message["type"] == "test.message" + assert message["text"] == "Ahoy-hoy!" @pytest.mark.asyncio @@ -62,7 +89,6 @@ async def test_multi_send_receive(channel_layer): """ Tests overlapping sends and receives, and ordering. """ - channel_layer = InMemoryChannelLayer() await channel_layer.send("test-channel-3", {"type": "message.1"}) await channel_layer.send("test-channel-3", {"type": "message.2"}) await channel_layer.send("test-channel-3", {"type": "message.3"}) @@ -76,7 +102,6 @@ async def test_groups_basic(channel_layer): """ Tests basic group operation. """ - channel_layer = InMemoryChannelLayer() await channel_layer.group_add("test-group", "test-gr-chan-1") await channel_layer.group_add("test-group", "test-gr-chan-2") await channel_layer.group_add("test-group", "test-gr-chan-3") @@ -97,7 +122,6 @@ async def test_groups_channel_full(channel_layer): """ Tests that group_send ignores ChannelFull """ - channel_layer = InMemoryChannelLayer() await channel_layer.group_add("test-group", "test-gr-chan-1") await channel_layer.group_send("test-group", {"type": "message.1"}) await channel_layer.group_send("test-group", {"type": "message.1"})