/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.bayesian.training;

import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.bayesian.BayesianEvent;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.bayesian.training.BayesianInit;
import org.encog.ml.bayesian.training.estimator.BayesEstimator;
import org.encog.ml.bayesian.training.estimator.SimpleEstimator;
import org.encog.ml.bayesian.training.search.k2.BayesSearch;
import org.encog.ml.bayesian.training.search.k2.SearchK2;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

public class TrainBayesian
extends BasicTraining {
    private Phase p = Phase.Init;
    private final MLDataSet data;
    private final BayesianNetwork network;
    private final int maximumParents;
    private final BayesSearch search;
    private final BayesEstimator estimator;
    private BayesianInit initNetwork = BayesianInit.InitNaiveBayes;
    private String holdQuery;

    public TrainBayesian(BayesianNetwork theNetwork, MLDataSet theData, int theMaximumParents) {
        this(theNetwork, theData, theMaximumParents, BayesianInit.InitNaiveBayes, new SearchK2(), new SimpleEstimator());
    }

    public TrainBayesian(BayesianNetwork theNetwork, MLDataSet theData, int theMaximumParents, BayesianInit theInit, BayesSearch theSearch, BayesEstimator theEstimator) {
        super(TrainingImplementationType.Iterative);
        this.network = theNetwork;
        this.data = theData;
        this.maximumParents = theMaximumParents;
        this.search = theSearch;
        this.search.init(this, theNetwork, theData);
        this.estimator = theEstimator;
        this.estimator.init(this, theNetwork, theData);
        this.initNetwork = theInit;
        this.setError(1.0);
    }

    private void initNaiveBayes() {
        this.network.removeAllRelations();
        BayesianEvent classificationTarget = this.network.getClassificationTargetEvent();
        for (BayesianEvent event : this.network.getEvents()) {
            if (event == classificationTarget) continue;
            this.network.createDependency(classificationTarget, event);
        }
        this.network.finalizeStructure();
    }

    private void iterationInit() {
        this.holdQuery = this.network.getClassificationStructure();
        switch (this.initNetwork) {
            case InitEmpty: {
                this.network.removeAllRelations();
                this.network.finalizeStructure();
                break;
            }
            case InitNoChange: {
                break;
            }
            case InitNaiveBayes: {
                this.initNaiveBayes();
            }
        }
        this.p = Phase.Search;
    }

    private void iterationSearch() {
        if (!this.search.iteration()) {
            this.p = Phase.SearchDone;
        }
    }

    private void iterationSearchDone() {
        this.network.finalizeStructure();
        this.network.reset();
        this.p = Phase.Probability;
    }

    private void iterationProbability() {
        if (!this.estimator.iteration()) {
            this.p = Phase.Finish;
        }
    }

    private void iterationFinish() {
        this.network.defineClassificationStructure(this.holdQuery);
        this.setError(this.network.calculateError(this.data));
        this.p = Phase.Terminated;
    }

    @Override
    public boolean isTrainingDone() {
        if (super.isTrainingDone()) {
            return true;
        }
        return this.p == Phase.Terminated;
    }

    @Override
    public void iteration() {
        this.postIteration();
        switch (this.p) {
            case Init: {
                this.iterationInit();
                break;
            }
            case Search: {
                this.iterationSearch();
                break;
            }
            case SearchDone: {
                this.iterationSearchDone();
                break;
            }
            case Probability: {
                this.iterationProbability();
                break;
            }
            case Finish: {
                this.iterationFinish();
            }
        }
        this.preIteration();
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    @Override
    public MLMethod getMethod() {
        return this.network;
    }

    public BayesianNetwork getNetwork() {
        return this.network;
    }

    public int getMaximumParents() {
        return this.maximumParents;
    }

    public BayesSearch getSearch() {
        return this.search;
    }

    public BayesianInit getInitNetwork() {
        return this.initNetwork;
    }

    public void setInitNetwork(BayesianInit initNetwork) {
        this.initNetwork = initNetwork;
    }

    private static enum Phase {
        Init,
        Search,
        SearchDone,
        Probability,
        Finish,
        Terminated;

    }
}

