/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.List;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.naive_bayes.BayesInput;
import org.dmg.pmml.naive_bayes.BayesInputs;
import org.dmg.pmml.naive_bayes.BayesOutput;
import org.dmg.pmml.naive_bayes.NaiveBayesModel;
import org.dmg.pmml.naive_bayes.PairCounts;
import org.dmg.pmml.naive_bayes.TargetValueCount;
import org.dmg.pmml.naive_bayes.TargetValueCounts;
import org.dmg.pmml.naive_bayes.TargetValueStat;
import org.dmg.pmml.naive_bayes.TargetValueStats;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RStringVector;

public class NaiveBayesConverter
extends ModelConverter<RGenericVector> {
    public NaiveBayesConverter(RGenericVector naiveBayes) {
        super(naiveBayes);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector naiveBayes = (RGenericVector)this.getObject();
        RGenericVector tables = naiveBayes.getGenericElement("tables");
        RStringVector levels = naiveBayes.getStringElement("levels");
        DataField dataField = encoder.createDataField("_target", OpType.CATEGORICAL, DataType.STRING, levels.getValues());
        encoder.setLabel(dataField);
        RStringVector tableNames = tables.names();
        for (int i = 0; i < tables.size(); ++i) {
            RDoubleVector table = tables.getDoubleValue(i);
            RStringVector tableRows = table.dimnames(0);
            RStringVector tableColumns = table.dimnames(1);
            String name = tableNames.getValue(i);
            DataField dataField2 = tableColumns != null ? encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING, tableColumns.getValues()) : encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
            encoder.addFeature((Field<?>)dataField2);
        }
    }

    @Override
    public Model encodeModel(Schema schema) {
        RGenericVector naiveBayes = (RGenericVector)this.getObject();
        RIntegerVector apriori = naiveBayes.getIntegerElement("apriori");
        RGenericVector tables = naiveBayes.getGenericElement("tables");
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        List features = schema.getFeatures();
        BayesInputs bayesInputs = new BayesInputs();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = (Feature)features.get(i);
            String name = feature.getName();
            RDoubleVector table = tables.getDoubleElement(name);
            RStringVector tableRows = table.dimnames(0);
            RStringVector tableColumns = table.dimnames(1);
            BayesInput bayesInput = new BayesInput(name, null, null);
            if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                for (int column = 0; column < tableColumns.size(); ++column) {
                    TargetValueCounts targetValueCounts = new TargetValueCounts();
                    List probabilities = FortranMatrixUtil.getColumn(table.getValues(), (int)tableRows.size(), (int)tableColumns.size(), (int)column);
                    for (int row = 0; row < tableRows.size(); ++row) {
                        double count = (double)apriori.getValue(row).intValue() * (Double)probabilities.get(row);
                        TargetValueCount targetValueCount = new TargetValueCount((Object)tableRows.getValue(row), (Number)count);
                        targetValueCounts.addTargetValueCounts(new TargetValueCount[]{targetValueCount});
                    }
                    PairCounts pairCounts = new PairCounts((Object)tableColumns.getValue(column), targetValueCounts);
                    bayesInput.addPairCounts(new PairCounts[]{pairCounts});
                }
            } else if (feature instanceof ContinuousFeature) {
                ContinuousFeature continuousFeature = (ContinuousFeature)feature;
                TargetValueStats targetValueStats = new TargetValueStats();
                for (int row = 0; row < tableRows.size(); ++row) {
                    List stats = FortranMatrixUtil.getRow(table.getValues(), (int)tableRows.size(), (int)2, (int)row);
                    double mean = (Double)stats.get(0);
                    double variance = Math.pow((Double)stats.get(1), 2.0);
                    TargetValueStat targetValueStat = new TargetValueStat((Object)tableRows.getValue(row), (ContinuousDistribution)new GaussianDistribution((Number)mean, (Number)variance));
                    targetValueStats.addTargetValueStats(new TargetValueStat[]{targetValueStat});
                }
                bayesInput.setTargetValueStats(targetValueStats);
            } else {
                throw new IllegalArgumentException();
            }
            bayesInputs.addBayesInputs(new BayesInput[]{bayesInput});
        }
        BayesOutput bayesOutput = new BayesOutput().setTargetField(categoricalLabel.getName());
        TargetValueCounts targetValueCounts = new TargetValueCounts();
        RStringVector aprioriRows = apriori.dimnames(0);
        for (int row = 0; row < aprioriRows.size(); ++row) {
            int count = apriori.getValue(row);
            TargetValueCount targetValueCount = new TargetValueCount((Object)aprioriRows.getValue(row), (Number)count);
            targetValueCounts.addTargetValueCounts(new TargetValueCount[]{targetValueCount});
        }
        bayesOutput.setTargetValueCounts(targetValueCounts);
        NaiveBayesModel naiveBayesModel = new NaiveBayesModel((Number)0.0, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), bayesInputs, bayesOutput).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (DiscreteLabel)categoricalLabel));
        return naiveBayesModel;
    }
}

