/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.converter.mining;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.Field;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.model.UnsupportedAttributeException;

public class MiningModelUtil {
    private MiningModelUtil() {
    }

    public static MiningModel createRegression(Model model, RegressionModel.NormalizationMethod normalizationMethod, Schema schema) {
        ContinuousFeature feature = MiningModelUtil.getPrediction(model, schema);
        MathContext mathContext = model.getMathContext();
        RegressionModel regressionModel = RegressionModelUtil.createRegression(mathContext, Collections.singletonList(feature), Collections.singletonList(1.0), null, normalizationMethod, schema);
        MiningModel miningModel = MiningModelUtil.createModelChain(Arrays.asList(model, regressionModel), Segmentation.MissingPredictionTreatment.RETURN_MISSING).setMathContext(ModelUtil.simplifyMathContext(mathContext));
        return miningModel;
    }

    public static MiningModel createBinaryLogisticClassification(Model model, double coefficient, double intercept, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
        ContinuousFeature feature = MiningModelUtil.getPrediction(model, schema);
        MathContext mathContext = model.getMathContext();
        RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(mathContext, Collections.singletonList(feature), Collections.singletonList(coefficient), intercept, normalizationMethod, hasProbabilityDistribution, schema);
        MiningModel miningModel = MiningModelUtil.createModelChain(Arrays.asList(model, regressionModel), Segmentation.MissingPredictionTreatment.RETURN_MISSING).setMathContext(ModelUtil.simplifyMathContext(mathContext));
        return miningModel;
    }

    public static MiningModel createClassification(List<? extends Model> models, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
        CategoricalLabel categoricalLabel;
        block13: {
            block12: {
                categoricalLabel = (CategoricalLabel)schema.getLabel();
                SchemaUtil.checkSize(models.size(), categoricalLabel);
                if (normalizationMethod == null) break block12;
                switch (normalizationMethod) {
                    case NONE: {
                        if (categoricalLabel.size() < 3) {
                            throw new IllegalArgumentException();
                        }
                        break block13;
                    }
                    case SIMPLEMAX: 
                    case SOFTMAX: {
                        if (categoricalLabel.size() < 2) {
                            throw new IllegalArgumentException();
                        }
                        break block13;
                    }
                    default: {
                        throw new IllegalArgumentException();
                    }
                }
            }
            if (categoricalLabel.size() < 3) {
                throw new IllegalArgumentException();
            }
        }
        MathContext mathContext = null;
        ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            Model model = models.get(i);
            MathContext modelMathContext = model.getMathContext();
            if (modelMathContext == null) {
                modelMathContext = MathContext.DOUBLE;
            }
            if (mathContext == null) {
                mathContext = modelMathContext;
            } else if (!Objects.equals(mathContext, modelMathContext)) {
                throw new IllegalArgumentException();
            }
            ContinuousFeature feature = MiningModelUtil.getPrediction(model, schema);
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(mathContext, Collections.singletonList(feature), Collections.singletonList(1.0), null).setTargetCategory(categoricalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(normalizationMethod).setMathContext(ModelUtil.simplifyMathContext(mathContext)).setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(mathContext, (DiscreteLabel)categoricalLabel) : null);
        ArrayList<? extends Model> segmentationModels = new ArrayList<Model>(models);
        segmentationModels.add((Model)regressionModel);
        MiningModel miningModel = MiningModelUtil.createModelChain(segmentationModels, Segmentation.MissingPredictionTreatment.RETURN_MISSING).setMathContext(ModelUtil.simplifyMathContext(mathContext));
        return miningModel;
    }

    public static MiningModel createModelChain(List<? extends Model> models, Segmentation.MissingPredictionTreatment missingPredictionTreatment) {
        if (models.isEmpty()) {
            throw new IllegalArgumentException();
        }
        Model lastModel = (Model)Iterables.getLast(models);
        MiningFunction miningFunction = lastModel.requireMiningFunction();
        MiningModel miningModel = new MiningModel(miningFunction, MiningModelUtil.createMiningSchema(models)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, missingPredictionTreatment, models));
        return miningModel;
    }

    public static MiningModel createMultiModelChain(List<? extends Model> models, Segmentation.MissingPredictionTreatment missingPredictionTreatment) {
        if (models.isEmpty()) {
            throw new IllegalArgumentException();
        }
        MiningFunction miningFunction = null;
        for (Model model : models) {
            MiningFunction modelMiningFunction = model.requireMiningFunction();
            if (miningFunction == null) {
                miningFunction = modelMiningFunction;
                continue;
            }
            if (miningFunction == MiningFunction.MIXED || Objects.equals(miningFunction, modelMiningFunction)) continue;
            miningFunction = MiningFunction.MIXED;
        }
        MiningModel miningModel = new MiningModel(miningFunction, MiningModelUtil.createMiningSchema(models)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MULTI_MODEL_CHAIN, missingPredictionTreatment, models));
        return miningModel;
    }

    public static MiningSchema createMiningSchema(List<? extends Model> models) {
        MiningSchema miningSchema = new MiningSchema();
        models.stream().map(Model::requireMiningSchema).map(MiningSchema::getMiningFields).flatMap(Collection::stream).filter(miningField -> {
            MiningField.UsageType usageType = miningField.getUsageType();
            switch (usageType) {
                case PREDICTED: 
                case TARGET: {
                    return true;
                }
            }
            return false;
        }).map(MiningField::getName).distinct().map(name -> ModelUtil.createMiningField(name, MiningField.UsageType.TARGET)).forEach(xva$0 -> miningSchema.addMiningFields(new MiningField[]{xva$0}));
        return miningSchema;
    }

    public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, List<? extends Model> models) {
        return MiningModelUtil.createSegmentation(multipleModelMethod, missingPredictionTreatment, models, null);
    }

    public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, List<? extends Model> models, List<? extends Number> weights) {
        if (weights != null && models.size() != weights.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList<Segment> segments = new ArrayList<Segment>();
        for (int i = 0; i < models.size(); ++i) {
            Model model = models.get(i);
            Number weight = weights != null ? (Number)weights.get(i) : (Number)null;
            Segment segment = new Segment((Predicate)True.INSTANCE, model).setId(String.valueOf(i + 1));
            if (weight != null && !ValueUtil.isOne(weight)) {
                segment.setWeight(weight);
            }
            segments.add(segment);
        }
        Segmentation segmentation = new Segmentation(multipleModelMethod, segments).setMissingPredictionTreatment(missingPredictionTreatment);
        return segmentation;
    }

    public static Model getFinalModel(Model model) {
        if (model instanceof MiningModel) {
            MiningModel miningModel = (MiningModel)model;
            return MiningModelUtil.getFinalModel(miningModel);
        }
        return model;
    }

    public static Model getFinalModel(MiningModel miningModel) {
        Segmentation segmentation = miningModel.requireSegmentation();
        Segmentation.MultipleModelMethod multipleModelMethod = segmentation.requireMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL: {
                throw new UnsupportedAttributeException((PMMLObject)segmentation, (Enum)multipleModelMethod);
            }
            case MODEL_CHAIN: {
                if (!MiningModelUtil.isChain(segmentation)) break;
                List segments = segmentation.requireSegments();
                Segment finalSegment = (Segment)segments.get(segments.size() - 1);
                True _true = (True)finalSegment.requirePredicate(True.class);
                Model model = finalSegment.requireModel();
                return MiningModelUtil.getFinalModel(model);
            }
        }
        return miningModel;
    }

    public static boolean isChain(Segmentation segmentation) {
        List segments = segmentation.requireSegments();
        for (Segment segment : segments) {
            Predicate predicate = segment.requirePredicate();
            if (predicate instanceof True) continue;
            return false;
        }
        return true;
    }

    public static void optimizeOutputFields(MiningModel miningModel) {
        Segmentation segmentation = miningModel.requireSegmentation();
        Map<String, OutputField> commonOutputFields = MiningModelUtil.collectCommonOutputFields(segmentation);
        if (!commonOutputFields.isEmpty()) {
            Output output = ModelUtil.ensureOutput((Model)miningModel);
            MiningModelUtil.removeCommonOutputFields(segmentation, commonOutputFields.keySet());
            List outputFields = output.getOutputFields();
            outputFields.addAll(commonOutputFields.values());
        }
    }

    private static Map<String, OutputField> collectCommonOutputFields(Segmentation segmentation) {
        List segments = segmentation.requireSegments();
        Map<Object, Object> result = null;
        for (Segment segment : segments) {
            Model model = segment.requireModel();
            Model finalModel = MiningModelUtil.getFinalModel(model);
            Output output = finalModel.getOutput();
            if (output != null && output.hasOutputFields()) {
                List outputFields = output.getOutputFields();
                if (result == null) {
                    result = outputFields.stream().filter(outputField -> {
                        ResultFeature resultFeature = outputField.getResultFeature();
                        switch (resultFeature) {
                            case PROBABILITY: 
                            case AFFINITY: {
                                return true;
                            }
                        }
                        return false;
                    }).collect(Collectors.toMap(outputField -> outputField.requireName(), outputField -> outputField));
                } else {
                    LinkedHashSet<String> names = new LinkedHashSet<String>();
                    for (OutputField outputField2 : outputFields) {
                        String name = outputField2.requireName();
                        names.add(name);
                        OutputField commonOutputField = (OutputField)result.get(name);
                        if (commonOutputField == null || ReflectionUtil.equals((PMMLObject)outputField2, (PMMLObject)commonOutputField)) continue;
                        result.remove(name);
                    }
                    result.keySet().retainAll(names);
                }
            } else {
                result = Collections.emptyMap();
            }
            if (!result.isEmpty()) continue;
            break;
        }
        return result;
    }

    private static void removeCommonOutputFields(Segmentation segmentation, Set<String> names) {
        List segments = segmentation.requireSegments();
        for (Segment segment : segments) {
            Model model = segment.requireModel();
            Model finalModel = MiningModelUtil.getFinalModel(model);
            Output output = finalModel.getOutput();
            if (output == null || !output.hasOutputFields()) continue;
            List outputFields = output.getOutputFields();
            outputFields.removeIf(outputField -> names.contains(outputField.requireName()));
            if (!outputFields.isEmpty()) continue;
            finalModel.setOutput(null);
        }
    }

    private static ContinuousFeature getPrediction(Model model, Schema schema) {
        Output output = model.getOutput();
        if (output == null || !output.hasOutputFields()) {
            throw new InvalidElementException((PMMLObject)model);
        }
        OutputField outputField = (OutputField)Iterables.getLast((Iterable)output.getOutputFields());
        ModelEncoder encoder = schema.getEncoder();
        return new ContinuousFeature((PMMLEncoder)encoder, (Field<?>)outputField);
    }
}

