package hex.Infogram;

import hex.Infogram.InfogramModel;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderHelper;
import hex.ModelCategory;
import hex.gam.MatrixFrameUtils.GamUtils;
import hex.genmodel.utils.DistributionFamily;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.stream.IntStream;
import water.DKV;
import water.H2O;
import water.Key;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/Infogram/Infogram.class */
public class Infogram extends ModelBuilder<InfogramModel, InfogramModel.InfogramParameters, InfogramModel.InfogramModelOutput> {
    static final double NORMALIZE_ADMISSIBLE_INDEX = 1.0d / Math.sqrt(2.0d);
    boolean _buildCore;
    String[] _topKPredictors;
    Frame _baseOrSensitiveFrame;
    String[] _modelDescription;
    int _numModels;
    double[] _cmi;
    double[] _cmiValid;
    double[] _cmiCV;
    double[] _cmiRaw;
    double[] _cmiRawValid;
    double[] _cmiRawCV;
    String[] _columnsCV;
    TwoDimTable _varImp;
    int _numPredictors;
    Key<Frame> _cmiRelKey;
    Key<Frame> _cmiRelKeyValid;
    Key<Frame> _cmiRelKeyCV;
    Key<Frame>[] _generatedFrameKeys;
    boolean _cvDone;
    private transient InfogramModel _model;
    long _validNonZeroNumRows;
    int _nFoldOrig;
    Model.Parameters.FoldAssignmentScheme _foldAssignmentOrig;
    String _foldColumnOrig;

    /* loaded from: input_file:hex/Infogram/Infogram$InfogramDriver.class */
    private class InfogramDriver extends ModelBuilder<InfogramModel, InfogramModel.InfogramParameters, InfogramModel.InfogramModelOutput>.Driver {
        private InfogramDriver() {
            super(Infogram.this);
        }

        void prepareModelTrainingFrame() {
            Infogram.this._generatedFrameKeys = new Key[((((InfogramModel.InfogramParameters) Infogram.this._parms)._top_n_features + 1) * 3) + 2];
            String[] extractPredictors = InfogramUtils.extractPredictors((InfogramModel.InfogramParameters) Infogram.this._parms, Infogram.this._train, Infogram.this._foldColumnOrig);
            Infogram.this._baseOrSensitiveFrame = InfogramUtils.extractTrainingFrame((InfogramModel.InfogramParameters) Infogram.this._parms, ((InfogramModel.InfogramParameters) Infogram.this._parms)._protected_columns, 1.0d, ((InfogramModel.InfogramParameters) Infogram.this._parms).train().clone());
            ((InfogramModel.InfogramParameters) Infogram.this._parms).extraModelSpecificParams();
            Infogram.this._topKPredictors = InfogramUtils.extractTopKPredictors((InfogramModel.InfogramParameters) Infogram.this._parms, ((InfogramModel.InfogramParameters) Infogram.this._parms).train(), extractPredictors, Infogram.this._generatedFrameKeys);
            Infogram.this._numModels = 1 + Infogram.this._topKPredictors.length;
            Infogram.this._modelDescription = InfogramUtils.generateModelDescription(Infogram.this._topKPredictors, ((InfogramModel.InfogramParameters) Infogram.this._parms)._protected_columns);
        }

        public void computeImpl() {
            Infogram.this.init(true);
            if (Infogram.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(Infogram.this);
            }
            Infogram.this._job.update(0L, "Initializing model training");
            buildModel();
        }

        public final void buildModel() {
            try {
                boolean z = ((InfogramModel.InfogramParameters) Infogram.this._parms).valid() != null;
                prepareModelTrainingFrame();
                Infogram.this._model = new InfogramModel(Infogram.this.dest(), (InfogramModel.InfogramParameters) Infogram.this._parms, new InfogramModel.InfogramModelOutput(Infogram.this)).delete_and_lock(Infogram.this._job);
                ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._start_time = System.currentTimeMillis();
                Infogram.this._cmiRaw = new double[Infogram.this._numModels];
                if (((InfogramModel.InfogramParameters) Infogram.this._parms).valid() != null) {
                    Infogram.this._cmiRawValid = new double[Infogram.this._numModels];
                }
                buildInfoGramsNRelevance(z);
                Infogram.this._job.update(1L, "finished building models for Infogram ...");
                ((InfogramModel.InfogramModelOutput) Infogram.this._model._output).setDistribution(((InfogramModel.InfogramParameters) Infogram.this._parms)._distribution);
                copyCMIRelevance((InfogramModel.InfogramModelOutput) Infogram.this._model._output);
                Infogram.this._cmi = ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._cmi;
                if (z) {
                    copyCMIRelevanceValid((InfogramModel.InfogramModelOutput) Infogram.this._model._output);
                }
                Infogram.this._cmiRelKey = setCMIRelFrame(z);
                ((InfogramModel.InfogramModelOutput) Infogram.this._model._output).extractAdmissibleFeatures((Frame) DKV.getGet(Infogram.this._cmiRelKey), false, false);
                if (z) {
                    Infogram.this._cmiRelKeyValid = ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._admissible_score_key_valid;
                    ((InfogramModel.InfogramModelOutput) Infogram.this._model._output).extractAdmissibleFeatures((Frame) DKV.getGet(Infogram.this._cmiRelKeyValid), true, false);
                    ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._validNonZeroNumRows = Infogram.this._validNonZeroNumRows;
                }
                if (Infogram.this._cvDone) {
                    Infogram.this._cmiRelKeyCV = setCMIRelFrameCV();
                    ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._admissible_score_key_xval = Infogram.this._cmiRelKeyCV;
                    ((InfogramModel.InfogramModelOutput) Infogram.this._model._output).extractAdmissibleFeatures((Frame) DKV.getGet(Infogram.this._cmiRelKeyCV), false, true);
                    ((InfogramModel.InfogramParameters) Infogram.this._parms)._nfolds = Infogram.this._nFoldOrig;
                    ((InfogramModel.InfogramParameters) Infogram.this._parms)._fold_assignment = Infogram.this._foldAssignmentOrig;
                    ((InfogramModel.InfogramParameters) Infogram.this._parms)._fold_column = Infogram.this._foldColumnOrig;
                }
                Infogram.this._job.update(1L, "Infogram building completed...");
                Infogram.this._model.update(Infogram.this._job._key);
                DKV.remove(Infogram.this._baseOrSensitiveFrame._key);
                InfogramUtils.removeFromDKV(Infogram.this._generatedFrameKeys);
                ArrayList arrayList = new ArrayList();
                if (Infogram.this._model != null) {
                    GamUtils.keepFrameKeys(arrayList, new Key[]{Infogram.this._cmiRelKey});
                    if (Infogram.this._cmiRelKeyValid != null) {
                        GamUtils.keepFrameKeys(arrayList, new Key[]{Infogram.this._cmiRelKeyValid});
                    }
                    if (Infogram.this._cmiRelKeyCV != null) {
                        GamUtils.keepFrameKeys(arrayList, new Key[]{Infogram.this._cmiRelKeyCV});
                    }
                    Infogram.this._model.update(Infogram.this._job._key);
                    Infogram.this._model.unlock(Infogram.this._job);
                }
                Scope.exit((Key[]) arrayList.toArray(new Key[arrayList.size()]));
            } catch (Throwable th) {
                DKV.remove(Infogram.this._baseOrSensitiveFrame._key);
                InfogramUtils.removeFromDKV(Infogram.this._generatedFrameKeys);
                ArrayList arrayList2 = new ArrayList();
                if (Infogram.this._model != null) {
                    GamUtils.keepFrameKeys(arrayList2, new Key[]{Infogram.this._cmiRelKey});
                    if (Infogram.this._cmiRelKeyValid != null) {
                        GamUtils.keepFrameKeys(arrayList2, new Key[]{Infogram.this._cmiRelKeyValid});
                    }
                    if (Infogram.this._cmiRelKeyCV != null) {
                        GamUtils.keepFrameKeys(arrayList2, new Key[]{Infogram.this._cmiRelKeyCV});
                    }
                    Infogram.this._model.update(Infogram.this._job._key);
                    Infogram.this._model.unlock(Infogram.this._job);
                }
                Scope.exit((Key[]) arrayList2.toArray(new Key[arrayList2.size()]));
                throw th;
            }
        }

        private void copyCMIRelevance(InfogramModel.InfogramModelOutput infogramModelOutput) {
            infogramModelOutput._cmi_raw = new double[Infogram.this._cmi.length];
            System.arraycopy(Infogram.this._cmiRaw, 0, infogramModelOutput._cmi_raw, 0, infogramModelOutput._cmi_raw.length);
            infogramModelOutput._admissible_index = new double[Infogram.this._cmi.length];
            infogramModelOutput._admissible = new double[Infogram.this._cmi.length];
            infogramModelOutput._cmi = (double[]) Infogram.this._cmi.clone();
            infogramModelOutput._topKFeatures = (String[]) Infogram.this._topKPredictors.clone();
            infogramModelOutput._all_predictor_names = (String[]) Infogram.this._topKPredictors.clone();
            int rowDim = Infogram.this._varImp.getRowDim();
            ArrayList arrayList = new ArrayList(Arrays.asList(Infogram.this._varImp.getRowHeaders()));
            infogramModelOutput._relevance = new double[rowDim];
            copyGenerateAdmissibleIndex(rowDim, arrayList, infogramModelOutput._cmi, infogramModelOutput._cmi_raw, infogramModelOutput._relevance, infogramModelOutput._admissible_index, infogramModelOutput._admissible, infogramModelOutput._all_predictor_names);
        }

        public void copyCMIRelevanceValid(InfogramModel.InfogramModelOutput infogramModelOutput) {
            infogramModelOutput._cmi_raw_valid = new double[Infogram.this._cmiValid.length];
            System.arraycopy(Infogram.this._cmiRawValid, 0, infogramModelOutput._cmi_raw_valid, 0, infogramModelOutput._cmi_raw_valid.length);
            infogramModelOutput._admissible_index_valid = new double[Infogram.this._cmiValid.length];
            infogramModelOutput._admissible_valid = new double[Infogram.this._cmiValid.length];
            infogramModelOutput._cmi_valid = (double[]) Infogram.this._cmiValid.clone();
            int rowDim = Infogram.this._varImp.getRowDim();
            ArrayList arrayList = new ArrayList(Arrays.asList(Infogram.this._varImp.getRowHeaders()));
            infogramModelOutput._all_predictor_names_valid = (String[]) infogramModelOutput._topKFeatures.clone();
            infogramModelOutput._relevance_valid = new double[rowDim];
            copyGenerateAdmissibleIndex(rowDim, arrayList, infogramModelOutput._cmi_valid, infogramModelOutput._cmi_raw_valid, infogramModelOutput._relevance_valid, infogramModelOutput._admissible_index_valid, infogramModelOutput._admissible_valid, infogramModelOutput._all_predictor_names_valid);
        }

        public void copyGenerateAdmissibleIndex(int i, List<String> list, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, String[] strArr) {
            for (int i2 = 0; i2 < i; i2++) {
                dArr3[i2] = ((Double) Infogram.this._varImp.get(list.indexOf(strArr[i2]), 1)).doubleValue();
                double d = dArr3[i2];
                double d2 = dArr[i2];
                dArr4[i2] = Infogram.NORMALIZE_ADMISSIBLE_INDEX * Math.sqrt((d * d) + (d2 * d2));
                dArr5[i2] = (dArr3[i2] < ((InfogramModel.InfogramParameters) Infogram.this._parms)._relevance_threshold || dArr[i2] < ((InfogramModel.InfogramParameters) Infogram.this._parms)._cmi_threshold) ? 0.0d : 1.0d;
            }
            int[] array = IntStream.range(0, dArr.length).toArray();
            ArrayUtils.sort(array, dArr4, -1, -1);
            InfogramModel.InfogramModelOutput.sortCMIRel(array, dArr3, dArr2, dArr, strArr, dArr4, dArr5);
        }

        private Key<Frame> setCMIRelFrame(boolean z) {
            Frame generateCMIRelevance = InfogramUtils.generateCMIRelevance(((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._all_predictor_names, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._admissible, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._admissible_index, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._relevance, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._cmi, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._cmi_raw, Infogram.this._buildCore);
            ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._admissible_score_key = generateCMIRelevance._key;
            if (z) {
                Frame generateCMIRelevance2 = InfogramUtils.generateCMIRelevance(((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._all_predictor_names_valid, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._admissible_valid, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._admissible_index_valid, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._relevance_valid, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._cmi_valid, ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._cmi_raw_valid, Infogram.this._buildCore);
                ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._admissible_score_key_valid = generateCMIRelevance2._key;
            }
            return generateCMIRelevance._key;
        }

        private void cleanUpCV() {
            String[] strArr = ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._all_predictor_names;
            ArrayList arrayList = new ArrayList(Arrays.asList(Infogram.this._columnsCV));
            int length = strArr.length;
            String[] strArr2 = new String[length];
            double[] dArr = new double[length];
            double[] dArr2 = new double[length];
            for (int i = 0; i < length; i++) {
                String str = strArr[i];
                int indexOf = arrayList.indexOf(str);
                if (indexOf >= 0) {
                    strArr2[i] = str;
                    dArr[i] = Infogram.this._cmiCV[indexOf];
                    dArr2[i] = Infogram.this._cmiRawCV[indexOf];
                }
            }
            Infogram.this._columnsCV = (String[]) strArr2.clone();
            Infogram.this._cmiCV = (double[]) dArr.clone();
            Infogram.this._cmiRawCV = (double[]) dArr2.clone();
        }

        private Key<Frame> setCMIRelFrameCV() {
            String[] strArr = ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._all_predictor_names;
            double[] dArr = ((InfogramModel.InfogramModelOutput) Infogram.this._model._output)._relevance;
            double[] dArr2 = new double[dArr.length];
            int length = strArr.length;
            double[] dArr3 = new double[length];
            double[] dArr4 = new double[length];
            cleanUpCV();
            for (int i = 0; i < length; i++) {
                dArr2[i] = dArr[i];
                double d = 1.0d - dArr2[i];
                double d2 = 1.0d - Infogram.this._cmiCV[i];
                dArr3[i] = Math.sqrt((d * d) + (d2 * d2)) * Infogram.NORMALIZE_ADMISSIBLE_INDEX;
                dArr4[i] = (Infogram.this._cmiCV[i] < ((InfogramModel.InfogramParameters) Infogram.this._parms)._cmi_threshold || dArr2[i] < ((InfogramModel.InfogramParameters) Infogram.this._parms)._relevance_threshold) ? 0.0d : 1.0d;
            }
            int[] array = IntStream.range(0, dArr2.length).toArray();
            Infogram.this._columnsCV = (String[]) strArr.clone();
            ArrayUtils.sort(array, dArr3, -1, -1);
            InfogramModel.InfogramModelOutput.sortCMIRel(array, dArr2, Infogram.this._cmiRawCV, Infogram.this._cmiCV, Infogram.this._columnsCV, dArr3, dArr4);
            return InfogramUtils.generateCMIRelevance(Infogram.this._columnsCV, dArr4, dArr3, dArr2, Infogram.this._cmiCV, Infogram.this._cmiRawCV, Infogram.this._buildCore)._key;
        }

        private void buildInfoGramsNRelevance(boolean z) {
            int floor = (int) Math.floor(Infogram.this._numModels / ((InfogramModel.InfogramParameters) Infogram.this._parms)._nparallelism);
            int i = 0;
            int i2 = Infogram.this._numModels - 1;
            if (floor > 0) {
                for (int i3 = 0; i3 < floor; i3++) {
                    buildModelCMINRelevance(i, ((InfogramModel.InfogramParameters) Infogram.this._parms)._nparallelism, i2);
                    i += ((InfogramModel.InfogramParameters) Infogram.this._parms)._nparallelism;
                    Infogram.this._job.update(((InfogramModel.InfogramParameters) Infogram.this._parms)._nparallelism, "in the middle of building infogram models.");
                }
            }
            int i4 = Infogram.this._numModels - i;
            if (i4 > 0) {
                buildModelCMINRelevance(i, i4, i2);
                Infogram.this._job.update(i4, " building the final set of infogram models.");
            }
            Infogram.this._cmi = InfogramUtils.calculateFinalCMI(Infogram.this._cmiRaw, Infogram.this._buildCore);
            if (z) {
                Infogram.this._cmiValid = InfogramUtils.calculateFinalCMI(Infogram.this._cmiRawValid, Infogram.this._buildCore);
            }
        }

        private void buildModelCMINRelevance(int i, int i2, int i3) {
            boolean z = i + i2 >= i3;
            Frame[] buildTrainingFrames = buildTrainingFrames(i, i2, i3);
            Model.Parameters[] buildModelParameters = InfogramUtils.buildModelParameters(buildTrainingFrames, ((InfogramModel.InfogramParameters) Infogram.this._parms)._infogram_algorithm_parameters, i2, ((InfogramModel.InfogramParameters) Infogram.this._parms)._algorithm);
            ModelBuilder[] trainModelsParallel = ModelBuilderHelper.trainModelsParallel(InfogramUtils.buildModelBuilders(buildModelParameters), i2);
            if (z) {
                extractRelevance(trainModelsParallel[i2 - 1].get(), buildModelParameters[i2 - 1]);
            }
            Infogram.this._validNonZeroNumRows = generateInfoGrams(trainModelsParallel, buildTrainingFrames, i, i2);
        }

        private Frame[] buildTrainingFrames(int i, int i2, int i3) {
            Frame[] frameArr = new Frame[i2];
            Frame train = ((InfogramModel.InfogramParameters) Infogram.this._parms).train();
            int i4 = i + i2;
            int i5 = 0;
            int findstart = InfogramUtils.findstart(Infogram.this._generatedFrameKeys);
            int length = Infogram.this._generatedFrameKeys.length;
            for (int i6 = i; i6 < i4; i6++) {
                frameArr[i5] = new Frame(Infogram.this._baseOrSensitiveFrame);
                if (Infogram.this._buildCore) {
                    for (int i7 = 0; i7 < Infogram.this._topKPredictors.length; i7++) {
                        if (i6 < i3 && i7 != i6) {
                            frameArr[i5].add(Infogram.this._topKPredictors[i7], train.vec(Infogram.this._topKPredictors[i7]));
                        } else if (i6 == i3) {
                            frameArr[i5].add(Infogram.this._topKPredictors[i7], train.vec(Infogram.this._topKPredictors[i7]));
                        }
                    }
                } else if (i6 < i3) {
                    frameArr[i5].prepend(Infogram.this._topKPredictors[i6], train.vec(Infogram.this._topKPredictors[i6]));
                }
                int i8 = findstart;
                findstart++;
                Infogram.this._generatedFrameKeys[i8] = frameArr[i5]._key;
                int i9 = i5;
                i5++;
                DKV.put(frameArr[i9]);
            }
            return frameArr;
        }

        private long generateInfoGrams(ModelBuilder[] modelBuilderArr, Frame[] frameArr, int i, int i2) {
            long j = Long.MAX_VALUE;
            int findstart = InfogramUtils.findstart(Infogram.this._generatedFrameKeys);
            for (int i3 = 0; i3 < i2; i3++) {
                Model model = modelBuilderArr[i3].get();
                int nclasses = model._output.nclasses();
                Frame score = model.score(frameArr[i3]);
                score.add(((InfogramModel.InfogramParameters) Infogram.this._parms)._response_column, frameArr[i3].vec(((InfogramModel.InfogramParameters) Infogram.this._parms)._response_column));
                Scope.track_generic(model);
                if (model._parms._weights_column != null && Arrays.asList(frameArr[i3].names()).contains(model._parms._weights_column)) {
                    score.add(model._parms._weights_column, frameArr[i3].vec(model._parms._weights_column));
                }
                int i4 = findstart;
                findstart++;
                Infogram.this._generatedFrameKeys[i4] = score._key;
                Infogram.this._cmiRaw[i3 + i] = ((EstimateCMI) new EstimateCMI(score, nclasses).doAll(score))._meanCMI;
                if (((InfogramModel.InfogramParameters) Infogram.this._parms).valid() != null) {
                    Frame valid = ((InfogramModel.InfogramParameters) Infogram.this._parms).valid();
                    Frame score2 = model.score(valid);
                    score2.add(((InfogramModel.InfogramParameters) Infogram.this._parms)._response_column, valid.vec(((InfogramModel.InfogramParameters) Infogram.this._parms)._response_column));
                    if (model._parms._weights_column != null) {
                        if (Arrays.asList(valid.names()).contains("__internal_cv_weights__")) {
                            score2.add(model._parms._weights_column, valid.vec("__internal_cv_weights__"));
                        } else {
                            score2.add(model._parms._weights_column, valid.vec(model._parms._weights_column));
                        }
                    }
                    findstart++;
                    Infogram.this._generatedFrameKeys[findstart] = score2._key;
                    Infogram.this._cmiRawValid[i3 + i] = ((EstimateCMI) new EstimateCMI(score2, nclasses).doAll(score2))._meanCMI;
                    j = Math.min(j, r0._nonZeroRows);
                }
            }
            return j;
        }

        private void extractRelevance(Model model, Model.Parameters parameters) {
            if (Infogram.this._buildCore) {
                Infogram.this._varImp = model._output.getVariableImportances();
                return;
            }
            Frame subtractAdd2Frame = InfogramUtils.subtractAdd2Frame(Infogram.this._baseOrSensitiveFrame, ((InfogramModel.InfogramParameters) Infogram.this._parms).train(), ((InfogramModel.InfogramParameters) Infogram.this._parms)._protected_columns, Infogram.this._topKPredictors);
            parameters._train = subtractAdd2Frame._key;
            int findstart = InfogramUtils.findstart(Infogram.this._generatedFrameKeys);
            int i = findstart + 1;
            Infogram.this._generatedFrameKeys[findstart] = subtractAdd2Frame._key;
            Model model2 = ModelBuilder.make(parameters).trainModel().get();
            Infogram.this._varImp = model2._output.getVariableImportances();
            Scope.track_generic(model2);
        }
    }

    public Infogram(boolean z) {
        super(new InfogramModel.InfogramParameters(), z);
        this._baseOrSensitiveFrame = null;
        this._cvDone = false;
        this._nFoldOrig = 0;
        this._foldAssignmentOrig = null;
        this._foldColumnOrig = null;
    }

    public Infogram(InfogramModel.InfogramParameters infogramParameters) {
        super(infogramParameters);
        this._baseOrSensitiveFrame = null;
        this._cvDone = false;
        this._nFoldOrig = 0;
        this._foldAssignmentOrig = null;
        this._foldColumnOrig = null;
        init(false);
    }

    public Infogram(InfogramModel.InfogramParameters infogramParameters, Key<InfogramModel> key) {
        super(infogramParameters, key);
        this._baseOrSensitiveFrame = null;
        this._cvDone = false;
        this._nFoldOrig = 0;
        this._foldAssignmentOrig = null;
        this._foldColumnOrig = null;
        init(false);
    }

    protected ModelBuilder<InfogramModel, InfogramModel.InfogramParameters, InfogramModel.InfogramModelOutput>.Driver trainModelImpl() {
        return new InfogramDriver();
    }

    protected int nModelsInParallel(int i) {
        return nModelsInParallel(i, 2);
    }

    public void computeCrossValidation() {
        info("cross-validation", "cross-validation infogram information is stored in frame with key labeled as admissible_score_key_cv and the admissible features in admissible_features_cv.");
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        super.computeCrossValidation();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    public void cv_computeAndSetOptimalParameters(ModelBuilder[] modelBuilderArr) {
        int length = modelBuilderArr.length;
        ?? r0 = new double[length];
        ArrayList arrayList = new ArrayList();
        long[] jArr = new long[length];
        for (int i = 0; i < modelBuilderArr.length; i++) {
            InfogramModel infogramModel = modelBuilderArr[i].dest().get();
            Scope.track_generic(infogramModel);
            InfogramUtils.extractInfogramInfo(infogramModel, r0, arrayList, i);
            jArr[i] = ((InfogramModel.InfogramModelOutput) infogramModel._output)._validNonZeroNumRows;
        }
        calculateMeanInfogramInfo(r0, arrayList, jArr);
        for (ModelBuilder modelBuilder : modelBuilderArr) {
            InfogramModel infogramModel2 = ((Infogram) modelBuilder)._model;
            infogramModel2.write_lock(this._job);
            infogramModel2.update(this._job);
            infogramModel2.unlock(this._job);
        }
        this._cvDone = true;
    }

    public void calculateMeanInfogramInfo(double[][] dArr, List<List<String>> list, long[] jArr) {
        int length = dArr.length;
        HashSet hashSet = new HashSet();
        Iterator<List<String>> it = list.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next());
        }
        ArrayList arrayList = new ArrayList(hashSet);
        int size = hashSet.size();
        this._cmiCV = new double[size];
        this._cmiRawCV = new double[size];
        double sum = 1.0d / ArrayUtils.sum(jArr);
        int length2 = dArr[0].length;
        for (int i = 0; i < length; i++) {
            List<String> list2 = list.get(i);
            double d = jArr[i] * sum;
            for (int i2 = 0; i2 < length2; i2++) {
                int indexOf = arrayList.indexOf(list2.get(i2));
                double[] dArr2 = this._cmiRawCV;
                dArr2[indexOf] = dArr2[indexOf] + (dArr[i][i2] * d);
            }
        }
        double maxValue = ArrayUtils.maxValue(this._cmiRawCV);
        double d2 = maxValue == 0.0d ? 0.0d : 1.0d / maxValue;
        for (int i3 = 0; i3 < size; i3++) {
            this._cmiCV[i3] = this._cmiRawCV[i3] * d2;
        }
        this._columnsCV = (String[]) arrayList.stream().toArray(i4 -> {
            return new String[i4];
        });
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public boolean isSupervised() {
        return true;
    }

    public boolean havePojo() {
        return false;
    }

    public boolean haveMojo() {
        return false;
    }

    public void init(boolean z) {
        super.init(z);
        if (z) {
            validateInfoGramParameters();
        }
    }

    private void validateInfoGramParameters() {
        Frame train = ((InfogramModel.InfogramParameters) this._parms).train();
        if (!((InfogramModel.InfogramParameters) this._parms).train().vec(((InfogramModel.InfogramParameters) this._parms)._response_column).isCategorical()) {
            error("response_column", "Regression is not supported for Infogram. If you meant to do classification, convert your response column to categorical/factor type before calling Infogram.");
        }
        if (((InfogramModel.InfogramParameters) this._parms)._protected_columns != null) {
            HashSet hashSet = new HashSet(Arrays.asList(train.names()));
            for (String str : ((InfogramModel.InfogramParameters) this._parms)._protected_columns) {
                if (!hashSet.contains(str)) {
                    error("protected_columns", "protected_columns: " + str + " is not a valid column in the training dataset.");
                }
            }
        }
        this._buildCore = ((InfogramModel.InfogramParameters) this._parms)._protected_columns == null;
        if (this._buildCore) {
            if (((InfogramModel.InfogramParameters) this._parms)._net_information_threshold == -1.0d) {
                ((InfogramModel.InfogramParameters) this._parms)._cmi_threshold = 0.1d;
                ((InfogramModel.InfogramParameters) this._parms)._net_information_threshold = 0.1d;
            } else if (((InfogramModel.InfogramParameters) this._parms)._net_information_threshold > 1.0d || ((InfogramModel.InfogramParameters) this._parms)._net_information_threshold < 0.0d) {
                error("net_information_threshold", " should be set to be between 0 and 1.");
            } else {
                ((InfogramModel.InfogramParameters) this._parms)._cmi_threshold = ((InfogramModel.InfogramParameters) this._parms)._net_information_threshold;
            }
            if (((InfogramModel.InfogramParameters) this._parms)._total_information_threshold == -1.0d) {
                ((InfogramModel.InfogramParameters) this._parms)._relevance_threshold = 0.1d;
                ((InfogramModel.InfogramParameters) this._parms)._total_information_threshold = 0.1d;
            } else if (((InfogramModel.InfogramParameters) this._parms)._total_information_threshold < 0.0d || ((InfogramModel.InfogramParameters) this._parms)._total_information_threshold > 1.0d) {
                error("total_information_threshold", " should be set to be between 0 and 1.");
            } else {
                ((InfogramModel.InfogramParameters) this._parms)._relevance_threshold = ((InfogramModel.InfogramParameters) this._parms)._total_information_threshold;
            }
            if (((InfogramModel.InfogramParameters) this._parms)._safety_index_threshold != -1.0d) {
                warn("safety_index_threshold", "Should not set safety_index_threshold for core infogram runs.  Set net_information_threshold instead.  Using default of 0.1 if not set");
            }
            if (((InfogramModel.InfogramParameters) this._parms)._relevance_index_threshold != -1.0d) {
                warn("relevance_index_threshold", "Should not set relevance_index_threshold for core infogram runs.  Set total_information_threshold instead.  Using default of 0.1 if not set");
            }
        } else {
            if (((InfogramModel.InfogramParameters) this._parms)._safety_index_threshold == -1.0d) {
                ((InfogramModel.InfogramParameters) this._parms)._cmi_threshold = 0.1d;
                ((InfogramModel.InfogramParameters) this._parms)._safety_index_threshold = 0.1d;
            } else if (((InfogramModel.InfogramParameters) this._parms)._safety_index_threshold < 0.0d || ((InfogramModel.InfogramParameters) this._parms)._safety_index_threshold > 1.0d) {
                error("safety_index_threshold", " should be set to be between 0 and 1.");
            } else {
                ((InfogramModel.InfogramParameters) this._parms)._cmi_threshold = ((InfogramModel.InfogramParameters) this._parms)._safety_index_threshold;
            }
            if (((InfogramModel.InfogramParameters) this._parms)._relevance_index_threshold == -1.0d) {
                ((InfogramModel.InfogramParameters) this._parms)._relevance_threshold = 0.1d;
                ((InfogramModel.InfogramParameters) this._parms)._relevance_index_threshold = 0.1d;
            } else if (((InfogramModel.InfogramParameters) this._parms)._relevance_index_threshold < 0.0d || ((InfogramModel.InfogramParameters) this._parms)._relevance_index_threshold > 1.0d) {
                error("relevance_index_threshold", " should be set to be between 0 and 1.");
            } else {
                ((InfogramModel.InfogramParameters) this._parms)._relevance_threshold = ((InfogramModel.InfogramParameters) this._parms)._relevance_index_threshold;
            }
            if (((InfogramModel.InfogramParameters) this._parms)._net_information_threshold != -1.0d) {
                warn("net_information_threshold", "Should not set net_information_threshold for fair infogram runs, set safety_index_threshold instead.  Using default of 0.1 if not set");
            }
            if (((InfogramModel.InfogramParameters) this._parms)._total_information_threshold != -1.0d) {
                warn("total_information_threshold", "Should not set total_information_threshold for fair infogram runs, set relevance_index_threshold instead.  Using default of 0.1 if not set");
            }
            if (InfogramModel.InfogramParameters.Algorithm.AUTO.equals(((InfogramModel.InfogramParameters) this._parms)._algorithm)) {
                ((InfogramModel.InfogramParameters) this._parms)._algorithm = InfogramModel.InfogramParameters.Algorithm.gbm;
            }
        }
        if (((InfogramModel.InfogramParameters) this._parms)._top_n_features < 0) {
            error("_topk", "topk must be between 0 and the number of predictor columns in your training dataset.");
        }
        this._numPredictors = ((InfogramModel.InfogramParameters) this._parms).train().numCols() - 1;
        if (((InfogramModel.InfogramParameters) this._parms)._weights_column != null) {
            this._numPredictors--;
        }
        if (((InfogramModel.InfogramParameters) this._parms)._offset_column != null) {
            this._numPredictors--;
        }
        if (((InfogramModel.InfogramParameters) this._parms)._top_n_features > this._numPredictors) {
            warn("top_n_features", "The top_n_features exceed the actual number of predictor columns in your training dataset.  It will be set to the number of predictors in your training dataset.");
            ((InfogramModel.InfogramParameters) this._parms)._top_n_features = this._numPredictors;
        }
        if (((InfogramModel.InfogramParameters) this._parms)._nparallelism < 0) {
            error("nparallelism", "must be >= 0.  If 0, it is adaptive");
        }
        if (((InfogramModel.InfogramParameters) this._parms)._nparallelism == 0) {
            ((InfogramModel.InfogramParameters) this._parms)._nparallelism = H2O.NUMCPUS;
        }
        if (((InfogramModel.InfogramParameters) this._parms)._compute_p_values) {
            error("compute_p_values", " compute_p_values calculation is not yet implemented.");
        }
        if (nclasses() < 2) {
            error("distribution", " infogram currently only supports classification models");
        }
        if (DistributionFamily.AUTO.equals(((InfogramModel.InfogramParameters) this._parms)._distribution)) {
            ((InfogramModel.InfogramParameters) this._parms)._distribution = nclasses() == 2 ? DistributionFamily.bernoulli : DistributionFamily.multinomial;
        }
        if (this._cvDone) {
            this._nFoldOrig = ((InfogramModel.InfogramParameters) this._parms)._nfolds;
            this._foldColumnOrig = ((InfogramModel.InfogramParameters) this._parms)._fold_column;
            this._foldAssignmentOrig = ((InfogramModel.InfogramParameters) this._parms)._fold_assignment;
            ((InfogramModel.InfogramParameters) this._parms)._fold_column = null;
            ((InfogramModel.InfogramParameters) this._parms)._nfolds = 0;
            ((InfogramModel.InfogramParameters) this._parms)._fold_assignment = null;
        }
    }
}
