package ws.palladian.classification.utils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ws.palladian.classification.zeror.ZeroRLearner;
import ws.palladian.core.Instance;
import ws.palladian.helper.collection.Bag;
import ws.palladian.helper.collection.LazyMap;
import ws.palladian.helper.functional.Factories;

/* loaded from: input_file:ws/palladian/classification/utils/ClassDistributionResampler.class */
public class ClassDistributionResampler implements Iterable<Instance> {
    private static final Logger LOGGER = LoggerFactory.getLogger(ClassDistributionResampler.class);
    private static final Random RANDOM = new Random();
    private Collection<Instance> sampled;
    private final Map<String, Double> probabilities;
    private final Map<String, Double> weights;

    public ClassDistributionResampler(Iterable<Instance> iterable) {
        this(iterable, Collections.emptyMap());
    }

    public ClassDistributionResampler(Iterable<Instance> iterable, Map<String, Double> map) {
        Validate.notNull(iterable, "data must not be null", new Object[0]);
        Validate.notNull(map, "weights must not be null", new Object[0]);
        this.probabilities = new ZeroRLearner().train(iterable).getCategoryProbabilities();
        this.weights = new LazyMap(new HashMap(map), Factories.constant(Double.valueOf(1.0d)));
        LOGGER.info("Class probabilities : {}", this.probabilities);
        this.sampled = reSample(iterable);
    }

    private Collection<Instance> reSample(Iterable<Instance> iterable) {
        double d = Double.MAX_VALUE;
        Iterator<Double> it = this.probabilities.values().iterator();
        while (it.hasNext()) {
            d = Math.min(d, it.next().doubleValue());
        }
        ArrayList arrayList = new ArrayList();
        Bag bag = new Bag();
        for (Instance instance : iterable) {
            String category = instance.getCategory();
            if (RANDOM.nextDouble() <= (d / this.probabilities.get(category).doubleValue()) * this.weights.get(category).doubleValue()) {
                arrayList.add(instance);
                bag.add(category);
            }
        }
        LOGGER.info("Re-weighted class counts: {}", bag);
        return arrayList;
    }

    @Override // java.lang.Iterable
    public Iterator<Instance> iterator() {
        return this.sampled.iterator();
    }

    public String toString() {
        return "ClassDistributionEqualizer [#items=" + this.sampled.size() + ", probabilities=" + this.probabilities + ", weights=" + this.weights + "]";
    }
}
