package cc.redberry.transformation.ec;

import cc.redberry.concurrent.ConcurrentGrowingList;
import cc.redberry.concurrent.InputPort;
import cc.redberry.core.context.CC;
import cc.redberry.core.indexmapping.IndexMappingBuffer;
import cc.redberry.core.indexmapping.IndexMappingBufferTester;
import cc.redberry.core.indexmapping.IndexMappings;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorNumber;
import cc.redberry.transformation.concurrent.CollectIP;
import cc.redberry.transformation.concurrent.CollectIPFactory;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:cc/redberry/transformation/ec/EqualCollectInputPort.class */
public class EqualCollectInputPort implements InputPort<Split> {
    private final CollectIPFactory cipFactory;
    static final /* synthetic */ boolean $assertionsDisabled;
    private final ConcurrentHashMap<Integer, ConcurrentGrowingList<FactorNode>> nodes = new ConcurrentHashMap<>();
    private final ThreadLocal<ConcurrentGrowingList<FactorNode>> threadLocalList = new ThreadLocal<ConcurrentGrowingList<FactorNode>>() { // from class: cc.redberry.transformation.ec.EqualCollectInputPort.1
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public ConcurrentGrowingList<FactorNode> initialValue() {
            return new ConcurrentGrowingList<>();
        }
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/transformation/ec/EqualCollectInputPort$FactorNode.class */
    public static final class FactorNode {
        final CollectIP cip;
        final Tensor factor;

        public FactorNode(CollectIP collectIP, Tensor tensor) {
            this.cip = collectIP;
            this.factor = tensor;
        }
    }

    public EqualCollectInputPort(CollectIPFactory collectIPFactory) {
        this.cipFactory = collectIPFactory;
    }

    @Override // cc.redberry.concurrent.InputPort
    public void put(Split split) throws InterruptedException {
        FactorNode next;
        Boolean compare;
        int hashCode = split.factor.hashCode();
        ConcurrentGrowingList<FactorNode> concurrentGrowingList = this.threadLocalList.get();
        ConcurrentGrowingList<FactorNode> putIfAbsent = this.nodes.putIfAbsent(Integer.valueOf(hashCode), concurrentGrowingList);
        if (putIfAbsent == null) {
            putIfAbsent = concurrentGrowingList;
            this.threadLocalList.remove();
        }
        ConcurrentGrowingList<FactorNode>.GrowingIterator it = putIfAbsent.iterator();
        FactorNode factorNode = null;
        do {
            next = it.next();
            if (next == null) {
                if (factorNode == null) {
                    factorNode = new FactorNode(this.cipFactory.create(), split.factor);
                }
                FactorNode factorNode2 = it.set(factorNode);
                next = factorNode2;
                if (factorNode2 == null) {
                    factorNode.cip.put(split.summand);
                    return;
                }
            }
            compare = compare(split.factor, next.factor);
        } while (compare == null);
        if (compare.booleanValue()) {
            next.cip.put(new Product(TensorNumber.createMINUSONE(), split.summand));
        } else {
            next.cip.put(split.summand);
        }
    }

    public Tensor result() throws InterruptedException {
        Sum sum = new Sum();
        Iterator<ConcurrentGrowingList<FactorNode>> it = this.nodes.values().iterator();
        while (it.hasNext()) {
            ConcurrentGrowingList<FactorNode>.GrowingIterator it2 = it.next().iterator();
            while (true) {
                FactorNode next = it2.next();
                if (next != null) {
                    sum.add(new Product(next.cip.result(), next.factor));
                }
            }
        }
        return sum.equivalent();
    }

    private static Boolean compare(Tensor tensor, Tensor tensor2) {
        int[] copy = tensor.getIndices().getFreeIndices().getAllIndices().copy();
        for (int i = 0; i < copy.length; i++) {
            copy[i] = IndicesUtils.getNameWithType(copy[i]);
        }
        IndexMappingBuffer take = IndexMappings.createPort(new IndexMappingBufferTester(copy, CC.withMetric()), tensor, tensor2).take();
        if (take == null) {
            return null;
        }
        take.removeContracted();
        if ($assertionsDisabled || take.isEmpty()) {
            return Boolean.valueOf(take.getSignum());
        }
        throw new AssertionError();
    }

    static {
        $assertionsDisabled = !EqualCollectInputPort.class.desiredAssertionStatus();
    }
}
