/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AODE
extends Classifier
implements OptionHandler,
WeightedInstancesHandler,
UpdateableClassifier,
TechnicalInformationHandler {
    static final long serialVersionUID = 9197439980415113523L;
    private double[][][] m_CondiCounts;
    private double[] m_ClassCounts;
    private double[][] m_SumForCounts;
    private int m_NumClasses;
    private int m_NumAttributes;
    private int m_NumInstances;
    private int m_ClassIndex;
    private Instances m_Instances;
    private int m_TotalAttValues;
    private int[] m_StartAttIndex;
    private int[] m_NumAttValues;
    private double[] m_Frequencies;
    private double m_SumInstances;
    private int m_Limit = 1;
    private boolean m_Debug = false;
    private boolean m_MEstimates = false;
    private int m_Weight = 1;

    public String globalInfo() {
        return "AODE achieves highly accurate classification by averaging over all of a small space of alternative naive-Bayes-like models that have weaker (and hence less detrimental) independence assumptions than naive Bayes. The resulting algorithm is computationally efficient while delivering highly accurate classification on many learning  tasks.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString() + "\n\n" + "Further papers are available at\n" + "  http://www.csse.monash.edu.au/~webb/.\n\n" + "Can use an m-estimate for smoothing base probability estimates " + "in place of the Laplace correction (via option -M).\n" + "Default frequency limit set to 1.";
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "G. Webb and J. Boughton and Z. Wang");
        result.setValue(TechnicalInformation.Field.YEAR, "2005");
        result.setValue(TechnicalInformation.Field.TITLE, "Not So Naive Bayes: Aggregating One-Dependence Estimators");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        result.setValue(TechnicalInformation.Field.VOLUME, "58");
        result.setValue(TechnicalInformation.Field.NUMBER, "1");
        result.setValue(TechnicalInformation.Field.PAGES, "5-24");
        return result;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        this.m_Instances = new Instances(instances);
        this.m_Instances.deleteWithMissingClass();
        this.m_SumInstances = 0.0;
        this.m_ClassIndex = instances.classIndex();
        this.m_NumInstances = this.m_Instances.numInstances();
        this.m_NumAttributes = this.m_Instances.numAttributes();
        this.m_NumClasses = this.m_Instances.numClasses();
        this.m_StartAttIndex = new int[this.m_NumAttributes];
        this.m_NumAttValues = new int[this.m_NumAttributes];
        this.m_TotalAttValues = 0;
        for (int i = 0; i < this.m_NumAttributes; ++i) {
            if (i != this.m_ClassIndex) {
                this.m_StartAttIndex[i] = this.m_TotalAttValues;
                this.m_NumAttValues[i] = this.m_Instances.attribute(i).numValues();
                this.m_TotalAttValues += this.m_NumAttValues[i] + 1;
                continue;
            }
            this.m_NumAttValues[i] = this.m_NumClasses;
        }
        this.m_CondiCounts = new double[this.m_NumClasses][this.m_TotalAttValues][this.m_TotalAttValues];
        this.m_ClassCounts = new double[this.m_NumClasses];
        this.m_SumForCounts = new double[this.m_NumClasses][this.m_NumAttributes];
        this.m_Frequencies = new double[this.m_TotalAttValues];
        for (int k = 0; k < this.m_NumInstances; ++k) {
            this.addToCounts(this.m_Instances.instance(k));
        }
        this.m_Instances = new Instances(this.m_Instances, 0);
    }

    public void updateClassifier(Instance instance) {
        this.addToCounts(instance);
    }

    private void addToCounts(Instance instance) {
        if (instance.classIsMissing()) {
            return;
        }
        int classVal = (int)instance.classValue();
        double weight = instance.weight();
        int n = classVal;
        this.m_ClassCounts[n] = this.m_ClassCounts[n] + weight;
        this.m_SumInstances += weight;
        int[] attIndex = new int[this.m_NumAttributes];
        for (int i = 0; i < this.m_NumAttributes; ++i) {
            attIndex[i] = i == this.m_ClassIndex ? -1 : (instance.isMissing(i) ? this.m_StartAttIndex[i] + this.m_NumAttValues[i] : this.m_StartAttIndex[i] + (int)instance.value(i));
        }
        for (int Att1 = 0; Att1 < this.m_NumAttributes; ++Att1) {
            if (attIndex[Att1] == -1) continue;
            int n2 = attIndex[Att1];
            this.m_Frequencies[n2] = this.m_Frequencies[n2] + weight;
            if (!instance.isMissing(Att1)) {
                double[] dArray = this.m_SumForCounts[classVal];
                int n3 = Att1;
                dArray[n3] = dArray[n3] + weight;
            }
            double[] countsPointer = this.m_CondiCounts[classVal][attIndex[Att1]];
            for (int Att2 = 0; Att2 < this.m_NumAttributes; ++Att2) {
                if (attIndex[Att2] == -1) continue;
                int n4 = attIndex[Att2];
                countsPointer[n4] = countsPointer[n4] + weight;
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] probs = new double[this.m_NumClasses];
        int[] attIndex = new int[this.m_NumAttributes];
        for (int att = 0; att < this.m_NumAttributes; ++att) {
            attIndex[att] = instance.isMissing(att) || att == this.m_ClassIndex ? -1 : this.m_StartAttIndex[att] + (int)instance.value(att);
        }
        for (int classVal = 0; classVal < this.m_NumClasses; ++classVal) {
            probs[classVal] = 0.0;
            double spodeP = 0.0;
            int parentCount = 0;
            double[][] countsForClass = this.m_CondiCounts[classVal];
            for (int parent = 0; parent < this.m_NumAttributes; ++parent) {
                int pIndex;
                if (attIndex[parent] == -1 || this.m_Frequencies[pIndex = attIndex[parent]] < (double)this.m_Limit) continue;
                double[] countsForClassParent = countsForClass[pIndex];
                attIndex[parent] = -1;
                ++parentCount;
                double classparentfreq = countsForClassParent[pIndex];
                double missing4ParentAtt = this.m_Frequencies[this.m_StartAttIndex[parent] + this.m_NumAttValues[parent]];
                spodeP = !this.m_MEstimates ? (classparentfreq + 1.0) / (this.m_SumInstances - missing4ParentAtt + (double)(this.m_NumClasses * this.m_NumAttValues[parent])) : (classparentfreq + (double)this.m_Weight / (double)(this.m_NumClasses * this.m_NumAttValues[parent])) / (this.m_SumInstances - missing4ParentAtt + (double)this.m_Weight);
                for (int att = 0; att < this.m_NumAttributes; ++att) {
                    if (attIndex[att] == -1) continue;
                    double missingForParentandChildAtt = countsForClassParent[this.m_StartAttIndex[att] + this.m_NumAttValues[att]];
                    if (!this.m_MEstimates) {
                        spodeP *= (countsForClassParent[attIndex[att]] + 1.0) / (classparentfreq - missingForParentandChildAtt + (double)this.m_NumAttValues[att]);
                        continue;
                    }
                    spodeP *= (countsForClassParent[attIndex[att]] + (double)this.m_Weight / (double)this.m_NumAttValues[att]) / (classparentfreq - missingForParentandChildAtt + (double)this.m_Weight);
                }
                int n = classVal;
                probs[n] = probs[n] + spodeP;
                attIndex[parent] = pIndex;
            }
            if (parentCount < 1) {
                probs[classVal] = this.NBconditionalProb(instance, classVal);
                continue;
            }
            int n = classVal;
            probs[n] = probs[n] / (double)parentCount;
        }
        Utils.normalize(probs);
        return probs;
    }

    public double NBconditionalProb(Instance instance, int classVal) {
        double prob = !this.m_MEstimates ? (this.m_ClassCounts[classVal] + 1.0) / (this.m_SumInstances + (double)this.m_NumClasses) : (this.m_ClassCounts[classVal] + (double)this.m_Weight / (double)this.m_NumClasses) / (this.m_SumInstances + (double)this.m_Weight);
        double[][] pointer = this.m_CondiCounts[classVal];
        for (int att = 0; att < this.m_NumAttributes; ++att) {
            if (att == this.m_ClassIndex || instance.isMissing(att)) continue;
            int aIndex = this.m_StartAttIndex[att] + (int)instance.value(att);
            if (!this.m_MEstimates) {
                prob *= (pointer[aIndex][aIndex] + 1.0) / (this.m_SumForCounts[classVal][att] + (double)this.m_NumAttValues[att]);
                continue;
            }
            prob *= (pointer[aIndex][aIndex] + (double)this.m_Weight / (double)this.m_NumAttValues[att]) / (this.m_SumForCounts[classVal][att] + (double)this.m_Weight);
        }
        return prob;
    }

    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(4);
        newVector.addElement(new Option("\tOutput debugging information\n", "D", 0, "-D"));
        newVector.addElement(new Option("\tImpose a frequency limit for superParents\n\t(default is 1)", "F", 1, "-F <int>"));
        newVector.addElement(new Option("\tUse m-estimate instead of laplace correction\n", "M", 0, "-M"));
        newVector.addElement(new Option("\tSpecify a weight to use with m-estimate\n\t(default is 1)", "W", 1, "-W <int>"));
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        this.m_Debug = Utils.getFlag('D', options);
        String Freq = Utils.getOption('F', options);
        this.m_Limit = Freq.length() != 0 ? Integer.parseInt(Freq) : 1;
        this.m_MEstimates = Utils.getFlag('M', options);
        String weight = Utils.getOption('W', options);
        if (weight.length() != 0) {
            if (!this.m_MEstimates) {
                throw new Exception("Can't use Laplace AND m-estimate weight. Choose one.");
            }
            this.m_Weight = Integer.parseInt(weight);
        } else if (this.m_MEstimates) {
            this.m_Weight = 1;
        }
        Utils.checkForRemainingOptions(options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.m_Debug) {
            result.add("-D");
        }
        result.add("-F");
        result.add("" + this.m_Limit);
        if (this.m_MEstimates) {
            result.add("-M");
            result.add("-W");
            result.add("" + this.m_Weight);
        }
        return result.toArray(new String[result.size()]);
    }

    public String weightTipText() {
        return "Set the weight for m-estimate.";
    }

    public void setWeight(int w) {
        if (!this.getUseMEstimates()) {
            System.out.println("Weight is only used in conjunction with m-estimate - ignored!");
        } else if (w > 0) {
            this.m_Weight = w;
        } else {
            System.out.println("Weight must be greater than 0!");
        }
    }

    public int getWeight() {
        return this.m_Weight;
    }

    public String useMEstimatesTipText() {
        return "Use m-estimate instead of laplace correction.";
    }

    public boolean getUseMEstimates() {
        return this.m_MEstimates;
    }

    public void setUseMEstimates(boolean value) {
        this.m_MEstimates = value;
    }

    public String frequencyLimitTipText() {
        return "Attributes with a frequency in the train set below this value aren't used as parents.";
    }

    public void setFrequencyLimit(int f) {
        this.m_Limit = f;
    }

    public int getFrequencyLimit() {
        return this.m_Limit;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        text.append("The AODE Classifier");
        if (this.m_Instances == null) {
            text.append(": No model built yet.");
        } else {
            try {
                for (int i = 0; i < this.m_NumClasses; ++i) {
                    text.append("\nClass " + this.m_Instances.classAttribute().value(i) + ": Prior probability = " + Utils.doubleToString((this.m_ClassCounts[i] + 1.0) / (this.m_SumInstances + (double)this.m_NumClasses), 4, 2) + "\n\n");
                }
                text.append("Dataset: " + this.m_Instances.relationName() + "\n" + "Instances: " + this.m_NumInstances + "\n" + "Attributes: " + this.m_NumAttributes + "\n" + "Frequency limit for superParents: " + this.m_Limit + "\n");
                text.append("Correction: ");
                if (!this.m_MEstimates) {
                    text.append("laplace\n");
                } else {
                    text.append("m-estimate (m=" + this.m_Weight + ")\n");
                }
            }
            catch (Exception ex) {
                text.append(ex.getMessage());
            }
        }
        return text.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5516 $");
    }

    public static void main(String[] argv) {
        AODE.runClassifier(new AODE(), argv);
    }
}

