Skip to content

Commit

Permalink
KAFKA-15045: (KIP-924 pt. 14) Callback to TaskAssignor::onAssignmentC…
Browse files Browse the repository at this point in the history
…omputed (#16123)

This PR adds the logic and wiring necessary to make the callback to
TaskAssignor::onAssignmentComputed with the necessary parameters.

We also fixed some log statements in the actual assignment error
computation, as well as modified the ApplicationState::allTasks method
to return a Map instead of a Set of TaskInfos.

Reviewers: Anna Sophie Blee-Goldman <ableegoldman@apache.org>
  • Loading branch information
apourchet authored May 29, 2024
1 parent 862ea12 commit cc269b0
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package org.apache.kafka.streams.processor.assignment;

import java.util.Map;
import java.util.Set;
import org.apache.kafka.streams.errors.TaskAssignmentException;
import org.apache.kafka.streams.processor.TaskId;

/**
* A read-only metadata class representing the state of the application and the current rebalance.
Expand Down Expand Up @@ -48,5 +48,5 @@ public interface ApplicationState {
/**
* @return the set of all tasks in this topology which must be assigned
*/
Set<TaskInfo> allTasks();
Map<TaskId, TaskInfo> allTasks();
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ private void optimizeActive(final ApplicationState applicationState,

final Map<ProcessId, KafkaStreamsAssignment> currentAssignments = assignmentState.buildKafkaStreamsAssignments();

final Set<TaskId> statefulTasks = applicationState.allTasks().stream()
final Set<TaskId> statefulTasks = applicationState.allTasks().values().stream()
.filter(TaskInfo::isStateful)
.map(TaskInfo::id)
.collect(Collectors.toSet());
final Map<ProcessId, KafkaStreamsAssignment> optimizedAssignmentsForStatefulTasks = TaskAssignmentUtils.optimizeRackAwareActiveTasks(
applicationState, currentAssignments, new TreeSet<>(statefulTasks));

final Set<TaskId> statelessTasks = applicationState.allTasks().stream()
final Set<TaskId> statelessTasks = applicationState.allTasks().values().stream()
.filter(task -> !task.isStateful())
.map(TaskInfo::id)
.collect(Collectors.toSet());
Expand Down Expand Up @@ -126,8 +126,7 @@ private static void assignActive(final ApplicationState applicationState,
final AssignmentState assignmentState,
final boolean mustPreserveActiveTaskAssignment) {
final int totalCapacity = computeTotalProcessingThreads(clients);
final Set<TaskId> allTaskIds = applicationState.allTasks().stream()
.map(TaskInfo::id).collect(Collectors.toSet());
final Set<TaskId> allTaskIds = applicationState.allTasks().keySet();
final int taskCount = allTaskIds.size();
final int activeTasksPerThread = taskCount / totalCapacity;
final Set<TaskId> unassigned = new HashSet<>(allTaskIds);
Expand Down Expand Up @@ -172,7 +171,7 @@ private static void assignActive(final ApplicationState applicationState,

private static void assignStandby(final ApplicationState applicationState,
final AssignmentState assignmentState) {
final Set<TaskInfo> statefulTasks = applicationState.allTasks().stream()
final Set<TaskInfo> statefulTasks = applicationState.allTasks().values().stream()
.filter(TaskInfo::isStateful)
.collect(Collectors.toSet());
final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import org.apache.kafka.streams.processor.assignment.ProcessId;
import org.apache.kafka.streams.processor.assignment.TaskAssignor.TaskAssignment;
import org.apache.kafka.streams.processor.assignment.TaskTopicPartition;
import org.apache.kafka.streams.processor.internals.assignment.ApplicationStateImpl;
import org.apache.kafka.streams.processor.internals.assignment.DefaultApplicationState;
import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.TopicsInfo;
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
Expand Down Expand Up @@ -191,6 +191,11 @@ public String toString() {
}
}

@FunctionalInterface
public interface UserTaskAssignmentListener {
void onAssignmentComputed(GroupAssignment assignment, GroupSubscription subscription);
}

// keep track of any future consumers in a "dummy" Client since we can't decipher their subscription
private static final UUID FUTURE_ID = randomUUID();

Expand Down Expand Up @@ -445,7 +450,7 @@ public GroupAssignment assign(final Cluster metadata, final GroupSubscription gr

final boolean versionProbing =
checkMetadataVersions(minReceivedMetadataVersion, minSupportedMetadataVersion, futureMetadataVersion);
assignTasksToClients(fullMetadata, allSourceTopics, topicGroups,
final UserTaskAssignmentListener userTaskAssignmentListener = assignTasksToClients(fullMetadata, allSourceTopics, topicGroups,
clientMetadataMap, partitionsForTask, racksForProcessConsumer, statefulTasks);

// ---------------- Step Three ---------------- //
Expand Down Expand Up @@ -473,7 +478,9 @@ public GroupAssignment assign(final Cluster metadata, final GroupSubscription gr
versionProbing
);

return new GroupAssignment(assignment);
final GroupAssignment groupAssignment = new GroupAssignment(assignment);
userTaskAssignmentListener.onAssignmentComputed(groupAssignment, groupSubscription);
return groupAssignment;
} catch (final MissingSourceTopicException e) {
log.error("Caught an error in the task assignment. Returning an error assignment.", e);
return new GroupAssignment(
Expand Down Expand Up @@ -561,7 +568,7 @@ private ApplicationState buildApplicationState(final TopologyMetadata topologyMe
);
}).collect(Collectors.toSet());

return new ApplicationStateImpl(
return new DefaultApplicationState(
assignmentConfigs.toPublicAssignmentConfigs(),
logicalTasks,
clientMetadataMap
Expand Down Expand Up @@ -720,7 +727,7 @@ private void checkAllPartitions(final Set<String> allSourceTopics,
* Assigns a set of tasks to each client (Streams instance) using the configured task assignor, and also
* populate the stateful tasks that have been assigned to the clients
*/
private void assignTasksToClients(final Cluster fullMetadata,
private UserTaskAssignmentListener assignTasksToClients(final Cluster fullMetadata,
final Set<String> allSourceTopics,
final Map<Subtopology, TopicsInfo> topicGroups,
final Map<UUID, ClientMetadata> clientMetadataMap,
Expand Down Expand Up @@ -768,15 +775,21 @@ private void assignTasksToClients(final Cluster fullMetadata,

final Optional<org.apache.kafka.streams.processor.assignment.TaskAssignor> userTaskAssignor =
userTaskAssignorSupplier.get();
UserTaskAssignmentListener userTaskAssignmentListener = (GroupAssignment assignment, GroupSubscription subscription) -> { };
if (userTaskAssignor.isPresent()) {
final ApplicationState applicationState = buildApplicationState(
taskManager.topologyMetadata(),
clientMetadataMap,
topicGroups,
fullMetadata
);
final TaskAssignment taskAssignment = userTaskAssignor.get().assign(applicationState);
final org.apache.kafka.streams.processor.assignment.TaskAssignor assignor = userTaskAssignor.get();
final TaskAssignment taskAssignment = assignor.assign(applicationState);
processStreamsPartitionAssignment(clientMetadataMap, taskAssignment);
final AssignmentError assignmentError = validateTaskAssignment(applicationState, taskAssignment);
userTaskAssignmentListener = (GroupAssignment assignment, GroupSubscription subscription) -> {
assignor.onAssignmentComputed(assignment, subscription, assignmentError);
};
} else {
final TaskAssignor taskAssignor = createTaskAssignor(lagComputationSuccessful);
final RackAwareTaskAssignor rackAwareTaskAssignor = new RackAwareTaskAssignor(
Expand Down Expand Up @@ -817,6 +830,7 @@ private void assignTasksToClients(final Cluster fullMetadata,
.sorted(comparingByKey())
.map(entry -> entry.getKey() + "=" + entry.getValue().currentAssignment())
.collect(Collectors.joining(Utils.NL)));
return userTaskAssignmentListener;
}

private TaskAssignor createTaskAssignor(final boolean lagComputationSuccessful) {
Expand Down Expand Up @@ -1546,33 +1560,36 @@ private void maybeScheduleFollowupRebalance(final long encodedNextScheduledRebal
private AssignmentError validateTaskAssignment(final ApplicationState applicationState,
final TaskAssignment taskAssignment) {
final Collection<KafkaStreamsAssignment> assignments = taskAssignment.assignment();
final Set<TaskId> activeTasksInOutput = new HashSet<>();
final Set<TaskId> standbyTasksInOutput = new HashSet<>();
final Map<TaskId, ProcessId> activeTasksInOutput = new HashMap<>();
final Map<TaskId, ProcessId> standbyTasksInOutput = new HashMap<>();
for (final KafkaStreamsAssignment assignment : assignments) {
final Set<TaskId> tasksForAssignment = new HashSet<>();
for (final KafkaStreamsAssignment.AssignedTask task : assignment.assignment()) {
if (activeTasksInOutput.contains(task.id()) && task.type() == KafkaStreamsAssignment.AssignedTask.Type.ACTIVE) {
log.error("Assignment is invalid: an active task was assigned multiple times: {}", task.id());
if (activeTasksInOutput.containsKey(task.id()) && task.type() == KafkaStreamsAssignment.AssignedTask.Type.ACTIVE) {
log.error("Assignment is invalid: active task {} was assigned to multiple KafkaStreams clients: {} and {}",
task.id(), assignment.processId().id(), activeTasksInOutput.get(task.id()).id());
return AssignmentError.ACTIVE_TASK_ASSIGNED_MULTIPLE_TIMES;
}

if (tasksForAssignment.contains(task.id())) {
log.error("Assignment is invalid: both an active and standby assignment of a task were assigned to the same client: {}", task.id());
log.error("Assignment is invalid: both an active and standby copy of task {} were assigned to KafkaStreams client {}",
task.id(), assignment.processId().id());
return AssignmentError.ACTIVE_AND_STANDBY_TASK_ASSIGNED_TO_SAME_KAFKASTREAMS;
}

tasksForAssignment.add(task.id());
if (task.type() == KafkaStreamsAssignment.AssignedTask.Type.ACTIVE) {
activeTasksInOutput.add(task.id());
activeTasksInOutput.put(task.id(), assignment.processId());
} else {
standbyTasksInOutput.add(task.id());
standbyTasksInOutput.put(task.id(), assignment.processId());
}
}
}

for (final TaskInfo task : applicationState.allTasks()) {
if (!task.isStateful() && standbyTasksInOutput.contains(task.id())) {
log.error("Assignment is invalid: a standby task was found for a stateless task: {}", task.id());
for (final TaskInfo task : applicationState.allTasks().values()) {
if (!task.isStateful() && standbyTasksInOutput.containsKey(task.id())) {
log.error("Assignment is invalid: standby task for stateless task {} was assigned to KafkaStreams client {}",
task.id(), standbyTasksInOutput.get(task.id()).id());
return AssignmentError.INVALID_STANDBY_TASK;
}
}
Expand All @@ -1583,24 +1600,24 @@ private AssignmentError validateTaskAssignment(final ApplicationState applicatio
for (final Map.Entry<ProcessId, KafkaStreamsState> entry : clientStates.entrySet()) {
final ProcessId processIdInInput = entry.getKey();
if (!clientsInOutput.contains(processIdInInput)) {
log.error("Assignment is invalid: one of the clients has no assignment: {}", processIdInInput.id());
log.error("Assignment is invalid: KafkaStreams client {} has no assignment", processIdInInput.id());
return AssignmentError.MISSING_PROCESS_ID;
}
}

for (final ProcessId processIdInOutput : clientsInOutput) {
if (!clientStates.containsKey(processIdInOutput)) {
log.error("Assignment is invalid: one of the clients in the assignment is unknown: {}", processIdInOutput.id());
log.error("Assignment is invalid: the KafkaStreams client {} is unknown", processIdInOutput.id());
return AssignmentError.UNKNOWN_PROCESS_ID;
}
}

final Set<TaskId> taskIdsInInput = applicationState.allTasks().stream().map(TaskInfo::id)
.collect(Collectors.toSet());
final Set<TaskId> taskIdsInInput = applicationState.allTasks().keySet();
for (final KafkaStreamsAssignment assignment : assignments) {
for (final KafkaStreamsAssignment.AssignedTask task : assignment.assignment()) {
if (!taskIdsInInput.contains(task.id())) {
log.error("Assignment is invalid: one of the tasks in the assignment is unknown: {}", task.id());
log.error("Assignment is invalid: task {} assigned to KafkaStreams client {} was unknown",
task.id(), assignment.processId().id());
return AssignmentError.UNKNOWN_TASK_ID;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
*/
package org.apache.kafka.streams.processor.internals.assignment;

import static java.util.Collections.unmodifiableSet;
import static java.util.Collections.unmodifiableMap;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.kafka.streams.processor.assignment.TaskInfo;
import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor.ClientMetadata;
import org.apache.kafka.streams.processor.TaskId;
Expand All @@ -32,22 +33,29 @@
import org.apache.kafka.streams.processor.assignment.ProcessId;
import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor;

public class ApplicationStateImpl implements ApplicationState {
public class DefaultApplicationState implements ApplicationState {

private final AssignmentConfigs assignmentConfigs;
private final Set<TaskInfo> tasks;
private final Map<TaskId, TaskInfo> tasks;
private final Map<UUID, ClientMetadata> clientStates;

public ApplicationStateImpl(final AssignmentConfigs assignmentConfigs,
final Set<TaskInfo> tasks,
final Map<UUID, ClientMetadata> clientStates) {
private final Map<Boolean, Map<ProcessId, KafkaStreamsState>> cachedKafkaStreamStates;

public DefaultApplicationState(final AssignmentConfigs assignmentConfigs,
final Set<TaskInfo> tasks,
final Map<UUID, ClientMetadata> clientStates) {
this.assignmentConfigs = assignmentConfigs;
this.tasks = unmodifiableSet(tasks);
this.tasks = unmodifiableMap(tasks.stream().collect(Collectors.toMap(TaskInfo::id, task -> task)));
this.clientStates = clientStates;
this.cachedKafkaStreamStates = new HashMap<>();
}

@Override
public Map<ProcessId, KafkaStreamsState> kafkaStreamsStates(final boolean computeTaskLags) {
if (cachedKafkaStreamStates.containsKey(computeTaskLags)) {
return cachedKafkaStreamStates.get(computeTaskLags);
}

final Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = new HashMap<>();
for (final Map.Entry<UUID, StreamsPartitionAssignor.ClientMetadata> clientEntry : clientStates.entrySet()) {
final ClientMetadata metadata = clientEntry.getValue();
Expand All @@ -68,6 +76,7 @@ public Map<ProcessId, KafkaStreamsState> kafkaStreamsStates(final boolean comput
kafkaStreamsStates.put(processId, kafkaStreamsState);
}

cachedKafkaStreamStates.put(computeTaskLags, kafkaStreamsStates);
return kafkaStreamsStates;
}

Expand All @@ -77,7 +86,7 @@ public AssignmentConfigs assignmentConfigs() {
}

@Override
public Set<TaskInfo> allTasks() {
public Map<TaskId, TaskInfo> allTasks() {
return tasks;
}
}

0 comments on commit cc269b0

Please sign in to comment.