package net.haesleinhuepf.clijx.weka;

import hr.irb.fastRandomForest.FastRandomForest;
import ij.IJ;
import ij.ImageJ;
import ij.ImagePlus;
import ij.ImageStack;
import ij.Prefs;
import ij.process.FloatProcessor;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Random;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import net.haesleinhuepf.clij.clearcl.ClearCLBuffer;
import net.haesleinhuepf.clij2.CLIJ2;
import trainableSegmentation.WekaSegmentation;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Evaluation;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;

/* JADX INFO: Access modifiers changed from: package-private */
@Deprecated
/* loaded from: input_file:net/haesleinhuepf/clijx/weka/CLIJxWeka.class */
public class CLIJxWeka {
    private FastRandomForest classifier;
    private Integer numberOfClasses;
    private Integer numberOfFeatures;
    private CLIJ2 clij2;
    private ClearCLBuffer featureStack;
    ClearCLBuffer classification;
    private ClearCLBuffer distribution;
    private int frf_numberOfTrees = 200;
    private int frf_maxDepth = 0;
    private int frf_numberOfFeatures = 2;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:net/haesleinhuepf/clijx/weka/CLIJxWeka$Classificator.class */
    public static class Classificator implements Runnable {
        private final Instances dataSet;
        private final AbstractClassifier classifier;
        private final int numberOfFeatures;
        float[] features;
        float[] classes;
        private final int width;
        private final int height;

        public Classificator(float[] fArr, float[] fArr2, int i, int i2, Instances instances, AbstractClassifier abstractClassifier, int i3) {
            this.features = fArr;
            this.classes = fArr2;
            this.width = i;
            this.height = i2;
            this.dataSet = instances;
            this.classifier = abstractClassifier;
            this.numberOfFeatures = i3;
        }

        @Override // java.lang.Runnable
        public void run() {
            for (int i = 0; i < this.height; i++) {
                double[] dArr = new double[this.numberOfFeatures + 1];
                for (int i2 = 0; i2 < this.numberOfFeatures; i2++) {
                    dArr[i2] = this.features[(i * this.numberOfFeatures) + i2];
                }
                DenseInstance denseInstance = new DenseInstance(1.0d, dArr);
                denseInstance.setDataset(this.dataSet);
                try {
                    this.classes[i] = ((float) this.classifier.classifyInstance(denseInstance)) + 1.0f;
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }

        public float[] getClasses() {
            return this.classes;
        }
    }

    public CLIJxWeka(CLIJ2 clij2, ClearCLBuffer clearCLBuffer, ClearCLBuffer clearCLBuffer2) {
        this.clij2 = clij2;
        this.featureStack = clearCLBuffer;
        this.classification = clearCLBuffer2;
    }

    public CLIJxWeka(CLIJ2 clij2, ClearCLBuffer clearCLBuffer, FastRandomForest fastRandomForest, Integer num) {
        this.clij2 = clij2;
        this.featureStack = clearCLBuffer;
        this.classifier = fastRandomForest;
        this.numberOfClasses = num;
        this.numberOfFeatures = Integer.valueOf((int) clearCLBuffer.getDepth());
    }

    public CLIJxWeka(CLIJ2 clij2, ClearCLBuffer clearCLBuffer, String str) {
        this.clij2 = clij2;
        this.featureStack = clearCLBuffer;
        loadClassifier(str);
    }

    private void trainClassifier() {
        if (this.classifier != null) {
            System.out.println("Already trained.");
            return;
        }
        if (this.classification == null) {
            System.out.println("No ground truth available");
            return;
        }
        this.numberOfClasses = Integer.valueOf((int) this.clij2.maximumOfAllPixels(this.classification));
        this.numberOfFeatures = Integer.valueOf((int) this.featureStack.getDepth());
        ArrayList<Attribute> makeAttributes = makeAttributes(this.numberOfClasses.intValue(), this.numberOfFeatures.intValue());
        System.out.println("att size" + makeAttributes.size());
        Instances instances = new Instances("segment", makeAttributes, 1);
        instances.setClassIndex(makeAttributes.size() - 1);
        featureStackToInstance(this.clij2, this.featureStack, this.classification, instances);
        System.out.println("Balance training data");
        System.out.println("Num classes " + instances.numClasses());
        Instances balanceTrainingData = WekaSegmentation.balanceTrainingData(instances);
        System.out.println("Init classifier");
        FastRandomForest fastRandomForest = new FastRandomForest();
        fastRandomForest.setNumTrees(this.frf_numberOfTrees);
        fastRandomForest.setSeed(new Random().nextInt());
        fastRandomForest.setNumFeatures(this.frf_numberOfFeatures);
        fastRandomForest.setNumThreads(Prefs.getThreads());
        fastRandomForest.setMaxDepth(this.frf_maxDepth);
        System.out.println("Train classifier");
        try {
            fastRandomForest.buildClassifier(balanceTrainingData);
        } catch (InterruptedException e) {
            IJ.log("Classifier construction was interrupted.");
        } catch (Exception e2) {
            IJ.showMessage(e2.getMessage());
            e2.printStackTrace();
        }
        IJ.log(fastRandomForest.toString());
        System.out.println("Evaluate classifier on training data");
        try {
            Evaluation evaluation = new Evaluation(balanceTrainingData);
            evaluation.evaluateModel(fastRandomForest, balanceTrainingData, new Object[0]);
            System.out.println(evaluation.toSummaryString("\n=== Test data evaluation ===\n", false));
            System.out.println(evaluation.toClassDetailsString() + "\n");
            System.out.println(evaluation.toMatrixString());
            evaluation.errorRate();
        } catch (Exception e3) {
            e3.printStackTrace();
        }
        this.classifier = fastRandomForest;
        this.numberOfClasses = this.numberOfClasses;
    }

    private static ArrayList<Attribute> makeAttributes(int i, int i2) {
        System.out.println("Number of classes: " + i);
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            arrayList.add("C" + (i3 + 1));
        }
        System.out.println("Classes: " + arrayList.size());
        ArrayList<Attribute> arrayList2 = new ArrayList<>();
        for (int i4 = 0; i4 < i2; i4++) {
            arrayList2.add(new Attribute("F" + (i4 + 1)));
        }
        arrayList2.add(new Attribute("class", arrayList));
        return arrayList2;
    }

    private static void featureStackToInstance(CLIJ2 clij2, ClearCLBuffer clearCLBuffer, ClearCLBuffer clearCLBuffer2, Instances instances) {
        ClearCLBuffer create = clij2.create(new long[]{clearCLBuffer.getDepth(), clearCLBuffer.getHeight(), clearCLBuffer.getWidth()}, clij2.Float);
        clij2.transposeXZ(clearCLBuffer, create);
        ClearCLBuffer clearCLBuffer3 = clearCLBuffer2;
        if (clearCLBuffer3.getNativeType() != clij2.Float) {
            clearCLBuffer3 = clij2.create(new long[]{clearCLBuffer2.getWidth(), clearCLBuffer2.getHeight()}, clij2.Float);
            clij2.copy(clearCLBuffer2, clearCLBuffer3);
        }
        ImagePlus pull = clij2.pull(create);
        float[] fArr = (float[]) clij2.pull(clearCLBuffer2).getProcessor().getPixels();
        int depth = (int) clearCLBuffer.getDepth();
        int width = (int) clearCLBuffer.getWidth();
        int height = (int) clearCLBuffer.getHeight();
        System.out.println("Number of features: " + depth);
        for (int i = 0; i < width; i++) {
            pull.setZ(i + 1);
            float[] fArr2 = (float[]) pull.getProcessor().getPixels();
            for (int i2 = 0; i2 < height; i2++) {
                if (fArr[(i2 * width) + i] != 0.0f) {
                    double[] dArr = new double[depth + 1];
                    for (int i3 = 0; i3 < depth; i3++) {
                        dArr[i3] = fArr2[(i2 * depth) + i3];
                    }
                    dArr[dArr.length - 1] = fArr[(i2 * width) + i] - 1.0f;
                    instances.add(new DenseInstance(1.0d, dArr));
                }
            }
        }
        System.out.println("number of instances: " + instances.size());
        if (clearCLBuffer2 != clearCLBuffer3) {
            clij2.release(clearCLBuffer3);
        }
        clij2.release(create);
    }

    public static void main(String[] strArr) {
        new ImageJ();
        ImagePlus openImage = IJ.openImage("src/test/resources/blobs.tif");
        CLIJ2 clij2 = CLIJ2.getInstance();
        ClearCLBuffer push = clij2.push(openImage);
        ClearCLBuffer create = clij2.create(push);
        for (int i = 0; i < 10; i++) {
            long currentTimeMillis = System.currentTimeMillis();
            BinaryWekaPixelClassifier.binaryWekaPixelClassifier(clij2, push, create, "original gaussianblur=1 gaussianblur=5 sobelofgaussian=1 sobelofgaussian=5", "src/test/resources/blobs.model");
            System.out.println("Duration " + (System.currentTimeMillis() - currentTimeMillis));
        }
        clij2.show(create, "res");
    }

    private static ClearCLBuffer featureStackToInstance(CLIJ2 clij2, ClearCLBuffer clearCLBuffer, AbstractClassifier abstractClassifier, int i) {
        ClearCLBuffer create = clij2.create(new long[]{clearCLBuffer.getDepth(), clearCLBuffer.getHeight(), clearCLBuffer.getWidth()}, clij2.Float);
        clij2.transposeXZ(clearCLBuffer, create);
        ImageStack imageStack = clij2.pull(create).getImageStack();
        ImagePlus imagePlus = new ImagePlus("classified", new FloatProcessor((int) clearCLBuffer.getWidth(), (int) clearCLBuffer.getHeight()));
        float[] fArr = (float[]) imagePlus.getProcessor().getPixels();
        int depth = (int) clearCLBuffer.getDepth();
        int width = (int) clearCLBuffer.getWidth();
        int height = (int) clearCLBuffer.getHeight();
        ArrayList<Attribute> makeAttributes = makeAttributes(i, depth);
        Instances instances = new Instances("segment", makeAttributes, 1);
        instances.setClassIndex(makeAttributes.size() - 1);
        System.out.println("Hello object " + width);
        long currentTimeMillis = System.currentTimeMillis();
        Thread[] threadArr = new Thread[width];
        Classificator[] classificatorArr = new Classificator[width];
        for (int i2 = 0; i2 < width; i2++) {
            classificatorArr[i2] = new Classificator((float[]) imageStack.getProcessor(i2 + 1).getPixels(), new float[height], width, height, instances, abstractClassifier, depth);
            threadArr[i2] = new Thread(classificatorArr[i2]);
            threadArr[i2].start();
        }
        for (int i3 = 0; i3 < width; i3++) {
            try {
                threadArr[i3].join();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        for (int i4 = 0; i4 < width; i4++) {
            float[] classes = classificatorArr[i4].getClasses();
            for (int i5 = 0; i5 < height; i5++) {
                fArr[(i5 * width) + i4] = classes[i5];
            }
        }
        System.out.println("inner duration " + (System.currentTimeMillis() - currentTimeMillis));
        clij2.release(create);
        return clij2.push(imagePlus);
    }

    private void applyClassifier() {
        if (this.classification != null) {
            System.out.println("Alread classified");
        } else if (this.classifier == null) {
            System.out.println("No classifier available.");
        } else {
            this.classification = featureStackToInstance(this.clij2, this.featureStack, (AbstractClassifier) this.classifier, this.numberOfClasses.intValue());
        }
    }

    public FastRandomForest getClassifier() {
        trainClassifier();
        return this.classifier;
    }

    public ClearCLBuffer getDistribution() {
        return null;
    }

    public ClearCLBuffer getClassification() {
        applyClassifier();
        return this.classification;
    }

    public void saveClassifier(String str) {
        if (this.classifier == null) {
            trainClassifier();
        }
        if (this.classifier == null) {
            System.out.println("No classifier to save");
            return;
        }
        if (new File(str).getParentFile() != null) {
            new File(str).getParentFile().mkdirs();
        }
        try {
            File file = new File(str);
            OutputStream fileOutputStream = new FileOutputStream(file);
            if (file.getName().endsWith(".gz")) {
                fileOutputStream = new GZIPOutputStream(fileOutputStream);
            }
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
            objectOutputStream.writeObject(this.classifier);
            objectOutputStream.writeObject(this.numberOfClasses);
            objectOutputStream.writeObject(this.numberOfFeatures);
            objectOutputStream.flush();
            objectOutputStream.close();
        } catch (Exception e) {
            IJ.error("Save Failed", "Error when saving classifier into a file");
        }
    }

    private void loadClassifier(String str) {
        try {
            File file = new File(str);
            InputStream fileInputStream = new FileInputStream(file);
            if (file.getName().endsWith(".gz")) {
                fileInputStream = new GZIPInputStream(fileInputStream);
            }
            ObjectInputStream objectInputStream = new ObjectInputStream(fileInputStream);
            this.classifier = (FastRandomForest) objectInputStream.readObject();
            this.numberOfClasses = (Integer) objectInputStream.readObject();
            this.numberOfFeatures = (Integer) objectInputStream.readObject();
            objectInputStream.close();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e2) {
            e2.printStackTrace();
        }
    }

    public Integer getNumberOfClasses() {
        return this.numberOfClasses;
    }

    public void printClassifier() {
        System.out.println(this.classifier);
    }

    public void setNumberOfTrees(int i) {
        this.frf_numberOfTrees = i;
    }

    public void setMaxDepth(int i) {
        this.frf_maxDepth = i;
    }

    public void setNumberOfFeatures(int i) {
        this.frf_numberOfFeatures = i;
    }

    public void setFeatureStack(ClearCLBuffer clearCLBuffer) {
        this.featureStack = clearCLBuffer;
        if (this.classification != null) {
            this.clij2.release(this.classification);
        }
        this.classification = null;
    }
}
