/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

public class MILR
extends Classifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler {
    static final long serialVersionUID = 1996101190172373826L;
    protected double[] m_Par;
    protected int m_NumClasses;
    protected double m_Ridge = 1.0E-6;
    protected int[] m_Classes;
    protected double[][][] m_Data;
    protected Instances m_Attributes;
    protected double[] xMean = null;
    protected double[] xSD = null;
    protected int m_AlgorithmType = 0;
    public static final int ALGORITHMTYPE_DEFAULT = 0;
    public static final int ALGORITHMTYPE_ARITHMETIC = 1;
    public static final int ALGORITHMTYPE_GEOMETRIC = 2;
    public static final Tag[] TAGS_ALGORITHMTYPE = new Tag[]{new Tag(0, "standard MI assumption"), new Tag(1, "collective MI assumption, arithmetic mean for posteriors"), new Tag(2, "collective MI assumption, geometric mean for posteriors")};

    public String globalInfo() {
        return "Uses either standard or collective multi-instance assumption, but within linear regression. For the collective assumption, it offers arithmetic or geometric mean for the posteriors.";
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> result = new Vector<Option>();
        result.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        result.addElement(new Option("\tSet the ridge in the log-likelihood.", "R", 1, "-R <ridge>"));
        result.addElement(new Option("\tDefines the type of algorithm:\n\t 0. standard MI assumption\n\t 1. collective MI assumption, arithmetic mean for posteriors\n\t 2. collective MI assumption, geometric mean for posteriors", "A", 1, "-A [0|1|2]"));
        return result.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setDebug(Utils.getFlag('D', options));
        String tmpStr = Utils.getOption('R', options);
        if (tmpStr.length() != 0) {
            this.setRidge(Double.parseDouble(tmpStr));
        } else {
            this.setRidge(1.0E-6);
        }
        tmpStr = Utils.getOption('A', options);
        if (tmpStr.length() != 0) {
            this.setAlgorithmType(new SelectedTag(Integer.parseInt(tmpStr), TAGS_ALGORITHMTYPE));
        } else {
            this.setAlgorithmType(new SelectedTag(0, TAGS_ALGORITHMTYPE));
        }
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getDebug()) {
            result.add("-D");
        }
        result.add("-R");
        result.add("" + this.getRidge());
        result.add("-A");
        result.add("" + this.m_AlgorithmType);
        return result.toArray(new String[result.size()]);
    }

    public String ridgeTipText() {
        return "The ridge in the log-likelihood.";
    }

    public void setRidge(double ridge) {
        this.m_Ridge = ridge;
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public String algorithmTypeTipText() {
        return "The mean type for the posteriors.";
    }

    public SelectedTag getAlgorithmType() {
        return new SelectedTag(this.m_AlgorithmType, TAGS_ALGORITHMTYPE);
    }

    public void setAlgorithmType(SelectedTag newType) {
        if (newType.getTags() == TAGS_ALGORITHMTYPE) {
            this.m_AlgorithmType = newType.getSelectedTag().getID();
        }
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return result;
    }

    @Override
    public Capabilities getMultiInstanceCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.disableAllClasses();
        result.enable(Capabilities.Capability.NO_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances train) throws Exception {
        this.getCapabilities().testWithFail(train);
        train = new Instances(train);
        train.deleteWithMissingClass();
        this.m_NumClasses = train.numClasses();
        int nR = train.attribute(1).relation().numAttributes();
        int nC = train.numInstances();
        this.m_Data = new double[nC][nR][];
        this.m_Classes = new int[nC];
        this.m_Attributes = train.attribute(1).relation();
        this.xMean = new double[nR];
        this.xSD = new double[nR];
        double sY1 = 0.0;
        double sY0 = 0.0;
        double totIns = 0.0;
        int[] missingbags = new int[nR];
        if (this.m_Debug) {
            System.out.println("Extracting data...");
        }
        int h = 0;
        while (h < this.m_Data.length) {
            Instance current = train.instance(h);
            this.m_Classes[h] = (int)current.classValue();
            Instances currInsts = current.relationalValue(1);
            int nI = currInsts.numInstances();
            totIns += (double)nI;
            int i = 0;
            while (i < nR) {
                this.m_Data[h][i] = new double[nI];
                double avg = 0.0;
                double std = 0.0;
                double num = 0.0;
                int k = 0;
                while (k < nI) {
                    if (!currInsts.instance(k).isMissing(i)) {
                        this.m_Data[h][i][k] = currInsts.instance(k).value(i);
                        avg += this.m_Data[h][i][k];
                        std += this.m_Data[h][i][k] * this.m_Data[h][i][k];
                        num += 1.0;
                    } else {
                        this.m_Data[h][i][k] = Double.NaN;
                    }
                    ++k;
                }
                if (num > 0.0) {
                    int n = i;
                    this.xMean[n] = this.xMean[n] + avg / num;
                    int n2 = i;
                    this.xSD[n2] = this.xSD[n2] + std / num;
                } else {
                    int n = i;
                    missingbags[n] = missingbags[n] + 1;
                }
                ++i;
            }
            if (this.m_Classes[h] == 1) {
                sY1 += 1.0;
            } else {
                sY0 += 1.0;
            }
            ++h;
        }
        int j = 0;
        while (j < nR) {
            this.xMean[j] = this.xMean[j] / (double)(nC - missingbags[j]);
            this.xSD[j] = Math.sqrt(Math.abs(this.xSD[j] / ((double)(nC - missingbags[j]) - 1.0) - this.xMean[j] * this.xMean[j] * (double)(nC - missingbags[j]) / ((double)(nC - missingbags[j]) - 1.0)));
            ++j;
        }
        if (this.m_Debug) {
            System.out.println("Descriptives...");
            System.out.println(String.valueOf(sY0) + " bags have class 0 and " + sY1 + " bags have class 1");
            System.out.println("\n Variable     Avg       SD    ");
            j = 0;
            while (j < nR) {
                System.out.println(String.valueOf(Utils.doubleToString(j, 8, 4)) + Utils.doubleToString(this.xMean[j], 10, 4) + Utils.doubleToString(this.xSD[j], 10, 4));
                ++j;
            }
        }
        int i = 0;
        while (i < nC) {
            int j2 = 0;
            while (j2 < nR) {
                int k = 0;
                while (k < this.m_Data[i][j2].length) {
                    if (this.xSD[j2] != 0.0) {
                        this.m_Data[i][j2][k] = !Double.isNaN(this.m_Data[i][j2][k]) ? (this.m_Data[i][j2][k] - this.xMean[j2]) / this.xSD[j2] : 0.0;
                    }
                    ++k;
                }
                ++j2;
            }
            ++i;
        }
        if (this.m_Debug) {
            System.out.println("\nIteration History...");
        }
        double[] x = new double[nR + 1];
        x[0] = Math.log((sY1 + 1.0) / (sY0 + 1.0));
        double[][] b = new double[2][x.length];
        b[0][0] = Double.NaN;
        b[1][0] = Double.NaN;
        int q = 1;
        while (q < x.length) {
            x[q] = 0.0;
            b[0][q] = Double.NaN;
            b[1][q] = Double.NaN;
            ++q;
        }
        OptEng opt = new OptEng(this.m_AlgorithmType);
        opt.setDebug(this.m_Debug);
        this.m_Par = opt.findArgmin(x, b);
        while (this.m_Par == null) {
            this.m_Par = opt.getVarbValues();
            if (this.m_Debug) {
                System.out.println("200 iterations finished, not enough!");
            }
            this.m_Par = opt.findArgmin(this.m_Par, b);
        }
        if (this.m_Debug) {
            System.out.println(" -------------<Converged>--------------");
        }
        if (this.m_AlgorithmType == 1) {
            double[] fs = new double[nR];
            int k = 1;
            while (k < nR + 1) {
                fs[k - 1] = Math.abs(this.m_Par[k]);
                ++k;
            }
            int[] idx = Utils.sort(fs);
            double max = fs[idx[idx.length - 1]];
            int k2 = idx.length - 1;
            while (k2 >= 0) {
                System.out.println(String.valueOf(this.m_Attributes.attribute(idx[k2]).name()) + "\t" + fs[idx[k2]] * 100.0 / max);
                --k2;
            }
        }
        int j3 = 1;
        while (j3 < nR + 1) {
            if (this.xSD[j3 - 1] != 0.0) {
                int n = j3;
                this.m_Par[n] = this.m_Par[n] / this.xSD[j3 - 1];
                this.m_Par[0] = this.m_Par[0] - this.m_Par[j3] * this.xMean[j3 - 1];
            }
            ++j3;
        }
    }

    @Override
    public double[] distributionForInstance(Instance exmp) throws Exception {
        Instances ins = exmp.relationalValue(1);
        int nI = ins.numInstances();
        int nA = ins.numAttributes();
        double[][] dat = new double[nI][nA + 1];
        int j = 0;
        while (j < nI) {
            dat[j][0] = 1.0;
            int idx = 1;
            int k = 0;
            while (k < nA) {
                dat[j][idx] = !ins.instance(j).isMissing(k) ? ins.instance(j).value(k) : this.xMean[idx - 1];
                ++idx;
                ++k;
            }
            ++j;
        }
        double[] distribution = new double[2];
        switch (this.m_AlgorithmType) {
            case 0: {
                distribution[0] = 0.0;
                int i = 0;
                while (i < nI) {
                    double exp = 0.0;
                    int r = 0;
                    while (r < this.m_Par.length) {
                        exp += this.m_Par[r] * dat[i][r];
                        ++r;
                    }
                    exp = Math.exp(exp);
                    distribution[0] = distribution[0] - Math.log(1.0 + exp);
                    ++i;
                }
                distribution[0] = Math.exp(distribution[0]);
                distribution[1] = 1.0 - distribution[0];
                break;
            }
            case 1: {
                distribution[0] = 0.0;
                int i = 0;
                while (i < nI) {
                    double exp = 0.0;
                    int r = 0;
                    while (r < this.m_Par.length) {
                        exp += this.m_Par[r] * dat[i][r];
                        ++r;
                    }
                    exp = Math.exp(exp);
                    distribution[0] = distribution[0] + 1.0 / (1.0 + exp);
                    ++i;
                }
                distribution[0] = distribution[0] / (double)nI;
                distribution[1] = 1.0 - distribution[0];
                break;
            }
            case 2: {
                int i = 0;
                while (i < nI) {
                    double exp = 0.0;
                    int r = 0;
                    while (r < this.m_Par.length) {
                        exp += this.m_Par[r] * dat[i][r];
                        ++r;
                    }
                    distribution[1] = distribution[1] + exp / (double)nI;
                    ++i;
                }
                distribution[1] = 1.0 / (1.0 + Math.exp(-distribution[1]));
                distribution[0] = 1.0 - distribution[1];
            }
        }
        return distribution;
    }

    public String toString() {
        String result = "Modified Logistic Regression";
        if (this.m_Par == null) {
            return String.valueOf(result) + ": No model built yet.";
        }
        result = String.valueOf(result) + "\nMean type: " + this.getAlgorithmType().getSelectedTag().getReadable() + "\n";
        result = String.valueOf(result) + "\nCoefficients...\nVariable      Coeff.\n";
        int j = 1;
        int idx = 0;
        while (j < this.m_Par.length) {
            result = String.valueOf(result) + this.m_Attributes.attribute(idx).name();
            result = String.valueOf(result) + " " + Utils.doubleToString(this.m_Par[j], 12, 4);
            result = String.valueOf(result) + "\n";
            ++j;
            ++idx;
        }
        result = String.valueOf(result) + "Intercept:";
        result = String.valueOf(result) + " " + Utils.doubleToString(this.m_Par[0], 10, 4);
        result = String.valueOf(result) + "\n";
        result = String.valueOf(result) + "\nOdds Ratios...\nVariable         O.R.\n";
        j = 1;
        idx = 0;
        while (j < this.m_Par.length) {
            result = String.valueOf(result) + " " + this.m_Attributes.attribute(idx).name();
            double ORc = Math.exp(this.m_Par[j]);
            result = String.valueOf(result) + " " + (ORc > 1.0E10 ? "" + ORc : Utils.doubleToString(ORc, 12, 4));
            ++j;
            ++idx;
        }
        result = String.valueOf(result) + "\n";
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9144 $");
    }

    public static void main(String[] argv) {
        MILR.runClassifier(new MILR(), argv);
    }

    private class OptEng
    extends Optimization {
        private int m_Type;

        public OptEng(int type) {
            this.m_Type = type;
        }

        @Override
        protected double objectiveFunction(double[] x) {
            double nll = 0.0;
            switch (this.m_Type) {
                case 0: {
                    double bag;
                    int nI;
                    int i = 0;
                    while (i < MILR.this.m_Classes.length) {
                        nI = MILR.this.m_Data[i][0].length;
                        bag = 0.0;
                        double prod = 0.0;
                        int j = 0;
                        while (j < nI) {
                            double exp = 0.0;
                            int k = MILR.this.m_Data[i].length - 1;
                            while (k >= 0) {
                                exp += MILR.this.m_Data[i][k][j] * x[k + 1];
                                --k;
                            }
                            exp += x[0];
                            exp = Math.exp(exp);
                            if (MILR.this.m_Classes[i] == 1) {
                                prod -= Math.log(1.0 + exp);
                            } else {
                                bag += Math.log(1.0 + exp);
                            }
                            ++j;
                        }
                        if (MILR.this.m_Classes[i] == 1) {
                            bag = -Math.log(1.0 - Math.exp(prod));
                        }
                        nll += bag;
                        ++i;
                    }
                    break;
                }
                case 1: {
                    double exp;
                    double bag;
                    int nI;
                    int i = 0;
                    while (i < MILR.this.m_Classes.length) {
                        nI = MILR.this.m_Data[i][0].length;
                        bag = 0.0;
                        int j = 0;
                        while (j < nI) {
                            exp = 0.0;
                            int k = MILR.this.m_Data[i].length - 1;
                            while (k >= 0) {
                                exp += MILR.this.m_Data[i][k][j] * x[k + 1];
                                --k;
                            }
                            exp += x[0];
                            exp = Math.exp(exp);
                            bag = MILR.this.m_Classes[i] == 1 ? (bag += 1.0 - 1.0 / (1.0 + exp)) : (bag += 1.0 / (1.0 + exp));
                            ++j;
                        }
                        nll -= Math.log(bag /= (double)nI);
                        ++i;
                    }
                    break;
                }
                case 2: {
                    double exp;
                    double bag;
                    int nI;
                    int i = 0;
                    while (i < MILR.this.m_Classes.length) {
                        nI = MILR.this.m_Data[i][0].length;
                        bag = 0.0;
                        int j = 0;
                        while (j < nI) {
                            exp = 0.0;
                            int k = MILR.this.m_Data[i].length - 1;
                            while (k >= 0) {
                                exp += MILR.this.m_Data[i][k][j] * x[k + 1];
                                --k;
                            }
                            bag = MILR.this.m_Classes[i] == 1 ? (bag -= exp / (double)nI) : (bag += (exp += x[0]) / (double)nI);
                            ++j;
                        }
                        nll += Math.log(1.0 + Math.exp(bag));
                        ++i;
                    }
                    break;
                }
            }
            int r = 1;
            while (r < x.length) {
                nll += MILR.this.m_Ridge * x[r] * x[r];
                ++r;
            }
            return nll;
        }

        @Override
        protected double[] evaluateGradient(double[] x) {
            double[] grad = new double[x.length];
            switch (this.m_Type) {
                case 0: {
                    int q;
                    double m;
                    int p;
                    int k;
                    double exp;
                    int j;
                    double denom;
                    int nI;
                    int i = 0;
                    while (i < MILR.this.m_Classes.length) {
                        nI = MILR.this.m_Data[i][0].length;
                        denom = 0.0;
                        double[] bag = new double[grad.length];
                        j = 0;
                        while (j < nI) {
                            exp = 0.0;
                            k = MILR.this.m_Data[i].length - 1;
                            while (k >= 0) {
                                exp += MILR.this.m_Data[i][k][j] * x[k + 1];
                                --k;
                            }
                            exp += x[0];
                            exp = Math.exp(exp) / (1.0 + Math.exp(exp));
                            if (MILR.this.m_Classes[i] == 1) {
                                denom -= Math.log(1.0 - exp);
                            }
                            p = 0;
                            while (p < x.length) {
                                m = 1.0;
                                if (p > 0) {
                                    m = MILR.this.m_Data[i][p - 1][j];
                                }
                                int n = p++;
                                bag[n] = bag[n] + m * exp;
                            }
                            ++j;
                        }
                        denom = Math.exp(denom);
                        q = 0;
                        while (q < grad.length) {
                            if (MILR.this.m_Classes[i] == 1) {
                                int n = q;
                                grad[n] = grad[n] - bag[q] / (denom - 1.0);
                            } else {
                                int n = q;
                                grad[n] = grad[n] + bag[q];
                            }
                            ++q;
                        }
                        ++i;
                    }
                    break;
                }
                case 1: {
                    int q;
                    double m;
                    int p;
                    int k;
                    double exp;
                    int j;
                    double denom;
                    int nI;
                    int i = 0;
                    while (i < MILR.this.m_Classes.length) {
                        nI = MILR.this.m_Data[i][0].length;
                        denom = 0.0;
                        double[] numrt = new double[x.length];
                        j = 0;
                        while (j < nI) {
                            exp = 0.0;
                            k = MILR.this.m_Data[i].length - 1;
                            while (k >= 0) {
                                exp += MILR.this.m_Data[i][k][j] * x[k + 1];
                                --k;
                            }
                            exp += x[0];
                            exp = Math.exp(exp);
                            denom = MILR.this.m_Classes[i] == 1 ? (denom += exp / (1.0 + exp)) : (denom += 1.0 / (1.0 + exp));
                            p = 0;
                            while (p < x.length) {
                                m = 1.0;
                                if (p > 0) {
                                    m = MILR.this.m_Data[i][p - 1][j];
                                }
                                int n = p++;
                                numrt[n] = numrt[n] + m * exp / ((1.0 + exp) * (1.0 + exp));
                            }
                            ++j;
                        }
                        q = 0;
                        while (q < grad.length) {
                            if (MILR.this.m_Classes[i] == 1) {
                                int n = q;
                                grad[n] = grad[n] - numrt[q] / denom;
                            } else {
                                int n = q;
                                grad[n] = grad[n] + numrt[q] / denom;
                            }
                            ++q;
                        }
                        ++i;
                    }
                    break;
                }
                case 2: {
                    double m;
                    int k;
                    double exp;
                    int j;
                    int nI;
                    int i = 0;
                    while (i < MILR.this.m_Classes.length) {
                        nI = MILR.this.m_Data[i][0].length;
                        double bag = 0.0;
                        double[] sumX = new double[x.length];
                        j = 0;
                        while (j < nI) {
                            int q;
                            exp = 0.0;
                            k = MILR.this.m_Data[i].length - 1;
                            while (k >= 0) {
                                exp += MILR.this.m_Data[i][k][j] * x[k + 1];
                                --k;
                            }
                            exp += x[0];
                            if (MILR.this.m_Classes[i] == 1) {
                                bag -= exp / (double)nI;
                                q = 0;
                                while (q < grad.length) {
                                    m = 1.0;
                                    if (q > 0) {
                                        m = MILR.this.m_Data[i][q - 1][j];
                                    }
                                    int n = q++;
                                    sumX[n] = sumX[n] - m / (double)nI;
                                }
                            } else {
                                bag += exp / (double)nI;
                                q = 0;
                                while (q < grad.length) {
                                    m = 1.0;
                                    if (q > 0) {
                                        m = MILR.this.m_Data[i][q - 1][j];
                                    }
                                    int n = q++;
                                    sumX[n] = sumX[n] + m / (double)nI;
                                }
                            }
                            ++j;
                        }
                        int p = 0;
                        while (p < x.length) {
                            int n = p;
                            grad[n] = grad[n] + Math.exp(bag) * sumX[p] / (1.0 + Math.exp(bag));
                            ++p;
                        }
                        ++i;
                    }
                    break;
                }
            }
            int r = 1;
            while (r < x.length) {
                int n = r;
                grad[n] = grad[n] + 2.0 * MILR.this.m_Ridge * x[r];
                ++r;
            }
            return grad;
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 9144 $");
        }
    }
}

