diff --git a/lib/job-iteration/active_record_cursor.rb b/lib/job-iteration/active_record_cursor.rb index 10a8f4ef..f753b511 100644 --- a/lib/job-iteration/active_record_cursor.rb +++ b/lib/job-iteration/active_record_cursor.rb @@ -18,23 +18,16 @@ def initialize end end - def initialize(relation, columns = nil, position = nil) - @columns = if columns - Array(columns) - else - Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" } - end + def initialize(relation, columns, position = nil) + @columns = columns self.position = Array.wrap(position) raise ArgumentError, "Must specify at least one column" if columns.empty? - if relation.joins_values.present? && !@columns.all? { |column| column.to_s.include?(".") } - raise ArgumentError, "You need to specify fully-qualified columns if you join a table" - end if relation.arel.orders.present? || relation.arel.taken.present? raise ConditionNotSupportedError end - @base_relation = relation.reorder(@columns.join(",")) + @base_relation = relation.reorder(*@columns) @reached_end = false end @@ -54,17 +47,15 @@ def position=(position) def update_from_record(record) self.position = @columns.map do |column| - method = column.to_s.split(".").last - - if ActiveRecord.version >= Gem::Version.new("7.1.0.alpha") && method == "id" + if ActiveRecord.version >= Gem::Version.new("7.1.0.alpha") && column.name == "id" record.id_value else - record.send(method.to_sym) + record.send(column.name) end end end - def next_batch(batch_size) + def next_batch(batch_size, database_role: nil) return if @reached_end relation = @base_relation.limit(batch_size) @@ -74,7 +65,13 @@ def next_batch(batch_size) end records = relation.uncached do - relation.to_a + if database_role.present? + ActiveRecord::Base.connected_to(role: database_role) do + relation.to_a + end + else + relation.to_a + end end update_from_record(records.last) unless records.empty? @@ -89,14 +86,14 @@ def conditions i = @position.size - 1 column = @columns[i] conditions = if @columns.size == @position.size - "#{column} > ?" + column.gt(@position[i]) else - "#{column} >= ?" + column.gteq(@position[i]) end while i > 0 i -= 1 column = @columns[i] - conditions = "#{column} > ? OR (#{column} = ? AND (#{conditions}))" + conditions = column.gt(@position[i]).or(column.eq(@position[i]).and(conditions)) end ret = @position.reduce([conditions]) { |params, value| params << value << value } ret.pop diff --git a/lib/job-iteration/active_record_enumerator.rb b/lib/job-iteration/active_record_enumerator.rb index 363a4ecf..12a52096 100644 --- a/lib/job-iteration/active_record_enumerator.rb +++ b/lib/job-iteration/active_record_enumerator.rb @@ -7,13 +7,20 @@ module JobIteration class ActiveRecordEnumerator SQL_DATETIME_WITH_NSEC = "%Y-%m-%d %H:%M:%S.%N" - def initialize(relation, columns: nil, batch_size: 100, cursor: nil) + def initialize(relation, columns: nil, batch_size: 100, cursor: nil, database_role: nil) @relation = relation @batch_size = batch_size + @database_role = database_role @columns = if columns - Array(columns) + Array(columns).map do |column| + if column.is_a?(Arel::Attributes::Attribute) + column + else + relation.arel_table[column.to_sym] + end + end else - Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" } + Array(relation.primary_key).map { |pk| relation.arel_table[pk.to_sym] } end @cursor = cursor end @@ -31,7 +38,7 @@ def records def batches cursor = finder_cursor Enumerator.new(method(:size)) do |yielder| - while (records = cursor.next_batch(@batch_size)) + while (records = cursor.next_batch(@batch_size, database_role: @database_role)) yielder.yield(records, cursor_value(records.last)) if records.any? end end @@ -45,7 +52,7 @@ def size def cursor_value(record) positions = @columns.map do |column| - attribute_name = column.to_s.split(".").last + attribute_name = column.name.to_sym column_value(record, attribute_name) end return positions.first if positions.size == 1 @@ -58,8 +65,8 @@ def finder_cursor end def column_value(record, attribute) - value = record.read_attribute(attribute.to_sym) - case record.class.columns_hash.fetch(attribute).type + value = record.read_attribute(attribute) + case record.class.columns_hash.fetch(attribute.to_s).type when :datetime value.strftime(SQL_DATETIME_WITH_NSEC) else