/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.BiasedLogisticObjectiveFunction;
import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.LogisticClassifierFactory;
import edu.stanford.nlp.classify.LogisticObjectiveFunction;
import edu.stanford.nlp.classify.RVFClassifier;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.StringUtils;
import java.io.File;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Properties;

public class LogisticClassifier<L, F>
implements Classifier<L, F>,
Serializable,
RVFClassifier<L, F> {
    private static final long serialVersionUID = 6672245467246897192L;
    private double[] weights;
    private Index<F> featureIndex;
    private L[] classes = ErasureUtils.mkTArray(Object.class, 2);
    @Deprecated
    private LogPrior prior;
    @Deprecated
    private boolean biased = false;

    public String toString() {
        if (this.featureIndex == null) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        for (Object f : this.featureIndex) {
            sb.append(this.classes[1]).append(" / ").append(f).append(" = ").append(this.weights[this.featureIndex.indexOf(f)]);
        }
        return sb.toString();
    }

    public L getLabelForInternalPositiveClass() {
        return this.classes[1];
    }

    public L getLabelForInternalNegativeClass() {
        return this.classes[0];
    }

    public Counter<String> weightsAsCounter() {
        ClassicCounter<String> c = new ClassicCounter<String>();
        for (Object f : this.featureIndex) {
            c.incrementCount(this.classes[1] + " / " + f, this.weights[this.featureIndex.indexOf(f)]);
        }
        return c;
    }

    public Counter<F> weightsAsGenericCounter() {
        ClassicCounter c = new ClassicCounter();
        for (Object f : this.featureIndex) {
            double w = this.weights[this.featureIndex.indexOf(f)];
            if (w == 0.0) continue;
            c.setCount(f, w);
        }
        return c;
    }

    public Index<F> getFeatureIndex() {
        return this.featureIndex;
    }

    public double[] getWeights() {
        return this.weights;
    }

    public LogisticClassifier(double[] weights, Index<F> featureIndex, L[] classes) {
        this.weights = weights;
        this.featureIndex = featureIndex;
        this.classes = classes;
    }

    @Deprecated
    public LogisticClassifier(boolean biased) {
        this(new LogPrior(LogPrior.LogPriorType.QUADRATIC), biased);
    }

    @Deprecated
    public LogisticClassifier(LogPrior prior) {
        this.prior = prior;
    }

    @Deprecated
    public LogisticClassifier(LogPrior prior, boolean biased) {
        this.prior = prior;
        this.biased = biased;
    }

    @Override
    public Collection<L> labels() {
        LinkedList<L> l = new LinkedList<L>();
        l.add(this.classes[0]);
        l.add(this.classes[1]);
        return l;
    }

    @Override
    public L classOf(Datum<L, F> datum) {
        if (datum instanceof RVFDatum) {
            return this.classOfRVFDatum((RVFDatum)datum);
        }
        return this.classOf(datum.asFeatures());
    }

    @Override
    @Deprecated
    public L classOf(RVFDatum<L, F> example) {
        return this.classOf(example.asFeaturesCounter());
    }

    private L classOfRVFDatum(RVFDatum<L, F> example) {
        return this.classOf(example.asFeaturesCounter());
    }

    public L classOf(Counter<F> features) {
        if (this.scoreOf(features) > 0.0) {
            return this.classes[1];
        }
        return this.classes[0];
    }

    public L classOf(Collection<F> features) {
        if (this.scoreOf(features) > 0.0) {
            return this.classes[1];
        }
        return this.classes[0];
    }

    public double scoreOf(Collection<F> features) {
        double sum = 0.0;
        for (F feature : features) {
            int f = this.featureIndex.indexOf(feature);
            if (f < 0) continue;
            sum += this.weights[f];
        }
        return sum;
    }

    public double scoreOf(Counter<F> features) {
        double sum = 0.0;
        for (F feature : features.keySet()) {
            int f = this.featureIndex.indexOf(feature);
            if (f < 0) continue;
            sum += this.weights[f] * features.getCount(feature);
        }
        return sum;
    }

    public Counter<F> justificationOf(Counter<F> features) {
        ClassicCounter<F> fWts = new ClassicCounter<F>();
        for (F feature : features.keySet()) {
            int f = this.featureIndex.indexOf(feature);
            if (f < 0) continue;
            fWts.incrementCount(feature, this.weights[f] * features.getCount(feature));
        }
        return fWts;
    }

    public Counter<F> justificationOf(Collection<F> features) {
        ClassicCounter<F> fWts = new ClassicCounter<F>();
        for (F feature : features) {
            int f = this.featureIndex.indexOf(feature);
            if (f < 0) continue;
            fWts.incrementCount(feature, this.weights[f]);
        }
        return fWts;
    }

    @Override
    public Counter<L> scoresOf(Datum<L, F> datum) {
        if (datum instanceof RVFDatum) {
            return this.scoresOfRVFDatum((RVFDatum)datum);
        }
        Collection features = datum.asFeatures();
        double sum = this.scoreOf(features);
        ClassicCounter<L> c = new ClassicCounter<L>();
        c.setCount(this.classes[0], -sum);
        c.setCount(this.classes[1], sum);
        return c;
    }

    @Override
    @Deprecated
    public Counter<L> scoresOf(RVFDatum<L, F> example) {
        return this.scoresOfRVFDatum(example);
    }

    private Counter<L> scoresOfRVFDatum(RVFDatum<L, F> example) {
        Counter<F> features = example.asFeaturesCounter();
        double sum = this.scoreOf(features);
        ClassicCounter<L> c = new ClassicCounter<L>();
        c.setCount(this.classes[0], -sum);
        c.setCount(this.classes[1], sum);
        return c;
    }

    public double probabilityOf(Datum<L, F> example) {
        if (example instanceof RVFDatum) {
            return this.probabilityOfRVFDatum((RVFDatum)example);
        }
        return this.probabilityOf(example.asFeatures(), example.label());
    }

    public double probabilityOf(Collection<F> features, L label) {
        short sign = (short)(label.equals(this.classes[0]) ? 1 : -1);
        return 1.0 / (1.0 + Math.exp((double)sign * this.scoreOf(features)));
    }

    public double probabilityOf(RVFDatum<L, F> example) {
        return this.probabilityOfRVFDatum(example);
    }

    private double probabilityOfRVFDatum(RVFDatum<L, F> example) {
        return this.probabilityOf(example.asFeaturesCounter(), example.label());
    }

    public double probabilityOf(Counter<F> features, L label) {
        short sign = (short)(label.equals(this.classes[0]) ? 1 : -1);
        return 1.0 / (1.0 + Math.exp((double)sign * this.scoreOf(features)));
    }

    @Deprecated
    public void trainWeightedData(GeneralDataset<L, F> data, float[] dataWeights) {
        if (data.labelIndex.size() != 2) {
            throw new RuntimeException("LogisticClassifier is only for binary classification!");
        }
        LogisticObjectiveFunction lof = null;
        if (data instanceof Dataset) {
            lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), this.prior, dataWeights);
        } else if (data instanceof RVFDataset) {
            lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), this.prior, dataWeights);
        }
        QNMinimizer minim = new QNMinimizer(lof);
        this.weights = minim.minimize(lof, 1.0E-4, new double[data.numFeatureTypes()]);
        this.featureIndex = data.featureIndex;
        this.classes[0] = data.labelIndex.get(0);
        this.classes[1] = data.labelIndex.get(1);
    }

    @Deprecated
    public void train(GeneralDataset<L, F> data) {
        this.train(data, 0.0, 1.0E-4);
    }

    @Deprecated
    public void train(GeneralDataset<L, F> data, double l1reg, double tol) {
        if (data.labelIndex.size() != 2) {
            throw new RuntimeException("LogisticClassifier is only for binary classification!");
        }
        if (!this.biased) {
            LogisticObjectiveFunction lof = null;
            if (data instanceof Dataset) {
                lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), this.prior);
            } else if (data instanceof RVFDataset) {
                lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), this.prior);
            }
            Minimizer<DiffFunction> minim = l1reg > 0.0 ? (Minimizer)ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg) : new QNMinimizer(lof);
            this.weights = minim.minimize(lof, tol, new double[data.numFeatureTypes()]);
        } else {
            BiasedLogisticObjectiveFunction lof = new BiasedLogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), this.prior);
            Minimizer<DiffFunction> minim = l1reg > 0.0 ? (Minimizer)ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg) : new QNMinimizer(lof);
            this.weights = minim.minimize(lof, tol, new double[data.numFeatureTypes()]);
        }
        this.featureIndex = data.featureIndex;
        this.classes[0] = data.labelIndex.get(0);
        this.classes[1] = data.labelIndex.get(1);
    }

    public static void main(String[] args) throws Exception {
        Properties prop = StringUtils.argsToProperties(args);
        double l1reg = Double.parseDouble(prop.getProperty("l1reg", "0.0"));
        Dataset<String, String> ds = new Dataset<String, String>();
        for (String line : ObjectBank.getLineIterator(new File(prop.getProperty("trainFile")))) {
            String[] bits = line.split("\\s+");
            LinkedList<String> f = new LinkedList<String>(Arrays.asList(bits).subList(1, bits.length));
            String l = bits[0];
            ds.add(f, l);
        }
        ds.summaryStatistics();
        boolean biased = prop.getProperty("biased", "false").equals("true");
        LogisticClassifierFactory<String, String> factory = new LogisticClassifierFactory<String, String>();
        LogisticClassifier lc = factory.trainClassifier(ds, l1reg, 1.0E-4, biased);
        for (String line : ObjectBank.getLineIterator(new File(prop.getProperty("testFile")))) {
            String[] bits = line.split("\\s+");
            LinkedList<String> f = new LinkedList<String>(Arrays.asList(bits).subList(1, bits.length));
            String g = (String)lc.classOf(f);
            System.out.println(g + '\t' + line);
        }
    }
}

