package ai.knowly.langtorch.capability.graph;

import com.google.auto.value.AutoValue;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.reflect.TypeToken;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;

@AutoValue
/* loaded from: input_file:ai/knowly/langtorch/capability/graph/CapabilityGraph.class */
public abstract class CapabilityGraph {
    public static CapabilityGraph create() {
        return new AutoValue_CapabilityGraph(new HashMap(), ArrayListMultimap.create(), new HashMap(), ArrayListMultimap.create(), new HashMap());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract HashMap<String, NodeAdapter<?, ?>> nodes();

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract Multimap<String, Object> inputMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract HashMap<String, Object> outputMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract Multimap<String, String> inDegreeMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract HashMap<String, TypeToken<?>> inputTypes();

    /* JADX WARN: Multi-variable type inference failed */
    public <I, O> void addNode(NodeAdapter<I, O> nodeAdapter, Class<I> cls) {
        nodes().put(nodeAdapter.getId(), nodeAdapter);
        inputTypes().put(nodeAdapter.getId(), TypeToken.of(cls));
        Iterator<String> it = nodeAdapter.getOutDegree().iterator();
        while (it.hasNext()) {
            inDegreeMap().put(it.next(), nodeAdapter.getId());
        }
    }

    public Map<String, Object> process(Map<String, Object> map) throws ExecutionException, InterruptedException {
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            setInitialInput(entry.getKey(), entry.getValue());
        }
        for (String str : topologicalSort()) {
            NodeAdapter<?, ?> nodeAdapter = nodes().get(str);
            Object processNode = processNode(nodeAdapter, inputMap().get(str));
            addOutput(str, processNode);
            Iterator<String> it = nodeAdapter.getOutDegree().iterator();
            while (it.hasNext()) {
                addInput(it.next(), processNode);
            }
        }
        HashMap hashMap = new HashMap();
        for (String str2 : getEndNodeIds()) {
            hashMap.put(str2, outputMap().get(str2));
        }
        return hashMap;
    }

    private <I, O> O processNode(NodeAdapter<I, O> nodeAdapter, Collection<Object> collection) throws ExecutionException, InterruptedException {
        return nodeAdapter.process(collection);
    }

    public Object getOutput(String str) {
        return outputMap().get(str);
    }

    private List<String> getEndNodeIds() {
        ArrayList arrayList = new ArrayList();
        for (NodeAdapter<?, ?> nodeAdapter : nodes().values()) {
            if (nodeAdapter.getOutDegree().isEmpty()) {
                arrayList.add(nodeAdapter.getId());
            }
        }
        return arrayList;
    }

    private void setInitialInput(String str, Object obj) {
        if (!inputTypes().get(str).isSupertypeOf(obj.getClass())) {
            throw new IllegalArgumentException("Input type for node " + str + " does not match the expected type");
        }
        inputMap().put(str, obj);
    }

    private void addInput(String str, Object obj) {
        inputMap().put(str, obj);
    }

    private void addOutput(String str, Object obj) {
        outputMap().put(str, obj);
    }

    private List<String> topologicalSort() {
        ArrayList arrayList = new ArrayList();
        LinkedList linkedList = new LinkedList();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, NodeAdapter<?, ?>> entry : nodes().entrySet()) {
            int size = inDegreeMap().get(entry.getKey()).size();
            hashMap.put(entry.getKey(), Integer.valueOf(size));
            if (size == 0) {
                linkedList.offer(entry.getKey());
            }
        }
        while (!linkedList.isEmpty()) {
            String str = (String) linkedList.poll();
            arrayList.add(str);
            for (String str2 : nodes().get(str).getOutDegree()) {
                int intValue = ((Integer) hashMap.get(str2)).intValue() - 1;
                hashMap.put(str2, Integer.valueOf(intValue));
                if (intValue == 0) {
                    linkedList.offer(str2);
                }
            }
        }
        if (arrayList.size() != nodes().size()) {
            throw new IllegalStateException("The graph contains a cycle");
        }
        return arrayList;
    }
}
