/*
 * Decompiled with CFR 0.152.
 */
package marytts.unitselection.select.viterbi;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.text.DecimalFormat;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import marytts.exceptions.SynthesisException;
import marytts.unitselection.data.DiphoneUnit;
import marytts.unitselection.data.Unit;
import marytts.unitselection.data.UnitDatabase;
import marytts.unitselection.select.DiphoneTarget;
import marytts.unitselection.select.HalfPhoneTarget;
import marytts.unitselection.select.JoinCostFunction;
import marytts.unitselection.select.SelectedUnit;
import marytts.unitselection.select.StatisticalCostFunction;
import marytts.unitselection.select.Target;
import marytts.unitselection.select.TargetCostFunction;
import marytts.unitselection.select.viterbi.ViterbiCandidate;
import marytts.unitselection.select.viterbi.ViterbiPath;
import marytts.unitselection.select.viterbi.ViterbiPoint;
import marytts.util.MaryUtils;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

public class Viterbi {
    protected int beamSize;
    protected final float wTargetCosts;
    protected final float wJoinCosts;
    protected final float wSCosts;
    protected ViterbiPoint firstPoint = null;
    protected ViterbiPoint lastPoint = null;
    private UnitDatabase database;
    protected TargetCostFunction targetCostFunction;
    protected JoinCostFunction joinCostFunction;
    protected StatisticalCostFunction sCostFunction;
    protected Logger logger;
    protected double cumulJoinCosts;
    protected int nJoinCosts;
    protected double cumulTargetCosts;
    protected int nTargetCosts;
    private static Map<UnitDatabase, DebugStats> debugStats = new HashMap<UnitDatabase, DebugStats>();

    public Viterbi(List<Target> targets, UnitDatabase database, float wTargetCosts, int beamSize) {
        this.database = database;
        this.targetCostFunction = database.getTargetCostFunction();
        this.joinCostFunction = database.getJoinCostFunction();
        this.sCostFunction = database.getSCostFunction();
        this.logger = MaryUtils.getLogger("Viterbi");
        this.wTargetCosts = wTargetCosts;
        this.wJoinCosts = 1.0f - wTargetCosts;
        this.wSCosts = 0.0f;
        this.beamSize = beamSize;
        this.cumulJoinCosts = 0.0;
        this.nJoinCosts = 0;
        this.cumulTargetCosts = 0.0;
        this.nTargetCosts = 0;
        ViterbiPoint last = null;
        for (Target target : targets) {
            ViterbiPoint nextPoint = new ViterbiPoint(target);
            if (last != null) {
                last.setNext(nextPoint);
            } else {
                this.firstPoint = nextPoint;
                this.firstPoint.getPaths().add(new ViterbiPath(null, null, 0.0));
            }
            last = nextPoint;
        }
        this.lastPoint = new ViterbiPoint(null);
        last.setNext(this.lastPoint);
        if (beamSize == 0) {
            throw new IllegalStateException("General beam search not implemented");
        }
    }

    public Viterbi(List<Target> targets, UnitDatabase database, float wTargetCosts, float wSCosts, int beamSize) {
        this.database = database;
        this.targetCostFunction = database.getTargetCostFunction();
        this.joinCostFunction = database.getJoinCostFunction();
        this.sCostFunction = database.getSCostFunction();
        this.logger = MaryUtils.getLogger("Viterbi");
        this.wTargetCosts = wTargetCosts;
        this.wSCosts = wSCosts;
        this.wJoinCosts = 1.0f - (wTargetCosts + wSCosts);
        this.beamSize = beamSize;
        this.cumulJoinCosts = 0.0;
        this.nJoinCosts = 0;
        this.cumulTargetCosts = 0.0;
        this.nTargetCosts = 0;
        ViterbiPoint last = null;
        for (Target target : targets) {
            ViterbiPoint nextPoint = new ViterbiPoint(target);
            if (last != null) {
                last.setNext(nextPoint);
            } else {
                this.firstPoint = nextPoint;
                this.firstPoint.getPaths().add(new ViterbiPath(null, null, 0.0));
            }
            last = nextPoint;
        }
        this.lastPoint = new ViterbiPoint(null);
        last.setNext(this.lastPoint);
        if (beamSize == 0) {
            throw new IllegalStateException("General beam search not implemented");
        }
    }

    public void apply() throws SynthesisException {
        this.logger.debug("Viterbi running with beam size " + this.beamSize);
        ViterbiPoint point = this.firstPoint;
        while (point.next != null) {
            Target target = point.target;
            List<ViterbiCandidate> candidates = this.database.getCandidates(target);
            if (candidates.size() == 0) {
                if (target instanceof DiphoneTarget) {
                    this.logger.debug("No diphone '" + target.getName() + "' -- will build from halfphones");
                    DiphoneTarget dt = (DiphoneTarget)target;
                    HalfPhoneTarget left = dt.left;
                    HalfPhoneTarget right = dt.right;
                    point.setTarget(left);
                    ViterbiPoint newP = new ViterbiPoint(right);
                    newP.next = point.next;
                    point.next = newP;
                    candidates = this.database.getCandidates(left);
                    if (candidates.size() == 0) {
                        throw new SynthesisException("Cannot even find any halfphone unit for target " + left);
                    }
                } else {
                    throw new SynthesisException("Cannot find any units for target " + target);
                }
            }
            assert (candidates.size() > 0);
            Collections.sort(candidates);
            point.candidates = candidates;
            assert (this.beamSize != 0);
            List<ViterbiPath> paths = point.paths;
            int nPaths = paths.size();
            if (this.beamSize != -1 && this.beamSize < nPaths) {
                nPaths = this.beamSize;
            }
            int i = 0;
            int iMax = nPaths;
            for (ViterbiPath pp : paths) {
                assert (pp != null);
                candidates = point.candidates;
                assert (candidates != null);
                int j = 0;
                int jMax = this.beamSize;
                for (ViterbiCandidate c : candidates) {
                    ViterbiPath np = this.getPath(pp, c);
                    this.addPath(point.next, np);
                    if (++j == jMax) break;
                }
                if (++i == iMax) break;
            }
            point = point.next;
        }
    }

    void addPath(ViterbiPoint point, ViterbiPath newPath) {
        ViterbiCandidate candidate = newPath.candidate;
        assert (candidate != null);
        ViterbiPath bestPathSoFar = candidate.bestPath;
        List<ViterbiPath> paths = point.getPaths();
        if (bestPathSoFar == null) {
            paths.add(newPath);
            candidate.setBestPath(newPath);
        } else if (newPath.score < bestPathSoFar.score) {
            paths.remove(bestPathSoFar);
            paths.add(newPath);
            candidate.setBestPath(newPath);
        }
    }

    public List<SelectedUnit> getSelectedUnits() {
        LinkedList<SelectedUnit> selectedUnits = new LinkedList<SelectedUnit>();
        if (this.firstPoint == null || this.firstPoint.getNext() == null) {
            return selectedUnits;
        }
        ViterbiPath best = this.findBestPath();
        if (best == null) {
            return null;
        }
        ViterbiPath path = best;
        while (path != null) {
            if (path.candidate != null) {
                Unit u = path.candidate.unit;
                Target t = path.candidate.target;
                if (u instanceof DiphoneUnit) {
                    assert (t instanceof DiphoneTarget);
                    DiphoneUnit du = (DiphoneUnit)u;
                    DiphoneTarget dt = (DiphoneTarget)t;
                    selectedUnits.addFirst(new SelectedUnit(du.right, dt.right));
                    selectedUnits.addFirst(new SelectedUnit(du.left, dt.left));
                } else {
                    selectedUnits.addFirst(new SelectedUnit(u, t));
                }
            }
            path = path.getPrevious();
        }
        if (this.logger.getEffectiveLevel().equals(Level.DEBUG)) {
            StringWriter sw = new StringWriter();
            PrintWriter pw = new PrintWriter(sw);
            int prevIndex = -1;
            int[] lengthHistogram = new int[10];
            int length = 0;
            int numUnits = selectedUnits.size();
            StringBuilder line = new StringBuilder();
            int i = 0;
            while (i < numUnits) {
                SelectedUnit u = selectedUnits.get(i);
                int index = u.getUnit().index;
                if (prevIndex + 1 == index) {
                    ++length;
                } else {
                    if (lengthHistogram.length <= length) {
                        int[] dummy = new int[length + 1];
                        System.arraycopy(lengthHistogram, 0, dummy, 0, lengthHistogram.length);
                        lengthHistogram = dummy;
                    }
                    int n = length;
                    lengthHistogram[n] = lengthHistogram[n] + 1;
                    pw.print(line);
                    if (i > 0) {
                        assert (i >= length);
                        Unit firstUnitInStretch = selectedUnits.get(i - length).getUnit();
                        String origin = this.database.getFilenameAndTime(firstUnitInStretch);
                        int col = line.length();
                        while (col < 80) {
                            pw.print(" ");
                            ++col;
                        }
                        pw.print(origin);
                    }
                    pw.println();
                    length = 1;
                    line.setLength(0);
                }
                line.append(String.valueOf(this.database.getTargetCostFunction().getFeature(u.getUnit(), "phone")) + "(" + u.getUnit().index + ")");
                prevIndex = index;
                ++i;
            }
            if (lengthHistogram.length <= length) {
                int[] dummy = new int[length + 1];
                System.arraycopy(lengthHistogram, 0, dummy, 0, lengthHistogram.length);
                lengthHistogram = dummy;
            }
            int n = length;
            lengthHistogram[n] = lengthHistogram[n] + 1;
            pw.print(line);
            Unit firstUnitInStretch = selectedUnits.get(numUnits - length).getUnit();
            String origin = this.database.getFilenameAndTime(firstUnitInStretch);
            int col = line.length();
            while (col < 80) {
                pw.print(" ");
                ++col;
            }
            pw.print(origin);
            pw.println();
            this.logger.debug("Selected units:\n" + sw.toString());
            int total = 0;
            int nStretches = 0;
            int l = 1;
            while (l < lengthHistogram.length) {
                total += lengthHistogram[l] * l;
                nStretches += lengthHistogram[l];
                ++l;
            }
            float avgLength = (float)total / (float)nStretches;
            DecimalFormat df = new DecimalFormat("0.000");
            this.logger.debug("Avg. consecutive length: " + df.format(avgLength) + " units");
            double totalCost = best.score;
            int elements = selectedUnits.size();
            double avgCostBestPath = totalCost / (double)(elements - 1);
            double avgTargetCost = this.cumulTargetCosts / (double)this.nTargetCosts;
            double avgJoinCost = this.cumulJoinCosts / (double)this.nJoinCosts;
            this.logger.debug("Avg. cost: best path " + df.format(avgCostBestPath) + ", avg. target " + df.format(avgTargetCost) + ", join " + df.format(avgJoinCost) + " (n=" + this.nTargetCosts + ")");
            DebugStats stats = debugStats.get(this.database);
            if (stats == null) {
                stats = new DebugStats();
                debugStats.put(this.database, stats);
            }
            ++stats.n;
            stats.avgLength += ((double)avgLength - stats.avgLength) / (double)stats.n;
            stats.avgCostBestPath += (avgCostBestPath - stats.avgCostBestPath) / (double)stats.n;
            stats.avgTargetCost += (avgTargetCost - stats.avgTargetCost) / (double)stats.n;
            stats.avgJoinCost += (avgJoinCost - stats.avgJoinCost) / (double)stats.n;
            this.logger.debug("Total average of " + stats.n + " utterances for this voice:");
            this.logger.debug("Avg. length: " + df.format(stats.avgLength) + ", avg. cost best path: " + df.format(stats.avgCostBestPath) + ", avg. target cost: " + df.format(stats.avgTargetCost) + ", avg. join cost: " + df.format(stats.avgJoinCost));
        }
        return selectedUnits;
    }

    private ViterbiPath getPath(ViterbiPath path, ViterbiCandidate candidate) {
        double joinCost;
        Target candidateTarget = candidate.target;
        Unit candidateUnit = candidate.unit;
        double sCost = 0.0;
        double targetCost = candidate.targetCost;
        if (path == null || path.candidate == null) {
            joinCost = 0.0;
        } else {
            ViterbiCandidate prevCandidate = path.candidate;
            Target prevTarget = prevCandidate.target;
            Unit prevUnit = prevCandidate.unit;
            joinCost = this.joinCostFunction.cost(prevTarget, prevUnit, candidateTarget, candidateUnit);
            if (this.sCostFunction != null) {
                sCost = this.sCostFunction.cost(prevUnit, candidateUnit);
            }
        }
        double cost = (joinCost *= (double)this.wJoinCosts) + (targetCost *= (double)this.wTargetCosts) + (sCost *= (double)this.wSCosts);
        if (joinCost < Double.POSITIVE_INFINITY) {
            this.cumulJoinCosts += joinCost;
        }
        ++this.nJoinCosts;
        this.cumulTargetCosts += targetCost;
        ++this.nTargetCosts;
        if (path != null) {
            cost += path.score;
        }
        return new ViterbiPath(candidate, path, cost);
    }

    private ViterbiPath findBestPath() {
        ViterbiPath best;
        assert (this.beamSize != 0);
        List<ViterbiPath> paths = this.lastPoint.getPaths();
        if (paths.isEmpty()) {
            return null;
        }
        Collections.sort(paths);
        ViterbiPath path = best = paths.get(0);
        double cfr_ignored_0 = best.score;
        while (path != null) {
            ViterbiPath prev = path.previous;
            if (prev != null) {
                prev.setNext(path);
            }
            path = prev;
        }
        return best;
    }

    private class DebugStats {
        int n;
        double avgLength;
        double avgCostBestPath;
        double avgTargetCost;
        double avgJoinCost;

        private DebugStats() {
        }
    }
}

