package org.ujmp.core.doublematrix.calculation.general.missingvalues;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
import org.ujmp.core.doublematrix.calculation.general.missingvalues.Impute;
import org.ujmp.core.util.MathUtil;
import org.ujmp.core.util.UJMPSettings;

/* loaded from: input_file:org/ujmp/core/doublematrix/calculation/general/missingvalues/ImputeEM.class */
public class ImputeEM extends AbstractDoubleCalculation {
    private static final long serialVersionUID = -1272010036598212696L;
    private Matrix bestGuess;
    private Matrix imputed;
    private double delta;
    private final double decay = 0.66d;
    private File tempFile;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/ujmp/core/doublematrix/calculation/general/missingvalues/ImputeEM$PredictColumn.class */
    public class PredictColumn implements Callable<Long> {
        long column;

        public PredictColumn(long j) {
            this.column = 0L;
            this.column = j;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Long call() throws Exception {
            Matrix replaceInColumn = ImputeEM.replaceInColumn(ImputeEM.this.getSource(), ImputeEM.this.bestGuess, this.column);
            synchronized (ImputeEM.this.imputed) {
                for (int i = 0; i < replaceInColumn.getRowCount(); i++) {
                    ImputeEM.this.imputed.setAsDouble(replaceInColumn.getAsDouble(i, 0), i, this.column);
                }
            }
            return Long.valueOf(this.column);
        }
    }

    public ImputeEM(Matrix matrix) throws IOException {
        this(matrix, null);
    }

    public ImputeEM(Matrix matrix, Matrix matrix2) throws IOException {
        this(matrix, matrix2, 1.0E-6d, File.createTempFile("ujmp-impute-em-" + System.currentTimeMillis(), ".csv"));
    }

    public ImputeEM(Matrix matrix, Matrix matrix2, double d, File file) {
        super(matrix);
        this.bestGuess = null;
        this.imputed = null;
        this.delta = 1.0E-6d;
        this.decay = 0.66d;
        this.bestGuess = matrix2;
        this.delta = d;
        this.tempFile = file;
    }

    @Override // org.ujmp.core.doublematrix.calculation.DoubleCalculation
    public double getDouble(long... jArr) {
        if (this.imputed == null) {
            createMatrix();
        }
        double asDouble = getSource().getAsDouble(jArr);
        return MathUtil.isNaNOrInfinite(asDouble) ? this.imputed.getAsDouble(jArr) : asDouble;
    }

    private void createMatrix() {
        double euklideanDistanceTo;
        try {
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(UJMPSettings.getInstance().getNumberOfThreads());
            Matrix source = getSource();
            double valueCount = source.getValueCount();
            long euklideanValue = (long) source.countMissing(Calculation.Ret.NEW, Integer.MAX_VALUE).getEuklideanValue();
            System.out.println("missing values: " + euklideanValue + " (" + (((int) Math.round((euklideanValue * 1000.0d) / valueCount)) / 10.0d) + "%)");
            System.out.println("============================================");
            if (this.bestGuess == null) {
                this.bestGuess = getSource().impute(Calculation.Ret.NEW, Impute.ImputationMethod.RowMean, new Object[0]);
            }
            int i = 0;
            do {
                int i2 = i;
                i++;
                System.out.println("Iteration " + i2);
                ArrayList arrayList = new ArrayList();
                this.imputed = Matrix.Factory.zeros(source.getSize());
                long currentTimeMillis = System.currentTimeMillis();
                for (long j = 0; j < source.getColumnCount(); j++) {
                    if (containsMissingValues(j)) {
                        arrayList.add(newFixedThreadPool.submit(new PredictColumn(j)));
                    }
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    Long l = (Long) ((Future) it.next()).get();
                    System.out.println((((l.longValue() * 1000) / source.getColumnCount()) / 10.0d) + "% completed (" + ((long) (((source.getColumnCount() - l.longValue()) / ((l.longValue() + 1) / (System.currentTimeMillis() - currentTimeMillis))) / 1000.0d)) + " seconds remaining)");
                }
                Matrix plus = this.bestGuess.times(0.66d).plus(this.imputed.times(0.33999999999999997d));
                for (int i3 = 0; i3 < getSource().getRowCount(); i3++) {
                    for (int i4 = 0; i4 < getSource().getColumnCount(); i4++) {
                        double asDouble = getSource().getAsDouble(i3, i4);
                        if (!MathUtil.isNaNOrInfinite(asDouble)) {
                            plus.setAsDouble(asDouble, i3, i4);
                        }
                    }
                }
                euklideanDistanceTo = plus.euklideanDistanceTo(this.bestGuess, true) / euklideanValue;
                System.out.println("delta: " + euklideanDistanceTo);
                System.out.println("============================================");
                this.bestGuess = plus;
                this.bestGuess.exportTo().file(this.tempFile).asDenseCSV();
            } while (this.delta < euklideanDistanceTo);
            newFixedThreadPool.shutdown();
            this.imputed = this.bestGuess;
            if (this.imputed.containsMissingValues()) {
                throw new RuntimeException("Matrix has still missing values after imputation");
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private boolean containsMissingValues(long j) {
        for (int i = 0; i < getSource().getRowCount(); i++) {
            if (MathUtil.isNaNOrInfinite(getSource().getAsDouble(i, j))) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Matrix replaceInColumn(Matrix matrix, Matrix matrix2, long j) {
        Matrix deleteColumns = matrix2.deleteColumns(Calculation.Ret.NEW, j);
        Matrix selectColumns = matrix.selectColumns(Calculation.Ret.NEW, j);
        ArrayList arrayList = new ArrayList();
        long rowCount = selectColumns.getRowCount();
        while (true) {
            long j2 = rowCount - 1;
            rowCount = j2;
            if (j2 < 0) {
                break;
            }
            if (MathUtil.isNaNOrInfinite(selectColumns.getAsDouble(rowCount, 0))) {
                arrayList.add(Long.valueOf(rowCount));
            }
        }
        if (arrayList.isEmpty()) {
            return selectColumns;
        }
        Matrix deleteRows = deleteColumns.deleteRows(Calculation.Ret.NEW, arrayList);
        Matrix mtimes = Matrix.Factory.horCat(deleteColumns, (DenseDoubleMatrix2D) DenseDoubleMatrix2D.Factory.ones(deleteColumns.getRowCount(), 1L)).mtimes(Matrix.Factory.horCat(deleteRows, (DenseDoubleMatrix2D) DenseDoubleMatrix2D.Factory.ones(deleteRows.getRowCount(), 1L)).pinv().mtimes(selectColumns.deleteRows(Calculation.Ret.NEW, arrayList)));
        for (int i = 0; i < selectColumns.getRowCount(); i++) {
            double asDouble = selectColumns.getAsDouble(i, 0);
            if (!Double.isNaN(asDouble)) {
                mtimes.setAsDouble(asDouble, i, 0);
            }
        }
        return mtimes;
    }
}
