package edu.rit.compbio.phyl;

import edu.rit.numeric.NonNegativeLeastSquares;

/* loaded from: input_file:edu/rit/compbio/phyl/LeastSquaresBranchLengths.class */
public class LeastSquaresBranchLengths {
    private LeastSquaresBranchLengths() {
    }

    public static double squaredError(DnaSequenceTree dnaSequenceTree, Distance distance) {
        double[] branchLengths = getBranchLengths(dnaSequenceTree);
        int[] tipNodes = getTipNodes(dnaSequenceTree);
        boolean[][] rootPaths = getRootPaths(dnaSequenceTree, tipNodes);
        int length = tipNodes.length;
        double d = 0.0d;
        for (int i = 0; i < length - 1; i++) {
            DnaSequence seq = dnaSequenceTree.seq(tipNodes[i]);
            for (int i2 = i + 1; i2 < length; i2++) {
                double distance2 = distance.distance(seq, dnaSequenceTree.seq(tipNodes[i2])) - treeDistance(i, i2, branchLengths, rootPaths);
                d += distance2 * distance2;
            }
        }
        return d;
    }

    public static double solve(DnaSequenceTree dnaSequenceTree, Distance distance) {
        double[] branchLengths = getBranchLengths(dnaSequenceTree);
        int[] tipNodes = getTipNodes(dnaSequenceTree);
        boolean[][] rootPaths = getRootPaths(dnaSequenceTree, tipNodes);
        int length = branchLengths.length;
        int length2 = tipNodes.length;
        NonNegativeLeastSquares nonNegativeLeastSquares = new NonNegativeLeastSquares((length2 * (length2 - 1)) / 2, length);
        int i = 0;
        for (int i2 = 0; i2 < length2 - 1; i2++) {
            DnaSequence seq = dnaSequenceTree.seq(tipNodes[i2]);
            boolean[] zArr = rootPaths[i2];
            for (int i3 = i2 + 1; i3 < length2; i3++) {
                DnaSequence seq2 = dnaSequenceTree.seq(tipNodes[i3]);
                boolean[] zArr2 = rootPaths[i3];
                nonNegativeLeastSquares.b[i] = distance.distance(seq, seq2);
                double[] dArr = nonNegativeLeastSquares.a[i];
                for (int i4 = 0; i4 < length; i4++) {
                    dArr[i4] = zArr[i4] ^ zArr2[i4] ? 1.0d : 0.0d;
                }
                i++;
            }
        }
        nonNegativeLeastSquares.solve();
        for (int i5 = 0; i5 < length; i5++) {
            dnaSequenceTree.branchLength(i5, Double.valueOf(nonNegativeLeastSquares.x[i5]));
        }
        dnaSequenceTree.branchLength(dnaSequenceTree.root(), null);
        return nonNegativeLeastSquares.normsqr;
    }

    private static double[] getBranchLengths(DnaSequenceTree dnaSequenceTree) {
        int length = dnaSequenceTree.length();
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            Double branchLength = dnaSequenceTree.branchLength(i);
            if (branchLength != null) {
                dArr[i] = branchLength.doubleValue();
            }
        }
        return dArr;
    }

    private static int[] getTipNodes(DnaSequenceTree dnaSequenceTree) {
        int length = dnaSequenceTree.length();
        int[] iArr = new int[(length + 1) / 2];
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            if (dnaSequenceTree.child1(i2) == -1) {
                int i3 = i;
                i++;
                iArr[i3] = i2;
            }
        }
        return iArr;
    }

    private static boolean[][] getRootPaths(DnaSequenceTree dnaSequenceTree, int[] iArr) {
        int length = dnaSequenceTree.length();
        int length2 = iArr.length;
        boolean[][] zArr = new boolean[length2][length];
        for (int i = 0; i < length2; i++) {
            boolean[] zArr2 = zArr[i];
            int i2 = iArr[i];
            while (true) {
                int i3 = i2;
                if (i3 != -1) {
                    zArr2[i3] = true;
                    i2 = dnaSequenceTree.parent(i3);
                }
            }
        }
        return zArr;
    }

    private static double treeDistance(int i, int i2, double[] dArr, boolean[][] zArr) {
        boolean[] zArr2 = zArr[i];
        boolean[] zArr3 = zArr[i2];
        int length = zArr[0].length;
        double d = 0.0d;
        for (int i3 = 0; i3 < length; i3++) {
            if (zArr2[i3] ^ zArr3[i3]) {
                d += dArr[i3];
            }
        }
        return d;
    }
}
