package edu.iu.dsc.tws.task.impl;

import edu.iu.dsc.tws.api.comms.CommunicationContext;
import edu.iu.dsc.tws.api.compute.graph.ComputeGraph;
import edu.iu.dsc.tws.api.compute.graph.Edge;
import edu.iu.dsc.tws.api.compute.graph.Vertex;
import edu.iu.dsc.tws.task.impl.ops.AbstractOpsConfig;
import edu.iu.dsc.tws.task.impl.ops.AllGatherConfig;
import edu.iu.dsc.tws.task.impl.ops.AllReduceConfig;
import edu.iu.dsc.tws.task.impl.ops.BroadcastConfig;
import edu.iu.dsc.tws.task.impl.ops.DirectConfig;
import edu.iu.dsc.tws.task.impl.ops.GatherConfig;
import edu.iu.dsc.tws.task.impl.ops.JoinConfig;
import edu.iu.dsc.tws.task.impl.ops.KeyedGatherConfig;
import edu.iu.dsc.tws.task.impl.ops.KeyedPartitionConfig;
import edu.iu.dsc.tws.task.impl.ops.KeyedReduceConfig;
import edu.iu.dsc.tws.task.impl.ops.PartitionConfig;
import edu.iu.dsc.tws.task.impl.ops.ReduceConfig;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/iu/dsc/tws/task/impl/ComputeConnection.class */
public class ComputeConnection {
    private String nodeName;
    private Map<String, Map<String, Edge>> inputs = new HashMap();
    private Map<String, Set<AbstractOpsConfig>> autoConnectConfig = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public ComputeConnection(String str) {
        this.nodeName = str;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void putEdgeFromSource(String str, Edge edge) {
        Map<String, Edge> computeIfAbsent = this.inputs.computeIfAbsent(str, str2 -> {
            return new HashMap();
        });
        if (computeIfAbsent.containsKey(edge.getName())) {
            throw new RuntimeException("Edges from the same source should be unique. Found " + edge.getName() + " already defined from source " + str);
        }
        computeIfAbsent.put(edge.getName(), edge);
    }

    private void addToAutoConfig(String str, AbstractOpsConfig abstractOpsConfig) {
        this.autoConnectConfig.computeIfAbsent(str, str2 -> {
            return new HashSet();
        }).add(abstractOpsConfig);
    }

    public BroadcastConfig broadcast(String str) {
        BroadcastConfig broadcastConfig = new BroadcastConfig(str, this);
        addToAutoConfig(str, broadcastConfig);
        return broadcastConfig;
    }

    public ReduceConfig reduce(String str) {
        ReduceConfig reduceConfig = new ReduceConfig(str, this);
        addToAutoConfig(str, reduceConfig);
        return reduceConfig;
    }

    public KeyedReduceConfig keyedReduce(String str) {
        KeyedReduceConfig keyedReduceConfig = new KeyedReduceConfig(str, this);
        addToAutoConfig(str, keyedReduceConfig);
        return keyedReduceConfig;
    }

    public GatherConfig gather(String str) {
        GatherConfig gatherConfig = new GatherConfig(str, this);
        addToAutoConfig(str, gatherConfig);
        return gatherConfig;
    }

    public KeyedGatherConfig keyedGather(String str) {
        KeyedGatherConfig keyedGatherConfig = new KeyedGatherConfig(str, this);
        addToAutoConfig(str, keyedGatherConfig);
        return keyedGatherConfig;
    }

    public JoinConfig innerJoin(String str, String str2) {
        JoinConfig joinConfig = new JoinConfig(str, str2, this, CommunicationContext.JoinType.INNER);
        addToAutoConfig(str, joinConfig);
        return joinConfig;
    }

    public JoinConfig fullOuterJoin(String str, String str2) {
        JoinConfig joinConfig = new JoinConfig(str, str2, this, CommunicationContext.JoinType.FULL_OUTER);
        addToAutoConfig(str, joinConfig);
        return joinConfig;
    }

    public JoinConfig leftOuterJoin(String str, String str2) {
        JoinConfig joinConfig = new JoinConfig(str, str2, this, CommunicationContext.JoinType.LEFT);
        addToAutoConfig(str, joinConfig);
        return joinConfig;
    }

    public JoinConfig rightOuterJoin(String str, String str2) {
        JoinConfig joinConfig = new JoinConfig(str, str2, this, CommunicationContext.JoinType.RIGHT);
        addToAutoConfig(str, joinConfig);
        return joinConfig;
    }

    public PartitionConfig partition(String str) {
        PartitionConfig partitionConfig = new PartitionConfig(str, this);
        addToAutoConfig(str, partitionConfig);
        return partitionConfig;
    }

    public KeyedPartitionConfig keyedPartition(String str) {
        KeyedPartitionConfig keyedPartitionConfig = new KeyedPartitionConfig(str, this);
        addToAutoConfig(str, keyedPartitionConfig);
        return keyedPartitionConfig;
    }

    public AllReduceConfig allreduce(String str) {
        AllReduceConfig allReduceConfig = new AllReduceConfig(str, this);
        addToAutoConfig(str, allReduceConfig);
        return allReduceConfig;
    }

    public AllGatherConfig allgather(String str) {
        AllGatherConfig allGatherConfig = new AllGatherConfig(str, this);
        addToAutoConfig(str, allGatherConfig);
        return allGatherConfig;
    }

    public DirectConfig direct(String str) {
        DirectConfig directConfig = new DirectConfig(str, this);
        addToAutoConfig(str, directConfig);
        return directConfig;
    }

    private void doAutoConnect() {
        this.autoConnectConfig.forEach((str, set) -> {
            set.forEach(abstractOpsConfig -> {
                if (this.inputs.containsKey(str) && this.inputs.get(str).containsKey(abstractOpsConfig.getEdgeName())) {
                    return;
                }
                abstractOpsConfig.connect();
            });
        });
        this.autoConnectConfig.clear();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void build(ComputeGraph computeGraph) {
        doAutoConnect();
        this.inputs.forEach((str, map) -> {
            map.forEach((str, edge) -> {
                Vertex vertex = computeGraph.vertex(this.nodeName);
                if (vertex == null) {
                    throw new RuntimeException("Failed to connect non-existing task: " + this.nodeName);
                }
                Vertex vertex2 = computeGraph.vertex(str);
                if (vertex2 == null) {
                    throw new RuntimeException("Failed to connect non-existing task: " + str);
                }
                computeGraph.addTaskEdge(vertex2, vertex, edge);
            });
        });
    }
}
