package dev.brachtendorf.jimagehash.matcher.categorize.supervised.randomForest;

import com.github.kilianB.pcg.fast.PcgRSFast;
import dev.brachtendorf.ArrayUtil;
import dev.brachtendorf.MathUtil;
import dev.brachtendorf.Require;
import dev.brachtendorf.datastructures.CountHashCollection;
import dev.brachtendorf.datastructures.Pair;
import dev.brachtendorf.datastructures.Triple;
import dev.brachtendorf.jimagehash.hash.FuzzyHash;
import dev.brachtendorf.jimagehash.hash.Hash;
import dev.brachtendorf.jimagehash.hashAlgorithms.AverageHash;
import dev.brachtendorf.jimagehash.hashAlgorithms.HashingAlgorithm;
import dev.brachtendorf.jimagehash.hashAlgorithms.PerceptiveHash;
import dev.brachtendorf.jimagehash.hashAlgorithms.RotAverageHash;
import dev.brachtendorf.jimagehash.matcher.PlainImageMatcher;
import dev.brachtendorf.jimagehash.matcher.categorize.CategoricalImageMatcher;
import dev.brachtendorf.jimagehash.matcher.categorize.CategorizationResult;
import dev.brachtendorf.jimagehash.matcher.categorize.supervised.LabeledImage;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeSet;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import javax.imageio.ImageIO;

/* loaded from: input_file:dev/brachtendorf/jimagehash/matcher/categorize/supervised/randomForest/RandomForestCategorizer.class */
public class RandomForestCategorizer extends PlainImageMatcher implements CategoricalImageMatcher {
    protected List<TreeNode> forest = new ArrayList();
    protected List<LabeledImage> labeledImages = new ArrayList();
    protected TreeSet<Integer> categories = new TreeSet<>();

    public void addTestImages(Collection<LabeledImage> collection) {
        Iterator<LabeledImage> it = collection.iterator();
        while (it.hasNext()) {
            addTestImages(it.next());
        }
    }

    public void addTestImages(LabeledImage... labeledImageArr) {
        for (LabeledImage labeledImage : labeledImageArr) {
            addTestImages(labeledImage);
        }
    }

    public void addTestImages(LabeledImage labeledImage) {
        this.categories.add(Integer.valueOf(labeledImage.getCategory()));
        this.labeledImages.add(labeledImage);
    }

    public void clearTestImages() {
        this.categories.clear();
        this.labeledImages.clear();
    }

    public void trainMatcher(int i, int i2, int i3) {
        Require.positiveValue(Integer.valueOf(i2), "NumVarsSearchRange has to be positive.");
        Require.oddValue(Integer.valueOf(i), "The number of trees should be odd to prevent ambiguity");
        System.out.println("");
        System.out.println("Hashing algos available: " + this.steps);
        HashMap hashMap = new HashMap();
        for (HashingAlgorithm hashingAlgorithm : getAlgorithms()) {
            HashMap hashMap2 = new HashMap();
            for (LabeledImage labeledImage : this.labeledImages) {
                hashMap2.put(labeledImage.getbImage(), hashingAlgorithm.hash(labeledImage.getbImage()));
            }
            hashMap.put(hashingAlgorithm, hashMap2);
        }
        List<Pair<FuzzyHash, HashingAlgorithm>> createFuzzyHashes = createFuzzyHashes(hashMap);
        System.out.println(createFuzzyHashes);
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
        int size = createFuzzyHashes.size();
        int sqrt = (int) Math.sqrt(size);
        int i4 = sqrt - i2;
        while (true) {
            if (i4 >= sqrt + i2 && 0 == 0) {
                break;
            }
            System.out.println("Create Forest with number of vars: " + i4 + "/" + size);
            if (i4 >= 0) {
                if (i4 > size) {
                    break;
                }
                Object[] createForest = createForest(i, i4, i3, createFuzzyHashes, hashMap, newFixedThreadPool);
                List<TreeNode> list = (List) createForest[0];
                double doubleValue = ((Double) createForest[1]).doubleValue();
                double doubleValue2 = ((Double) createForest[2]).doubleValue();
                System.out.println("Out of bag error: " + doubleValue);
                System.out.println("Class error all : " + doubleValue2);
                this.forest = list;
                System.out.println(" ");
            }
            i4++;
        }
        newFixedThreadPool.shutdown();
        testDecisionTree(this.labeledImages);
    }

    private void testDecisionTree(List<LabeledImage> list) {
        int i = 0;
        int i2 = 0;
        for (LabeledImage labeledImage : list) {
            if (labeledImage.getCategory() == categorizeImage(labeledImage.getbImage()).getCategory()) {
                i++;
            } else {
                i2++;
            }
        }
        System.out.println("Classification Error: " + (i2 / (i + i2)));
    }

    protected Object[] createForest(int i, int i2, int i3, List<Pair<FuzzyHash, HashingAlgorithm>> list, Map<HashingAlgorithm, Map<BufferedImage, Hash>> map, ExecutorService executorService) {
        new HashMap();
        ArrayList arrayList = new ArrayList();
        new PcgRSFast();
        ArrayList arrayList2 = new ArrayList();
        for (int i4 = 0; i4 < i; i4++) {
            arrayList2.add(executorService.submit(() -> {
                CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>> countHashCollection = new CountHashCollection<>();
                for (int i5 = 0; i5 < i3; i5++) {
                    countHashCollection.addAll(list);
                }
                arrayList.add(buildTree(new ArrayList(this.labeledImages), countHashCollection, i2, map));
                return null;
            }));
        }
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            try {
                ((Future) it.next()).get();
            } catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
        }
        return new Object[]{arrayList, Double.valueOf(0.0d), Double.valueOf(0.0d)};
    }

    private List<TestData> bootstrapDataset(List<TestData> list, Random random) {
        int size = list.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i = 0; i < size; i++) {
            arrayList.add(list.get(random.nextInt(size)));
        }
        return arrayList;
    }

    private TreeNode buildTree(List<LabeledImage> list, CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>> countHashCollection, int i, Map<HashingAlgorithm, Map<BufferedImage, Hash>> map) {
        return buildTree(list, countHashCollection, i, map, Double.MAX_VALUE);
    }

    private TreeNode buildTree(List<LabeledImage> list, CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>> countHashCollection, int i, Map<HashingAlgorithm, Map<BufferedImage, Hash>> map, double d) {
        CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>> countHashCollection2 = new CountHashCollection<>(countHashCollection);
        Triple<TreeNode, List<LabeledImage>[], int[]> computeNode = computeNode(list, countHashCollection2, i, map, d);
        if (computeNode.getFirst() instanceof LeafNode) {
            return (TreeNode) computeNode.getFirst();
        }
        InnerNode innerNode = (InnerNode) computeNode.getFirst();
        if (((List[]) computeNode.getSecond())[0].size() <= 0 || MathUtil.isDoubleEquals(innerNode.qualityLeft, 0.0d, 1.0E-8d)) {
            System.out.println("Create Leaf node");
            innerNode.leftNode = new LeafNode(((int[]) computeNode.getThird())[0]);
        } else {
            innerNode.leftNode = buildTree(((List[]) computeNode.getSecond())[0], countHashCollection2, i, map, innerNode.qualityLeft);
        }
        if (((List[]) computeNode.getSecond())[1].size() <= 0 || MathUtil.isDoubleEquals(innerNode.qualityRight, 0.0d, 1.0E-8d)) {
            System.out.println("Create Leaf node");
            innerNode.rightNode = new LeafNode(((int[]) computeNode.getThird())[1]);
        } else {
            innerNode.rightNode = buildTree(((List[]) computeNode.getSecond())[1], countHashCollection2, i, map, innerNode.qualityRight);
        }
        return innerNode;
    }

    private Triple<TreeNode, List<LabeledImage>[], int[]> computeNode(List<LabeledImage> list, CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>> countHashCollection, int i, Map<HashingAlgorithm, Map<BufferedImage, Hash>> map, double d) {
        double d2 = 0.0d;
        Pair pair = null;
        double d3 = Double.MAX_VALUE;
        double d4 = Double.MAX_VALUE;
        double d5 = Double.MAX_VALUE;
        int i2 = -1;
        int i3 = -1;
        int i4 = 0;
        ArrayList[] arrayListArr = {new ArrayList(), new ArrayList()};
        PcgRSFast pcgRSFast = new PcgRSFast();
        int sizeUnique = countHashCollection.sizeUnique();
        ArrayList arrayList = new ArrayList(sizeUnique);
        for (int i5 = 0; i5 < sizeUnique; i5++) {
            arrayList.add(Integer.valueOf(i5));
        }
        Collections.shuffle(arrayList, pcgRSFast);
        Pair[] pairArr = (Pair[]) countHashCollection.toArrayUnique(new Pair[countHashCollection.sizeUnique()]);
        for (int i6 = 0; i6 < i && !arrayList.isEmpty(); i6++) {
            Pair pair2 = pairArr[((Integer) arrayList.remove(0)).intValue()];
            FuzzyHash fuzzyHash = (FuzzyHash) pair2.getFirst();
            Map<BufferedImage, Hash> map2 = map.get((HashingAlgorithm) pair2.getSecond());
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            HashMap hashMap = new HashMap();
            for (LabeledImage labeledImage : list) {
                double normalizedHammingDistance = fuzzyHash.normalizedHammingDistance(map2.get(labeledImage.getbImage()));
                hashMap.put(labeledImage.getbImage(), Double.valueOf(normalizedHammingDistance));
                linkedHashSet.add(Double.valueOf(normalizedHammingDistance));
            }
            ArrayList arrayList2 = new ArrayList(linkedHashSet);
            Collections.sort(arrayList2);
            LinkedHashSet linkedHashSet2 = new LinkedHashSet();
            for (int i7 = 0; i7 < arrayList2.size() - 1; i7++) {
                linkedHashSet2.add(Double.valueOf((((Double) arrayList2.get(i7)).doubleValue() + ((Double) arrayList2.get(i7 + 1)).doubleValue()) / 2.0d));
            }
            Iterator it = linkedHashSet2.iterator();
            while (it.hasNext()) {
                double doubleValue = ((Double) it.next()).doubleValue();
                ArrayList<LabeledImage> arrayList3 = new ArrayList();
                ArrayList<LabeledImage> arrayList4 = new ArrayList();
                HashMap hashMap2 = new HashMap();
                HashMap hashMap3 = new HashMap();
                for (LabeledImage labeledImage2 : list) {
                    if (((Double) hashMap.get(labeledImage2.getbImage())).doubleValue() < doubleValue) {
                        arrayList3.add(labeledImage2);
                        hashMap2.merge(Integer.valueOf(labeledImage2.getCategory()), 1, (num, num2) -> {
                            return Integer.valueOf(num.intValue() + num2.intValue());
                        });
                    } else {
                        arrayList4.add(labeledImage2);
                        hashMap3.merge(Integer.valueOf(labeledImage2.getCategory()), 1, (num3, num4) -> {
                            return Integer.valueOf(num3.intValue() + num4.intValue());
                        });
                    }
                }
                int i8 = -1;
                int i9 = -1;
                int i10 = 0;
                for (Map.Entry entry : hashMap2.entrySet()) {
                    int intValue = ((Integer) entry.getValue()).intValue();
                    int intValue2 = ((Integer) entry.getKey()).intValue();
                    if (i8 == -1) {
                        i8 = intValue2;
                        i10 = intValue;
                    } else if (intValue > i10) {
                        i8 = intValue2;
                        i10 = intValue;
                    }
                }
                int i11 = 0;
                for (Map.Entry entry2 : hashMap3.entrySet()) {
                    int intValue3 = ((Integer) entry2.getValue()).intValue();
                    int intValue4 = ((Integer) entry2.getKey()).intValue();
                    if (i9 == -1) {
                        i9 = intValue4;
                        i11 = intValue3;
                    } else if (intValue3 > i11) {
                        i9 = intValue4;
                        i11 = intValue3;
                    }
                }
                int i12 = 0;
                int i13 = 0;
                int i14 = 0;
                int i15 = 0;
                int i16 = 0;
                int i17 = 0;
                int i18 = 0;
                int i19 = 0;
                for (LabeledImage labeledImage3 : arrayList3) {
                    boolean z = labeledImage3.getCategory() == i8;
                    boolean z2 = labeledImage3.getCategory() == i9;
                    if (z) {
                        i12++;
                    } else {
                        i14++;
                    }
                    if (z2) {
                        i19++;
                    } else {
                        i17++;
                    }
                }
                for (LabeledImage labeledImage4 : arrayList4) {
                    boolean z3 = labeledImage4.getCategory() == i8;
                    if (labeledImage4.getCategory() == i9) {
                        i16++;
                    } else {
                        i18++;
                    }
                    if (z3) {
                        i15++;
                    } else {
                        i13++;
                    }
                }
                int i20 = i12 + i13 + i14 + i15;
                double pow = (1.0d - Math.pow(i12 / (i12 + i14), 2.0d)) - Math.pow(i14 / (i12 + i14), 2.0d);
                double pow2 = (1.0d - Math.pow(i16 / (i16 + i18), 2.0d)) - Math.pow(i18 / (i16 + i18), 2.0d);
                double d6 = (((i12 + i14) / i20) * pow) + (((i16 + i18) / i20) * pow2);
                double d7 = (i12 + i16) / ((i12 + i16) + (i15 + i19));
                double d8 = (i13 + i17) / ((i13 + i17) + (i14 + i18));
                double d9 = i12 + i16 + i14 + i18;
                double d10 = d9 != 0.0d ? (i12 + i16) / d9 : Double.NaN;
                double d11 = (2.0d * (d10 * d7)) / (d10 + d7);
                if (d6 < d3 || (d6 == d3 && arrayList3.size() > i4)) {
                    d2 = doubleValue;
                    pair = pair2;
                    arrayListArr[0] = arrayList3;
                    arrayListArr[1] = arrayList4;
                    d3 = d6;
                    i4 = arrayList3.size();
                    i2 = i8;
                    i3 = i9;
                    d4 = pow;
                    d5 = pow2;
                }
            }
        }
        countHashCollection.remove(pair);
        if (d3 < d && !MathUtil.isDoubleEquals(d4, d, 1.0E-8d)) {
            return new Triple<>(new InnerNode((FuzzyHash) pair.getFirst(), (HashingAlgorithm) pair.getSecond(), d2, d3, d4, d5), arrayListArr, new int[]{i2, i3});
        }
        long j = Long.MIN_VALUE;
        int i21 = -1;
        for (Map.Entry entry3 : ((Map) list.stream().collect(Collectors.groupingBy(labeledImage5 -> {
            return Integer.valueOf(labeledImage5.getCategory());
        }, Collectors.counting()))).entrySet()) {
            if (((Long) entry3.getValue()).longValue() > j) {
                j = ((Long) entry3.getValue()).longValue();
                i21 = ((Integer) entry3.getKey()).intValue();
            }
        }
        return new Triple<>(new LeafNode(i21), arrayListArr, (Object) null);
    }

    private List<Pair<FuzzyHash, HashingAlgorithm>> createFuzzyHashes(Map<HashingAlgorithm, Map<BufferedImage, Hash>> map) {
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        Iterator<LabeledImage> it = this.labeledImages.iterator();
        while (it.hasNext()) {
            hashSet.add(Integer.valueOf(it.next().getCategory()));
        }
        Iterator<HashingAlgorithm> it2 = this.steps.iterator();
        while (it2.hasNext()) {
            HashingAlgorithm next = it2.next();
            new HashMap();
            int keyResolution = next.getKeyResolution();
            PcgRSFast pcgRSFast = new PcgRSFast();
            for (int i = 0; i < hashSet.size(); i++) {
                arrayList.add(new Pair(new FuzzyHash(new Hash(new BigInteger(keyResolution, (Random) pcgRSFast), keyResolution, next.algorithmId())), next));
            }
        }
        return arrayList;
    }

    public Map<Integer, Integer> countLeafCategories() {
        HashMap hashMap = new HashMap();
        ArrayDeque arrayDeque = new ArrayDeque();
        arrayDeque.add(this.forest.get(0));
        while (!arrayDeque.isEmpty()) {
            TreeNode treeNode = (TreeNode) arrayDeque.poll();
            if (treeNode instanceof InnerNode) {
                arrayDeque.add(((InnerNode) treeNode).rightNode);
                arrayDeque.add(((InnerNode) treeNode).leftNode);
            } else {
                hashMap.merge(Integer.valueOf(((LeafNode) treeNode).category), 1, (num, num2) -> {
                    return Integer.valueOf(num.intValue() + num2.intValue());
                });
            }
        }
        return hashMap;
    }

    public static void main(String[] strArr) throws IOException {
        RandomForestCategorizer randomForestCategorizer = new RandomForestCategorizer();
        randomForestCategorizer.addHashingAlgorithm(new AverageHash(32));
        randomForestCategorizer.addHashingAlgorithm(new PerceptiveHash(32));
        randomForestCategorizer.addHashingAlgorithm(new RotAverageHash(32));
        randomForestCategorizer.addTestImages(new LabeledImage(0, new File("src/test/resources/ballon.jpg")));
        randomForestCategorizer.addTestImages(new LabeledImage(1, new File("src/test/resources/copyright.jpg")));
        randomForestCategorizer.addTestImages(new LabeledImage(1, new File("src/test/resources/highQuality.jpg")));
        randomForestCategorizer.addTestImages(new LabeledImage(1, new File("src/test/resources/lowQuality.jpg")));
        randomForestCategorizer.addTestImages(new LabeledImage(2, new File("src/test/resources/Lenna.png")));
        randomForestCategorizer.addTestImages(new LabeledImage(2, new File("src/test/resources/Lenna90.png")));
        randomForestCategorizer.addTestImages(new LabeledImage(2, new File("src/test/resources/Lenna180.png")));
        randomForestCategorizer.addTestImages(new LabeledImage(2, new File("src/test/resources/LennaSaltAndPepper.png")));
        randomForestCategorizer.addTestImages(new LabeledImage(3, new File("src/test/resources/TestShapes.png")));
        randomForestCategorizer.trainMatcher(3, 2, 1);
        randomForestCategorizer.forest.get(0).printTree();
        System.out.println(randomForestCategorizer.categorizeImage(ImageIO.read(new File("src/test/resources/lowQuality.jpg"))));
    }

    @Override // dev.brachtendorf.jimagehash.matcher.categorize.CategoricalImageMatcher
    public CategorizationResult categorizeImage(BufferedImage bufferedImage) {
        List<Integer> categories = getCategories();
        int[] iArr = new int[categories.size()];
        Iterator<TreeNode> it = this.forest.iterator();
        while (it.hasNext()) {
            int i = it.next().predictAgainstAll(bufferedImage)[0];
            if (i != -1) {
                iArr[i] = iArr[i] + 1;
            }
        }
        return new CategorizationResult(categories.get(ArrayUtil.maximumIndex(iArr)).intValue(), iArr[r0] / this.forest.size());
    }

    @Override // dev.brachtendorf.jimagehash.matcher.categorize.CategoricalImageMatcher
    public List<Integer> getCategories() {
        return new ArrayList(this.categories);
    }

    @Override // dev.brachtendorf.jimagehash.matcher.categorize.CategoricalImageMatcher
    public void recomputeCategories() {
    }

    @Override // dev.brachtendorf.jimagehash.matcher.categorize.CategoricalImageMatcher
    public List<String> getImagesInCategory(int i) {
        return null;
    }

    @Override // dev.brachtendorf.jimagehash.matcher.categorize.CategoricalImageMatcher
    public int getCategory(String str) {
        return 0;
    }

    public void printTree() {
        this.forest.get(0).printTree();
    }

    @Override // dev.brachtendorf.jimagehash.matcher.categorize.CategoricalImageMatcher
    public CategorizationResult categorizeImageAndAdd(BufferedImage bufferedImage, String str) {
        throw new UnsupportedOperationException("Can't add images on the fly. Rebuilding time to expensive");
    }
}
