package ai.knowly.langtoch.capability.dag;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.reflect.TypeToken;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:ai/knowly/langtoch/capability/dag/CapabilityDAG.class */
public class CapabilityDAG {
    private final HashMap<String, Node<?, ?>> nodes = new HashMap<>();
    private final Multimap<String, Object> inputMap = ArrayListMultimap.create();
    private final HashMap<String, Object> outputMap = new HashMap<>();
    private final Multimap<String, String> inDegreeMap = ArrayListMultimap.create();
    private final HashMap<String, TypeToken<?>> inputTypes = new HashMap<>();

    public <I, O> void addNode(Node<I, O> node, Class<I> cls) {
        this.nodes.put(node.getId(), node);
        this.inputTypes.put(node.getId(), TypeToken.of(cls));
        Iterator<String> it = node.getOutDegree().iterator();
        while (it.hasNext()) {
            this.inDegreeMap.put(it.next(), node.getId());
        }
    }

    public Map<String, Object> process(Map<String, Object> map) {
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            setInitialInput(entry.getKey(), entry.getValue());
        }
        for (String str : topologicalSort()) {
            Node<?, ?> node = this.nodes.get(str);
            Object processNode = processNode(node, this.inputMap.get(str));
            addOutput(str, processNode);
            Iterator<String> it = node.getOutDegree().iterator();
            while (it.hasNext()) {
                addInput(it.next(), processNode);
            }
        }
        HashMap hashMap = new HashMap();
        for (String str2 : getEndNodeIds()) {
            hashMap.put(str2, this.outputMap.get(str2));
        }
        return hashMap;
    }

    private <I, O> O processNode(Node<I, O> node, Collection<Object> collection) {
        return node.process(collection);
    }

    public Object getOutput(String str) {
        return this.outputMap.get(str);
    }

    private List<String> getEndNodeIds() {
        ArrayList arrayList = new ArrayList();
        for (Node<?, ?> node : this.nodes.values()) {
            if (node.getOutDegree().isEmpty()) {
                arrayList.add(node.getId());
            }
        }
        return arrayList;
    }

    private void setInitialInput(String str, Object obj) {
        if (!this.inputTypes.get(str).isSupertypeOf(obj.getClass())) {
            throw new IllegalArgumentException("Input type for node " + str + " does not match the expected type.");
        }
        this.inputMap.put(str, obj);
    }

    private void addOutput(String str, Object obj) {
        this.outputMap.put(str, obj);
    }

    private void addInput(String str, Object obj) {
        if (!this.inputTypes.get(str).isSupertypeOf(obj.getClass())) {
            throw new IllegalArgumentException("Input type for node " + str + " does not match the expected type.");
        }
        this.inputMap.put(str, obj);
    }

    private List<String> topologicalSort() {
        ArrayList arrayList = new ArrayList();
        LinkedList linkedList = new LinkedList();
        HashMap hashMap = new HashMap();
        for (String str : this.nodes.keySet()) {
            hashMap.put(str, Integer.valueOf(this.inDegreeMap.get(str).size()));
            if (((Integer) hashMap.get(str)).intValue() == 0) {
                linkedList.add(str);
            }
        }
        while (!linkedList.isEmpty()) {
            String str2 = (String) linkedList.poll();
            arrayList.add(str2);
            for (String str3 : this.nodes.get(str2).getOutDegree()) {
                int intValue = ((Integer) hashMap.get(str3)).intValue() - 1;
                hashMap.put(str3, Integer.valueOf(intValue));
                if (intValue == 0) {
                    linkedList.add(str3);
                }
            }
        }
        if (arrayList.size() != this.nodes.size()) {
            throw new RuntimeException("The graph contains a cycle and cannot be topologically sorted.");
        }
        return arrayList;
    }
}
