Skip to content

Commit

Permalink
Fix track matching regression (#5571)
Browse files Browse the repository at this point in the history
## Problem
A regression was introduced when adjusting the track matching logic to
use `lapjv` instead of `munkres`. The `lapjv` algorithm returns `-1` for
unmatched items, which wasn't being handled correctly in the matching
logic. This caused incorrect track assignments when importing new music.

## Solution
- Modified the mapping creation to filter out unmatched items (where
index is `-1`)
- Updated test case to properly catch this scenario
  • Loading branch information
snejus authored Jan 4, 2025
2 parents f91f096 + ef902ea commit c01d059
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 112 deletions.
16 changes: 11 additions & 5 deletions beets/autotag/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,21 @@ def assign_items(
objects. These "extra" objects occur when there is an unequal number
of objects of the two types.
"""
log.debug("Computing track assignment...")
# Construct the cost matrix.
costs = [[float(track_distance(i, t)) for t in tracks] for i in items]
# Find a minimum-cost bipartite matching.
log.debug("Computing track assignment...")
cost, _, assigned_idxs = lap.lapjv(np.array(costs), extend_cost=True)
# Assign items to tracks
_, _, assigned_item_idxs = lap.lapjv(np.array(costs), extend_cost=True)
log.debug("...done.")

# Produce the output matching.
mapping = {items[i]: tracks[t] for (t, i) in enumerate(assigned_idxs)}
# Each item in `assigned_item_idxs` list corresponds to a track in the
# `tracks` list. Each value is either an index into the assigned item in
# `items` list, or -1 if that track has no match.
mapping = {
items[iidx]: t
for iidx, t in zip(assigned_item_idxs, tracks)
if iidx != -1
}
extra_items = list(set(items) - mapping.keys())
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
extra_tracks = list(set(tracks) - set(mapping.values()))
Expand Down
42 changes: 22 additions & 20 deletions beets/test/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

import beets
import beets.plugins
from beets import autotag, config, importer, logging, util
from beets import autotag, importer, logging, util
from beets.autotag.hooks import AlbumInfo, TrackInfo
from beets.importer import ImportSession
from beets.library import Album, Item, Library
Expand Down Expand Up @@ -153,12 +153,27 @@ def check_reflink_support(path: str) -> bool:
return reflink.supported_at(path)


class ConfigMixin:
@cached_property
def config(self) -> beets.IncludeLazyConfig:
"""Base beets configuration for tests."""
config = beets.config
config.sources = []
config.read(user=False, defaults=True)

config["plugins"] = []
config["verbose"] = 1
config["ui"]["color"] = False
config["threaded"] = False
return config


NEEDS_REFLINK = unittest.skipUnless(
check_reflink_support(gettempdir()), "no reflink support for libdir"
)


class TestHelper(_common.Assertions):
class TestHelper(_common.Assertions, ConfigMixin):
"""Helper mixin for high-level cli and plugin tests.
This mixin provides methods to isolate beets' global state provide
Expand All @@ -184,8 +199,6 @@ def setup_beets(self):
- ``libdir`` Path to a subfolder of ``temp_dir``, containing the
library's media files. Same as ``config['directory']``.
- ``config`` The global configuration used by beets.
- ``lib`` Library instance created with the settings from
``config``.
Expand All @@ -202,15 +215,6 @@ def setup_beets(self):
)
self.env_patcher.start()

self.config = beets.config
self.config.sources = []
self.config.read(user=False, defaults=True)

self.config["plugins"] = []
self.config["verbose"] = 1
self.config["ui"]["color"] = False
self.config["threaded"] = False

self.libdir = os.path.join(self.temp_dir, b"libdir")
os.mkdir(syspath(self.libdir))
self.config["directory"] = os.fsdecode(self.libdir)
Expand All @@ -229,8 +233,6 @@ def teardown_beets(self):
self.io.restore()
self.lib._close()
self.remove_temp_dir()
beets.config.clear()
beets.config._materialized = False

# Library fixtures methods

Expand Down Expand Up @@ -452,7 +454,7 @@ def setUp(self):
self.i = _common.item(self.lib)


class PluginMixin:
class PluginMixin(ConfigMixin):
plugin: ClassVar[str]
preload_plugin: ClassVar[bool] = True

Expand All @@ -473,7 +475,7 @@ def load_plugins(self, *plugins: str) -> None:
"""
# FIXME this should eventually be handled by a plugin manager
plugins = (self.plugin,) if hasattr(self, "plugin") else plugins
beets.config["plugins"] = plugins
self.config["plugins"] = plugins
beets.plugins.load_plugins(plugins)
beets.plugins.find_plugins()

Expand All @@ -494,7 +496,7 @@ def unload_plugins(self) -> None:
# FIXME this should eventually be handled by a plugin manager
for plugin_class in beets.plugins._instances:
plugin_class.listeners = None
beets.config["plugins"] = []
self.config["plugins"] = []
beets.plugins._classes = set()
beets.plugins._instances = {}
Item._types = getattr(Item, "_original_types", {})
Expand All @@ -504,7 +506,7 @@ def unload_plugins(self) -> None:

@contextmanager
def configure_plugin(self, config: Any):
beets.config[self.plugin].set(config)
self.config[self.plugin].set(config)
self.load_plugins(self.plugin)

yield
Expand Down Expand Up @@ -624,7 +626,7 @@ def _get_import_session(self, import_dir: bytes) -> ImportSession:
def setup_importer(
self, import_dir: bytes | None = None, **kwargs
) -> ImportSession:
config["import"].set_args({**self.default_import_config, **kwargs})
self.config["import"].set_args({**self.default_import_config, **kwargs})
self.importer = self._get_import_session(import_dir or self.import_dir)
return self.importer

Expand Down
128 changes: 41 additions & 87 deletions test/test_autotag.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from beets.autotag import AlbumInfo, TrackInfo, correct_list_fields, match
from beets.autotag.hooks import Distance, string_dist
from beets.library import Item
from beets.test.helper import BeetsTestCase
from beets.test.helper import BeetsTestCase, ConfigMixin
from beets.util import plurality


Expand Down Expand Up @@ -498,85 +498,46 @@ def test_per_medium_track_numbers(self):
assert dist == 0


class AssignmentTest(unittest.TestCase):
def item(self, title, track):
return Item(
title=title,
track=track,
mb_trackid="",
mb_albumid="",
mb_artistid="",
)

def test_reorder_when_track_numbers_incorrect(self):
items = []
items.append(self.item("one", 1))
items.append(self.item("three", 2))
items.append(self.item("two", 3))
trackinfo = []
trackinfo.append(TrackInfo(title="one"))
trackinfo.append(TrackInfo(title="two"))
trackinfo.append(TrackInfo(title="three"))
mapping, extra_items, extra_tracks = match.assign_items(
items, trackinfo
)
assert extra_items == []
assert extra_tracks == []
assert mapping == {
items[0]: trackinfo[0],
items[1]: trackinfo[2],
items[2]: trackinfo[1],
}
class TestAssignment(ConfigMixin):
A = "one"
B = "two"
C = "three"

@pytest.fixture(autouse=True)
def _setup_config(self):
self.config["match"]["track_length_grace"] = 10
self.config["match"]["track_length_max"] = 30

@pytest.mark.parametrize(
# 'expected' is a tuple of expected (mapping, extra_items, extra_tracks)
"item_titles, track_titles, expected",
[
# items ordering gets corrected
([A, C, B], [A, B, C], ({A: A, B: B, C: C}, [], [])),
# unmatched tracks are returned as 'extra_tracks'
# the first track is unmatched
([B, C], [A, B, C], ({B: B, C: C}, [], [A])),
# the middle track is unmatched
([A, C], [A, B, C], ({A: A, C: C}, [], [B])),
# the last track is unmatched
([A, B], [A, B, C], ({A: A, B: B}, [], [C])),
# unmatched items are returned as 'extra_items'
([A, C, B], [A, C], ({A: A, C: C}, [B], [])),
],
)
def test_assign_tracks(self, item_titles, track_titles, expected):
expected_mapping, expected_extra_items, expected_extra_tracks = expected

def test_order_works_with_invalid_track_numbers(self):
items = []
items.append(self.item("one", 1))
items.append(self.item("three", 1))
items.append(self.item("two", 1))
trackinfo = []
trackinfo.append(TrackInfo(title="one"))
trackinfo.append(TrackInfo(title="two"))
trackinfo.append(TrackInfo(title="three"))
mapping, extra_items, extra_tracks = match.assign_items(
items, trackinfo
)
assert extra_items == []
assert extra_tracks == []
assert mapping == {
items[0]: trackinfo[0],
items[1]: trackinfo[2],
items[2]: trackinfo[1],
}
items = [Item(title=title) for title in item_titles]
tracks = [TrackInfo(title=title) for title in track_titles]

def test_order_works_with_missing_tracks(self):
items = []
items.append(self.item("one", 1))
items.append(self.item("three", 3))
trackinfo = []
trackinfo.append(TrackInfo(title="one"))
trackinfo.append(TrackInfo(title="two"))
trackinfo.append(TrackInfo(title="three"))
mapping, extra_items, extra_tracks = match.assign_items(
items, trackinfo
)
assert extra_items == []
assert extra_tracks == [trackinfo[1]]
assert mapping == {items[0]: trackinfo[0], items[1]: trackinfo[2]}
mapping, extra_items, extra_tracks = match.assign_items(items, tracks)

def test_order_works_with_extra_tracks(self):
items = []
items.append(self.item("one", 1))
items.append(self.item("two", 2))
items.append(self.item("three", 3))
trackinfo = []
trackinfo.append(TrackInfo(title="one"))
trackinfo.append(TrackInfo(title="three"))
mapping, extra_items, extra_tracks = match.assign_items(
items, trackinfo
)
assert extra_items == [items[1]]
assert extra_tracks == []
assert mapping == {items[0]: trackinfo[0], items[2]: trackinfo[1]}
assert (
{i.title: t.title for i, t in mapping.items()},
[i.title for i in extra_items],
[t.title for t in extra_tracks],
) == (expected_mapping, expected_extra_items, expected_extra_tracks)

def test_order_works_when_track_names_are_entirely_wrong(self):
# A real-world test case contributed by a user.
Expand All @@ -587,9 +548,6 @@ def item(i, length):
title=f"ben harper - Burn to Shine {i}",
track=i,
length=length,
mb_trackid="",
mb_albumid="",
mb_artistid="",
)

items = []
Expand Down Expand Up @@ -623,13 +581,9 @@ def info(index, title, length):
trackinfo.append(info(11, "Beloved One", 243.733))
trackinfo.append(info(12, "In the Lord's Arms", 186.13300000000001))

mapping, extra_items, extra_tracks = match.assign_items(
items, trackinfo
)
assert extra_items == []
assert extra_tracks == []
for item, info in mapping.items():
assert items.index(item) == trackinfo.index(info)
expected = dict(zip(items, trackinfo)), [], []

assert match.assign_items(items, trackinfo) == expected


class ApplyTestUtil:
Expand Down

0 comments on commit c01d059

Please sign in to comment.