package org.apache.tez.dag.app.dag.impl;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.app.dag.Task;
import org.apache.tez.dag.app.dag.Vertex;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.api.impl.EventMetaData;
import org.apache.tez.runtime.api.impl.TezEvent;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/tez/dag/app/dag/impl/TestEdge.class */
public class TestEdge {
    @Test(timeout = 5000)
    public void testCompositeEventHandling() {
        Edge edge = new Edge(EdgeProperty.create(EdgeProperty.DataMovementType.SCATTER_GATHER, EdgeProperty.DataSourceType.PERSISTED, EdgeProperty.SchedulingType.SEQUENTIAL, (OutputDescriptor) Mockito.mock(OutputDescriptor.class), (InputDescriptor) Mockito.mock(InputDescriptor.class)), (EventHandler) Mockito.mock(EventHandler.class));
        TezVertexID createVertexID = createVertexID(1);
        TezVertexID createVertexID2 = createVertexID(2);
        LinkedHashMap<TezTaskID, Task> mockTasks = mockTasks(createVertexID, 1);
        LinkedHashMap<TezTaskID, Task> mockTasks2 = mockTasks(createVertexID2, 5);
        TezTaskID next = mockTasks.keySet().iterator().next();
        Vertex mockVertex = mockVertex("src", createVertexID, mockTasks);
        Vertex mockVertex2 = mockVertex("dest", createVertexID2, mockTasks2);
        edge.setSourceVertex(mockVertex);
        edge.setDestinationVertex(mockVertex2);
        edge.initialize();
        TezTaskAttemptID createTAIDForTest = createTAIDForTest(next, 2);
        EventMetaData eventMetaData = new EventMetaData(EventMetaData.EventProducerConsumerType.OUTPUT, "consumerVertex", "producerVertex", createTAIDForTest);
        CompositeDataMovementEvent create = CompositeDataMovementEvent.create(0, mockTasks2.size(), ByteBuffer.wrap("bytes".getBytes()));
        create.setVersion(2);
        edge.sendTezEventToDestinationTasks(new TezEvent(create, eventMetaData));
        verifyEvents(createTAIDForTest, mockTasks2);
        resetTaskMocks(mockTasks2.values());
        for (int i = 0; i < mockTasks2.size(); i++) {
            DataMovementEvent create2 = DataMovementEvent.create(i, ByteBuffer.wrap("bytes".getBytes()));
            create2.setVersion(2);
            edge.sendTezEventToDestinationTasks(new TezEvent(create2, eventMetaData));
        }
        verifyEvents(createTAIDForTest, mockTasks2);
    }

    private void verifyEvents(TezTaskAttemptID tezTaskAttemptID, LinkedHashMap<TezTaskID, Task> linkedHashMap) {
        int i = 0;
        Iterator<Map.Entry<TezTaskID, Task>> it = linkedHashMap.entrySet().iterator();
        while (it.hasNext()) {
            Task value = it.next().getValue();
            ArgumentCaptor forClass = ArgumentCaptor.forClass(TezEvent.class);
            ((Task) Mockito.verify(value, Mockito.times(1))).registerTezEvent((TezEvent) forClass.capture());
            DataMovementEvent event = ((TezEvent) forClass.getValue()).getEvent();
            Assert.assertEquals(tezTaskAttemptID.getId(), event.getVersion());
            int i2 = i;
            i++;
            Assert.assertEquals(i2, event.getSourceIndex());
            Assert.assertEquals(tezTaskAttemptID.getTaskID().getId(), event.getTargetIndex());
            byte[] bArr = new byte[event.getUserPayload().limit() - event.getUserPayload().position()];
            event.getUserPayload().slice().get(bArr);
            Assert.assertTrue(Arrays.equals("bytes".getBytes(), bArr));
        }
    }

    private void resetTaskMocks(Collection<Task> collection) {
        for (Task task : collection) {
            TezTaskID taskId = task.getTaskId();
            Mockito.reset(new Task[]{task});
            ((Task) Mockito.doReturn(taskId).when(task)).getTaskId();
        }
    }

    private LinkedHashMap<TezTaskID, Task> mockTasks(TezVertexID tezVertexID, int i) {
        LinkedHashMap<TezTaskID, Task> linkedHashMap = new LinkedHashMap<>();
        for (int i2 = 0; i2 < i; i2++) {
            Task task = (Task) Mockito.mock(Task.class);
            TezTaskID tezTaskID = TezTaskID.getInstance(tezVertexID, i2);
            ((Task) Mockito.doReturn(tezTaskID).when(task)).getTaskId();
            linkedHashMap.put(tezTaskID, task);
        }
        return linkedHashMap;
    }

    private Vertex mockVertex(String str, TezVertexID tezVertexID, LinkedHashMap<TezTaskID, Task> linkedHashMap) {
        Vertex vertex = (Vertex) Mockito.mock(Vertex.class);
        ((Vertex) Mockito.doReturn(tezVertexID).when(vertex)).getVertexId();
        ((Vertex) Mockito.doReturn(str).when(vertex)).getName();
        ((Vertex) Mockito.doReturn(linkedHashMap).when(vertex)).getTasks();
        ((Vertex) Mockito.doReturn(Integer.valueOf(linkedHashMap.size())).when(vertex)).getTotalTasks();
        for (Map.Entry<TezTaskID, Task> entry : linkedHashMap.entrySet()) {
            ((Vertex) Mockito.doReturn(entry.getValue()).when(vertex)).getTask((TezTaskID) Matchers.eq(entry.getKey()));
            ((Vertex) Mockito.doReturn(entry.getValue()).when(vertex)).getTask(Matchers.eq(entry.getKey().getId()));
        }
        return vertex;
    }

    private TezVertexID createVertexID(int i) {
        return TezVertexID.getInstance(TezDAGID.getInstance("1000", 1, 1), i);
    }

    private TezTaskAttemptID createTAIDForTest(TezTaskID tezTaskID, int i) {
        return TezTaskAttemptID.getInstance(tezTaskID, i);
    }
}
