package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Iterator;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:weka/classifiers/meta/RealAdaBoost.class */
public class RealAdaBoost extends RandomizableIteratedSingleClassifierEnhancer implements WeightedInstancesHandler, TechnicalInformationHandler {
    static final long serialVersionUID = -7378109809933197974L;
    protected int m_NumIterationsPerformed;
    protected int m_WeightThreshold = 100;
    protected double m_Shrinkage = 1.0d;
    protected boolean m_UseResampling;
    protected Classifier m_ZeroR;
    protected double m_SumOfWeights;

    public RealAdaBoost() {
        this.m_Classifier = new DecisionStump();
    }

    public String globalInfo() {
        return "Class for boosting a 2-class classifier using the Real Adaboost method.\n\nFor more information, see\n\n" + getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "J. Friedman and T. Hastie and R. Tibshirani");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Additive Logistic Regression: a Statistical View of Boosting");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Annals of Statistics");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "95");
        technicalInformation.setValue(TechnicalInformation.Field.NUMBER, "2");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "337-407");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2000");
        return technicalInformation;
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    protected Instances selectWeightQuantile(Instances instances, double d) {
        int numInstances = instances.numInstances();
        Instances instances2 = new Instances(instances, numInstances);
        double[] dArr = new double[numInstances];
        double d2 = 0.0d;
        for (int i = 0; i < numInstances; i++) {
            dArr[i] = instances.instance(i).weight();
            d2 += dArr[i];
        }
        double d3 = d2 * d;
        int[] sort = Utils.sort(dArr);
        double d4 = 0.0d;
        for (int i2 = numInstances - 1; i2 >= 0; i2--) {
            instances2.add((Instance) instances.instance(sort[i2]).copy());
            d4 += dArr[sort[i2]];
            if (d4 > d3 && i2 > 0 && dArr[sort[i2]] != dArr[sort[i2 - 1]]) {
                break;
            }
        }
        if (this.m_Debug) {
            System.err.println("Selected " + instances2.numInstances() + " out of " + numInstances);
        }
        return instances2;
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tPercentage of weight mass to base training on.\n\t(default 100, reduce to around 90 speed up)", "P", 1, "-P <num>"));
        vector.addElement(new Option("\tUse resampling for boosting.", "Q", 0, "-Q"));
        vector.addElement(new Option("\tShrinkage parameter.\n\t(default 1)", "H", 1, "-H <num>"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('P', strArr);
        if (option.length() != 0) {
            setWeightThreshold(Integer.parseInt(option));
        } else {
            setWeightThreshold(100);
        }
        String option2 = Utils.getOption('H', strArr);
        if (option2.length() != 0) {
            setShrinkage(new Double(option2).doubleValue());
        } else {
            setShrinkage(1.0d);
        }
        setUseResampling(Utils.getFlag('Q', strArr));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        if (getUseResampling()) {
            vector.add("-Q");
        }
        vector.add("-P");
        vector.add("" + getWeightThreshold());
        vector.add("-H");
        vector.add("" + getShrinkage());
        for (String str : super.getOptions()) {
            vector.add(str);
        }
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public String shrinkageTipText() {
        return "Shrinkage parameter (use small value like 0.1 to reduce overfitting).";
    }

    public double getShrinkage() {
        return this.m_Shrinkage;
    }

    public void setShrinkage(double d) {
        this.m_Shrinkage = d;
    }

    public String weightThresholdTipText() {
        return "Weight threshold for weight pruning.";
    }

    public void setWeightThreshold(int i) {
        this.m_WeightThreshold = i;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public String useResamplingTipText() {
        return "Whether resampling is used instead of reweighting.";
    }

    public void setUseResampling(boolean z) {
        this.m_UseResampling = z;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        }
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        super.buildClassifier(instances);
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_SumOfWeights = instances2.sumOfWeights();
        if (this.m_UseResampling || !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            buildClassifierUsingResampling(instances2);
        } else {
            buildClassifierWithWeights(instances2);
        }
    }

    protected void buildClassifierUsingResampling(Instances instances) throws Exception {
        int numInstances = instances.numInstances();
        Random random = new Random(this.m_Seed);
        double d = Double.MAX_VALUE;
        Instances instances2 = new Instances(instances, 0, numInstances);
        this.m_NumIterationsPerformed = -1;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances instances3 = new Instances(instances2);
            normalizeWeights(instances3, 1.0d);
            Instances selectWeightQuantile = this.m_WeightThreshold < 100 ? selectWeightQuantile(instances3, this.m_WeightThreshold / 100.0d) : new Instances(instances3);
            double[] dArr = new double[selectWeightQuantile.numInstances()];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = selectWeightQuantile.instance(i).weight();
            }
            Instances resampleWithWeights = selectWeightQuantile.resampleWithWeights(random, dArr);
            if (this.m_NumIterationsPerformed == -1) {
                this.m_ZeroR = new ZeroR();
                this.m_ZeroR.buildClassifier(instances);
            } else {
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(resampleWithWeights);
            }
            setWeights(instances2, this.m_NumIterationsPerformed);
            double d2 = 0.0d;
            Iterator it = instances2.iterator();
            while (it.hasNext()) {
                d2 += Math.log(((Instance) it.next()).weight());
            }
            if (this.m_Debug) {
                System.err.println("Current loss on log scale: " + d2);
            }
            if (this.m_NumIterationsPerformed > -1 && d2 > d) {
                if (this.m_Debug) {
                    System.err.println("Loss has increased: bailing out.");
                    return;
                }
                return;
            }
            d = d2;
            this.m_NumIterationsPerformed++;
        }
    }

    protected void setWeights(Instances instances, int i) throws Exception {
        double d;
        double d2;
        double log;
        double log2;
        Iterator it = instances.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            double d3 = this.m_Shrinkage;
            if (i == -1) {
                d = this.m_ZeroR.distributionForInstance(instance)[0];
                d3 = 1.0d;
            } else {
                d = ((this.m_SumOfWeights * this.m_Classifiers[i].distributionForInstance(instance)[0]) + 1.0d) / (this.m_SumOfWeights + 2.0d);
            }
            if (instance.classValue() == 1.0d) {
                d2 = d3 * 0.5d;
                log = Math.log(d);
                log2 = Math.log(1.0d - d);
            } else {
                d2 = d3 * 0.5d;
                log = Math.log(1.0d - d);
                log2 = Math.log(d);
            }
            instance.setWeight(instance.weight() * Math.exp(d2 * (log - log2)));
        }
    }

    protected void normalizeWeights(Instances instances, double d) throws Exception {
        double sumOfWeights = instances.sumOfWeights();
        Iterator it = instances.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            instance.setWeight((instance.weight() * d) / sumOfWeights);
        }
    }

    protected void buildClassifierWithWeights(Instances instances) throws Exception {
        int numInstances = instances.numInstances();
        Random random = new Random(this.m_Seed);
        double d = Double.MAX_VALUE;
        Instances instances2 = new Instances(instances, 0, numInstances);
        this.m_NumIterationsPerformed = -1;
        while (this.m_NumIterationsPerformed < this.m_Classifiers.length) {
            if (this.m_Debug) {
                System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
            }
            Instances instances3 = new Instances(instances2);
            normalizeWeights(instances3, this.m_SumOfWeights);
            Instances selectWeightQuantile = this.m_WeightThreshold < 100 ? selectWeightQuantile(instances3, this.m_WeightThreshold / 100.0d) : new Instances(instances3, 0, numInstances);
            if (this.m_NumIterationsPerformed == -1) {
                this.m_ZeroR = new ZeroR();
                this.m_ZeroR.buildClassifier(instances);
            } else {
                if (this.m_Classifiers[this.m_NumIterationsPerformed] instanceof Randomizable) {
                    this.m_Classifiers[this.m_NumIterationsPerformed].setSeed(random.nextInt());
                }
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(selectWeightQuantile);
            }
            setWeights(instances2, this.m_NumIterationsPerformed);
            double d2 = 0.0d;
            Iterator it = instances2.iterator();
            while (it.hasNext()) {
                d2 += Math.log(((Instance) it.next()).weight());
            }
            if (this.m_Debug) {
                System.err.println("Current loss on log scale: " + d2);
            }
            if (this.m_NumIterationsPerformed > -1 && d2 > d) {
                if (this.m_Debug) {
                    System.err.println("Loss has increased: bailing out.");
                    return;
                }
                return;
            }
            d = d2;
            this.m_NumIterationsPerformed++;
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double d;
        double[] dArr = new double[instance.numClasses()];
        for (int i = -1; i < this.m_NumIterationsPerformed; i++) {
            double d2 = this.m_Shrinkage;
            if (i == -1) {
                d = this.m_ZeroR.distributionForInstance(instance)[0];
                d2 = 1.0d;
            } else {
                d = ((this.m_SumOfWeights * this.m_Classifiers[i].distributionForInstance(instance)[0]) + 1.0d) / (this.m_SumOfWeights + 2.0d);
            }
            dArr[0] = dArr[0] + (d2 * 0.5d * (Math.log(d) - Math.log(1.0d - d)));
        }
        dArr[1] = -dArr[0];
        return Utils.logs2probs(dArr);
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_ZeroR == null) {
            stringBuffer.append("No model built yet.\n\n");
        } else {
            stringBuffer.append("RealAdaBoost: Base classifiers: \n\n");
            stringBuffer.append(this.m_ZeroR.toString() + "\n\n");
            for (int i = 0; i < this.m_NumIterationsPerformed; i++) {
                stringBuffer.append(this.m_Classifiers[i].toString() + "\n\n");
            }
            stringBuffer.append("Number of performed Iterations: " + this.m_NumIterationsPerformed + "\n");
        }
        return stringBuffer.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8109 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new RealAdaBoost(), strArr);
    }
}
