package cc.redberry.transformation;

import cc.redberry.core.number.ComplexElement;
import cc.redberry.core.tensor.AbstractScalarFunction;
import cc.redberry.core.tensor.Derivative;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.TensorIterator;
import cc.redberry.core.tensor.TensorNumber;
import cc.redberry.core.tensor.testing.TTest;
import java.util.ArrayList;

/* loaded from: input_file:cc/redberry/transformation/GetScalarDerivative.class */
public class GetScalarDerivative {
    public static GetScalarDerivative INSTANCE = new GetScalarDerivative();

    private GetScalarDerivative() {
    }

    public Tensor transform(Tensor tensor) {
        if (TTest.testIsScalar(tensor) && (tensor instanceof Derivative)) {
            Derivative derivative = (Derivative) tensor;
            Tensor mo6clone = derivative.getTarget().mo6clone();
            for (int i = 0; i < derivative.getDerivativeOrder(); i++) {
                mo6clone = getDerivative(mo6clone, derivative.getVariation(i));
                if (mo6clone == null) {
                    return TensorNumber.createZERO();
                }
            }
            return mo6clone;
        }
        return tensor;
    }

    private Tensor getDerivative(Tensor tensor, SimpleTensor simpleTensor) {
        if (tensor instanceof Sum) {
            TensorIterator it = ((Sum) tensor).iterator();
            Sum sum = new Sum();
            while (it.hasNext()) {
                Tensor derivative = getDerivative(it.next(), simpleTensor);
                if (derivative != null) {
                    sum.add(derivative);
                }
            }
            if (sum.isEmpty()) {
                return null;
            }
            return sum.equivalent();
        }
        if (tensor instanceof Product) {
            Product product = (Product) tensor;
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < product.size(); i++) {
                Tensor derivative2 = getDerivative(product.getElements().get(i), simpleTensor);
                if (derivative2 != null) {
                    Product mo6clone = product.mo6clone();
                    mo6clone.getElements().remove(i);
                    if (!isOne(derivative2)) {
                        mo6clone.add(derivative2);
                    }
                    arrayList.add(mo6clone.equivalent());
                }
            }
            if (arrayList.isEmpty()) {
                return null;
            }
            return arrayList.size() == 1 ? (Tensor) arrayList.get(0) : new Sum(arrayList);
        }
        if (tensor.getClass() == SimpleTensor.class) {
            if (((SimpleTensor) tensor).getName() == simpleTensor.getName()) {
                return TensorNumber.createONE();
            }
            return null;
        }
        if (tensor.getClass() == TensorField.class) {
            for (Tensor tensor2 : ((TensorField) tensor).getArgs()) {
                if (getDerivative(tensor2, simpleTensor) != null) {
                    return Derivative.createFromInversed(tensor, simpleTensor);
                }
            }
            return null;
        }
        if (!(tensor instanceof AbstractScalarFunction)) {
            return null;
        }
        AbstractScalarFunction abstractScalarFunction = (AbstractScalarFunction) tensor;
        Tensor derivative3 = getDerivative(abstractScalarFunction.getInnerTensor(), simpleTensor);
        if (derivative3 == null) {
            return null;
        }
        return isOne(derivative3) ? abstractScalarFunction.derivative() : new Product(abstractScalarFunction.derivative(), derivative3);
    }

    private static boolean isOne(Tensor tensor) {
        return (tensor instanceof TensorNumber) && ((TensorNumber) tensor).getValue().equals(ComplexElement.ONE);
    }
}
