package tl.lin.data.fd;

import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.IntCollection;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import tl.lin.data.fd.SortableEntries;
import tl.lin.data.map.Int2IntOpenHashMapWritable;
import tl.lin.data.pair.PairOfInts;

/* loaded from: input_file:tl/lin/data/fd/Int2IntFrequencyDistributionFastutil.class */
public class Int2IntFrequencyDistributionFastutil implements Int2IntFrequencyDistribution {
    private Int2IntOpenHashMapWritable counts = new Int2IntOpenHashMapWritable();
    private long sumOfCounts = 0;
    private final Comparator<PairOfInts> comparatorRightDescending = new Comparator<PairOfInts>() { // from class: tl.lin.data.fd.Int2IntFrequencyDistributionFastutil.2
        @Override // java.util.Comparator
        public int compare(PairOfInts pairOfInts, PairOfInts pairOfInts2) {
            if (pairOfInts.getRightElement() > pairOfInts2.getRightElement()) {
                return -1;
            }
            if (pairOfInts.getRightElement() < pairOfInts2.getRightElement()) {
                return 1;
            }
            if (pairOfInts.getLeftElement() == pairOfInts2.getLeftElement()) {
                throw new RuntimeException("Event observed twice!");
            }
            return pairOfInts.getLeftElement() < pairOfInts2.getLeftElement() ? -1 : 1;
        }
    };
    private final Comparator<PairOfInts> comparatorRightAscending = new Comparator<PairOfInts>() { // from class: tl.lin.data.fd.Int2IntFrequencyDistributionFastutil.3
        @Override // java.util.Comparator
        public int compare(PairOfInts pairOfInts, PairOfInts pairOfInts2) {
            if (pairOfInts.getRightElement() > pairOfInts2.getRightElement()) {
                return 1;
            }
            if (pairOfInts.getRightElement() < pairOfInts2.getRightElement()) {
                return -1;
            }
            if (pairOfInts.getLeftElement() == pairOfInts2.getLeftElement()) {
                throw new RuntimeException("Event observed twice!");
            }
            return pairOfInts.getLeftElement() < pairOfInts2.getLeftElement() ? -1 : 1;
        }
    };
    private final Comparator<PairOfInts> comparatorLeftAscending = new Comparator<PairOfInts>() { // from class: tl.lin.data.fd.Int2IntFrequencyDistributionFastutil.4
        @Override // java.util.Comparator
        public int compare(PairOfInts pairOfInts, PairOfInts pairOfInts2) {
            if (pairOfInts.getLeftElement() > pairOfInts2.getLeftElement()) {
                return 1;
            }
            if (pairOfInts.getLeftElement() < pairOfInts2.getLeftElement()) {
                return -1;
            }
            throw new RuntimeException("Event observed twice!");
        }
    };
    private final Comparator<PairOfInts> comparatorLeftDescending = new Comparator<PairOfInts>() { // from class: tl.lin.data.fd.Int2IntFrequencyDistributionFastutil.5
        @Override // java.util.Comparator
        public int compare(PairOfInts pairOfInts, PairOfInts pairOfInts2) {
            if (pairOfInts.getLeftElement() > pairOfInts2.getLeftElement()) {
                return -1;
            }
            if (pairOfInts.getLeftElement() < pairOfInts2.getLeftElement()) {
                return 1;
            }
            throw new RuntimeException("Event observed twice!");
        }
    };

    public void increment(int i) {
        if (contains(i)) {
            set(i, get(i) + 1);
        } else {
            set(i, 1);
        }
    }

    public void increment(int i, int i2) {
        if (contains(i)) {
            set(i, get(i) + i2);
        } else {
            set(i, i2);
        }
    }

    public void decrement(int i) {
        if (!contains(i)) {
            throw new RuntimeException("Can't decrement non-existent event!");
        }
        int i2 = get(i);
        if (i2 == 1) {
            remove(i);
        } else {
            set(i, i2 - 1);
        }
    }

    public void decrement(int i, int i2) {
        if (!contains(i)) {
            throw new RuntimeException("Can't decrement non-existent event!");
        }
        int i3 = get(i);
        if (i3 < i2) {
            throw new RuntimeException("Can't decrement past zero!");
        }
        if (i3 == i2) {
            remove(i);
        } else {
            set(i, i3 - i2);
        }
    }

    public boolean contains(int i) {
        return this.counts.containsKey(i);
    }

    public int get(int i) {
        return this.counts.get(i);
    }

    public double computeRelativeFrequency(int i) {
        return this.counts.get(i) / getSumOfCounts();
    }

    public double computeLogRelativeFrequency(int i) {
        return Math.log(this.counts.get(i)) - Math.log(getSumOfCounts());
    }

    public int set(int i, int i2) {
        int put = this.counts.put(i, i2);
        this.sumOfCounts = (this.sumOfCounts - put) + i2;
        return put;
    }

    public int remove(int i) {
        int remove = this.counts.remove(i);
        this.sumOfCounts -= remove;
        return remove;
    }

    public void clear() {
        this.counts.clear();
        this.sumOfCounts = 0L;
    }

    public IntSet keySet() {
        return this.counts.keySet();
    }

    public IntCollection values() {
        return this.counts.values();
    }

    public Int2IntMap.FastEntrySet entrySet() {
        return this.counts.int2IntEntrySet();
    }

    public int getNumberOfEvents() {
        return this.counts.size();
    }

    public long getSumOfCounts() {
        return this.sumOfCounts;
    }

    public Iterator<PairOfInts> iterator() {
        return new Iterator<PairOfInts>() { // from class: tl.lin.data.fd.Int2IntFrequencyDistributionFastutil.1
            private Iterator<Int2IntMap.Entry> iter;
            private final PairOfInts pair = new PairOfInts();

            {
                this.iter = Int2IntFrequencyDistributionFastutil.this.counts.int2IntEntrySet().iterator();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.iter.hasNext();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public PairOfInts next() {
                if (!hasNext()) {
                    return null;
                }
                Int2IntMap.Entry next = this.iter.next();
                this.pair.set(next.getIntKey(), next.getIntValue());
                return this.pair;
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    public List<PairOfInts> getEntries(SortableEntries.Order order) {
        if (order.equals(SortableEntries.Order.ByRightElementDescending)) {
            return getEntriesSorted(this.comparatorRightDescending);
        }
        if (order.equals(SortableEntries.Order.ByLeftElementAscending)) {
            return getEntriesSorted(this.comparatorLeftAscending);
        }
        if (order.equals(SortableEntries.Order.ByRightElementAscending)) {
            return getEntriesSorted(this.comparatorRightAscending);
        }
        if (order.equals(SortableEntries.Order.ByLeftElementDescending)) {
            return getEntriesSorted(this.comparatorLeftDescending);
        }
        return null;
    }

    public List<PairOfInts> getEntries(SortableEntries.Order order, int i) {
        if (order.equals(SortableEntries.Order.ByRightElementDescending)) {
            return getEntriesSorted(this.comparatorRightDescending, i);
        }
        if (order.equals(SortableEntries.Order.ByLeftElementAscending)) {
            return getEntriesSorted(this.comparatorLeftAscending, i);
        }
        if (order.equals(SortableEntries.Order.ByRightElementAscending)) {
            return getEntriesSorted(this.comparatorRightAscending, i);
        }
        if (order.equals(SortableEntries.Order.ByLeftElementDescending)) {
            return getEntriesSorted(this.comparatorLeftDescending, i);
        }
        return null;
    }

    private List<PairOfInts> getEntriesSorted(Comparator<PairOfInts> comparator) {
        ArrayList newArrayList = Lists.newArrayList();
        for (Int2IntMap.Entry entry : this.counts.int2IntEntrySet()) {
            newArrayList.add(new PairOfInts(entry.getIntKey(), entry.getIntValue()));
        }
        Collections.sort(newArrayList, comparator);
        return newArrayList;
    }

    private List<PairOfInts> getEntriesSorted(Comparator<PairOfInts> comparator, int i) {
        return getEntriesSorted(comparator).subList(0, i);
    }

    public void readFields(DataInput dataInput) throws IOException {
        this.sumOfCounts = dataInput.readLong();
        this.counts.readFields(dataInput);
    }

    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeLong(this.sumOfCounts);
        this.counts.write(dataOutput);
    }
}
