/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.Winnow;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;

public class WinnowTrainer
extends ClassifierTrainer<Winnow> {
    static final double DEFAULT_ALPHA = 2.0;
    static final double DEFAULT_BETA = 2.0;
    static final double DEFAULT_NFACTOR = 0.5;
    double alpha;
    double beta;
    double theta;
    double nfactor;
    double[][] weights;
    Winnow classifier;

    public WinnowTrainer() {
        this(2.0, 2.0, 0.5);
    }

    public WinnowTrainer(double a, double b) {
        this(a, b, 0.5);
    }

    public WinnowTrainer(double a, double b, double nfact) {
        this.alpha = a;
        this.beta = b;
        this.nfactor = nfact;
    }

    @Override
    public Winnow getClassifier() {
        return this.classifier;
    }

    @Override
    public Winnow train(InstanceList trainingList) {
        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();
        if (selectedFeatures != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        trainingList.getDataAlphabet().stopGrowth();
        trainingList.getTargetAlphabet().stopGrowth();
        Pipe dataPipe = trainingList.getPipe();
        Alphabet dict = trainingList.getDataAlphabet();
        int numLabels = trainingList.getTargetAlphabet().size();
        int numFeats = dict.size();
        this.theta = (double)numFeats * this.nfactor;
        this.weights = new double[numLabels][numFeats];
        for (int i = 0; i < numLabels; ++i) {
            for (int j = 0; j < numFeats; ++j) {
                this.weights[i][j] = 1.0;
            }
        }
        for (int ii = 0; ii < trainingList.size(); ++ii) {
            Instance inst = (Instance)trainingList.get(ii);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv = (FeatureVector)inst.getData();
            double[] results = new double[numLabels];
            int fvisize = fv.numLocations();
            int correctIndex = labeling.getBestIndex();
            for (int rpos = 0; rpos < numLabels; ++rpos) {
                results[rpos] = 0.0;
            }
            for (int fvi = 0; fvi < fvisize; ++fvi) {
                int fi = fv.indexAtLocation(fvi);
                for (int lpos = 0; lpos < numLabels; ++lpos) {
                    int n = lpos;
                    results[n] = results[n] + this.weights[lpos][fi];
                }
            }
            for (int ri = 0; ri < numLabels; ++ri) {
                if (results[ri] > this.theta) {
                    if (correctIndex == ri) continue;
                    this.demote(ri, fv);
                    continue;
                }
                if (correctIndex != ri) continue;
                this.promote(ri, fv);
            }
        }
        this.classifier = new Winnow(dataPipe, this.weights, this.theta, numLabels, numFeats);
        return this.classifier;
    }

    private void promote(int lpos, FeatureVector fv) {
        int fvisize = fv.numLocations();
        for (int fvi = 0; fvi < fvisize; ++fvi) {
            int fi = fv.indexAtLocation(fvi);
            double[] dArray = this.weights[lpos];
            int n = fi;
            dArray[n] = dArray[n] * this.alpha;
        }
    }

    private void demote(int lpos, FeatureVector fv) {
        int fvisize = fv.numLocations();
        for (int fvi = 0; fvi < fvisize; ++fvi) {
            int fi = fv.indexAtLocation(fvi);
            double[] dArray = this.weights[lpos];
            int n = fi;
            dArray[n] = dArray[n] / this.beta;
        }
    }
}

