package org.apache.tez.dag.app.dag.impl;

import java.util.HashMap;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.event.DrainDispatcher;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.hadoop.yarn.util.Clock;
import org.apache.hadoop.yarn.util.SystemClock;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
import org.apache.tez.dag.api.oldrecords.TaskState;
import org.apache.tez.dag.app.AppContext;
import org.apache.tez.dag.app.ContainerContext;
import org.apache.tez.dag.app.TaskAttemptListener;
import org.apache.tez.dag.app.TaskHeartbeatHandler;
import org.apache.tez.dag.app.dag.TaskAttemptStateInternal;
import org.apache.tez.dag.app.dag.TaskStateInternal;
import org.apache.tez.dag.app.dag.Vertex;
import org.apache.tez.dag.app.dag.event.DAGEventType;
import org.apache.tez.dag.app.dag.event.TaskAttemptEvent;
import org.apache.tez.dag.app.dag.event.TaskAttemptEventType;
import org.apache.tez.dag.app.dag.event.TaskEvent;
import org.apache.tez.dag.app.dag.event.TaskEventRecoverTask;
import org.apache.tez.dag.app.dag.event.TaskEventType;
import org.apache.tez.dag.app.dag.event.VertexEventType;
import org.apache.tez.dag.app.rm.container.AMContainer;
import org.apache.tez.dag.history.events.TaskAttemptFinishedEvent;
import org.apache.tez.dag.history.events.TaskAttemptStartedEvent;
import org.apache.tez.dag.history.events.TaskStartedEvent;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/tez/dag/app/dag/impl/TestTaskRecovery.class */
public class TestTaskRecovery {
    private static final Log LOG = LogFactory.getLog(TestTaskImpl.class);
    private int taskCounter = 0;
    private int taskAttemptCounter = 0;
    private Configuration conf;
    private TaskAttemptListener taskAttemptListener;
    private TaskHeartbeatHandler taskHeartbeatHandler;
    private Credentials credentials;
    private Clock clock;
    private ApplicationId appId;
    private TezDAGID dagId;
    private TezVertexID vertexId;
    private Vertex vertex;
    private AppContext appContext;
    private Resource taskResource;
    private Map<String, LocalResource> localResources;
    private Map<String, String> environment;
    private String javaOpts;
    private boolean leafVertex;
    private ContainerContext containerContext;
    private ContainerId mockContainerId;
    private Container mockContainer;
    private AMContainer mockAMContainer;
    private NodeId mockNodeId;
    private TaskImpl task;
    private DrainDispatcher dispatcher;

    /* loaded from: input_file:org/apache/tez/dag/app/dag/impl/TestTaskRecovery$TaskAttemptEventHandler.class */
    class TaskAttemptEventHandler implements EventHandler<TaskAttemptEvent> {
        TaskAttemptEventHandler() {
        }

        public void handle(TaskAttemptEvent taskAttemptEvent) {
            TestTaskRecovery.this.task.getAttempt(taskAttemptEvent.getTaskAttemptID()).handle(taskAttemptEvent);
        }
    }

    /* loaded from: input_file:org/apache/tez/dag/app/dag/impl/TestTaskRecovery$TaskEventHandler.class */
    class TaskEventHandler implements EventHandler<TaskEvent> {
        TaskEventHandler() {
        }

        public void handle(TaskEvent taskEvent) {
            TestTaskRecovery.this.task.handle(taskEvent);
        }
    }

    @Before
    public void setUp() {
        this.conf = new Configuration();
        this.taskAttemptListener = (TaskAttemptListener) Mockito.mock(TaskAttemptListener.class);
        this.taskHeartbeatHandler = (TaskHeartbeatHandler) Mockito.mock(TaskHeartbeatHandler.class);
        this.credentials = new Credentials();
        this.clock = new SystemClock();
        this.appId = ApplicationId.newInstance(System.currentTimeMillis(), 1);
        this.dagId = TezDAGID.getInstance(this.appId, 1);
        this.vertexId = TezVertexID.getInstance(this.dagId, 1);
        this.vertex = (Vertex) Mockito.mock(Vertex.class, Mockito.RETURNS_DEEP_STUBS);
        Mockito.when(this.vertex.getProcessorDescriptor().getClassName()).thenReturn("");
        this.appContext = (AppContext) Mockito.mock(AppContext.class, Mockito.RETURNS_DEEP_STUBS);
        this.mockContainerId = (ContainerId) Mockito.mock(ContainerId.class);
        this.mockContainer = (Container) Mockito.mock(Container.class);
        this.mockAMContainer = (AMContainer) Mockito.mock(AMContainer.class);
        this.mockNodeId = (NodeId) Mockito.mock(NodeId.class);
        Mockito.when(this.mockContainer.getId()).thenReturn(this.mockContainerId);
        Mockito.when(this.mockContainer.getNodeId()).thenReturn(this.mockNodeId);
        Mockito.when(this.mockAMContainer.getContainer()).thenReturn(this.mockContainer);
        Mockito.when(this.appContext.getAllContainers().get(this.mockContainerId)).thenReturn(this.mockAMContainer);
        Mockito.when(this.appContext.getCurrentDAG().getVertex((TezVertexID) Matchers.any(TezVertexID.class))).thenReturn(this.vertex);
        Mockito.when(this.vertex.getProcessorDescriptor().getClassName()).thenReturn("");
        this.taskResource = Resource.newInstance(1024, 1);
        this.localResources = new HashMap();
        this.environment = new HashMap();
        this.javaOpts = "";
        this.leafVertex = false;
        this.containerContext = new ContainerContext(this.localResources, this.credentials, this.environment, this.javaOpts);
        this.dispatcher = new DrainDispatcher();
        this.dispatcher.register(DAGEventType.class, (EventHandler) Mockito.mock(EventHandler.class));
        this.dispatcher.register(VertexEventType.class, (EventHandler) Mockito.mock(EventHandler.class));
        this.dispatcher.register(TaskEventType.class, new TaskEventHandler());
        this.dispatcher.register(TaskAttemptEventType.class, new TaskAttemptEventHandler());
        this.dispatcher.init(new Configuration());
        this.dispatcher.start();
        this.task = new TaskImpl(this.vertexId, 1, this.dispatcher.getEventHandler(), this.conf, this.taskAttemptListener, this.clock, this.taskHeartbeatHandler, this.appContext, this.leafVertex, this.taskResource, this.containerContext);
    }

    @Test
    public void testTaskRecovery1() {
        TezTaskID newTaskID = getNewTaskID();
        TezTaskID newTaskID2 = getNewTaskID();
        int i = this.conf.getInt("tez.am.task.max.failed.attempts", 4);
        this.task.restoreFromEvent(new TaskStartedEvent(newTaskID2, "v1", 0L, 0L));
        for (int i2 = 0; i2 < i; i2++) {
            TezTaskAttemptID newTaskAttemptID = getNewTaskAttemptID(newTaskID);
            this.task.restoreFromEvent(new TaskAttemptStartedEvent(newTaskAttemptID, "v1", 0L, this.mockContainerId, this.mockNodeId, "", ""));
            this.task.restoreFromEvent(new TaskAttemptFinishedEvent(newTaskAttemptID, "v1", 0L, 0L, TaskAttemptState.KILLED, "", (TezCounters) null));
        }
        Assert.assertEquals(i, this.task.getAttempts().size());
        Assert.assertEquals(0L, this.task.failedAttempts);
        this.task.handle(new TaskEventRecoverTask(newTaskID));
        Assert.assertEquals(TaskStateInternal.RUNNING, this.task.getInternalState());
        Assert.assertEquals(i + 1, this.task.getAttempts().size());
    }

    @Test
    public void testTaskRecovery2() {
        TezTaskID newTaskID = getNewTaskID();
        TezTaskID newTaskID2 = getNewTaskID();
        int i = this.conf.getInt("tez.am.task.max.failed.attempts", 4);
        this.task.restoreFromEvent(new TaskStartedEvent(newTaskID2, "v1", 0L, 0L));
        for (int i2 = 0; i2 < i; i2++) {
            TezTaskAttemptID newTaskAttemptID = getNewTaskAttemptID(newTaskID);
            this.task.restoreFromEvent(new TaskAttemptStartedEvent(newTaskAttemptID, "v1", 0L, this.mockContainerId, this.mockNodeId, "", ""));
            this.task.restoreFromEvent(new TaskAttemptFinishedEvent(newTaskAttemptID, "v1", 0L, 0L, TaskAttemptState.FAILED, "", (TezCounters) null));
        }
        Assert.assertEquals(i, this.task.getAttempts().size());
        Assert.assertEquals(i, this.task.failedAttempts);
        this.task.handle(new TaskEventRecoverTask(newTaskID));
        Assert.assertEquals(TaskStateInternal.FAILED, this.task.getInternalState());
        Assert.assertEquals(i, this.task.getAttempts().size());
    }

    @Test
    public void testTaskRecovery3() throws InterruptedException {
        TezTaskID newTaskID = getNewTaskID();
        TezTaskID newTaskID2 = getNewTaskID();
        int i = this.conf.getInt("tez.am.task.max.failed.attempts", 4);
        this.task.restoreFromEvent(new TaskStartedEvent(newTaskID2, "v1", 0L, 0L));
        for (int i2 = 0; i2 < i - 1; i2++) {
            TezTaskAttemptID newTaskAttemptID = getNewTaskAttemptID(newTaskID);
            this.task.restoreFromEvent(new TaskAttemptStartedEvent(newTaskAttemptID, "v1", 0L, this.mockContainerId, this.mockNodeId, "", ""));
            this.task.restoreFromEvent(new TaskAttemptFinishedEvent(newTaskAttemptID, "v1", 0L, 0L, TaskAttemptState.FAILED, "", (TezCounters) null));
        }
        Assert.assertEquals(i - 1, this.task.getAttempts().size());
        Assert.assertEquals(i - 1, this.task.failedAttempts);
        TezTaskAttemptID newTaskAttemptID2 = getNewTaskAttemptID(newTaskID);
        Assert.assertEquals(TaskState.RUNNING, this.task.restoreFromEvent(new TaskAttemptStartedEvent(newTaskAttemptID2, "v1", 0L, this.mockContainerId, this.mockNodeId, "", "")));
        Assert.assertEquals(TaskAttemptStateInternal.NEW, this.task.getAttempt(newTaskAttemptID2).getInternalState());
        Assert.assertEquals(i, this.task.getAttempts().size());
        this.task.handle(new TaskEventRecoverTask(newTaskID));
        this.dispatcher.await();
        Assert.assertEquals(TaskStateInternal.RUNNING, this.task.getInternalState());
        Assert.assertEquals(TaskAttemptStateInternal.KILLED, this.task.getAttempt(newTaskAttemptID2).getInternalState());
        Assert.assertEquals(i - 1, this.task.failedAttempts);
        Assert.assertEquals(i + 1, this.task.getAttempts().size());
    }

    private TezTaskID getNewTaskID() {
        TezVertexID tezVertexID = this.vertexId;
        int i = this.taskCounter + 1;
        this.taskCounter = i;
        return TezTaskID.getInstance(tezVertexID, i);
    }

    private TezTaskAttemptID getNewTaskAttemptID(TezTaskID tezTaskID) {
        int i = this.taskAttemptCounter;
        this.taskAttemptCounter = i + 1;
        return TezTaskAttemptID.getInstance(tezTaskID, i);
    }
}
