57 #include <sphinxbase/err.h>
58 #include <sphinxbase/ckd_alloc.h>
59 #include <sphinxbase/strfuncs.h>
60 #include <sphinxbase/cmd_ln.h>
65 #include "fsg_search_internal.h"
66 #include "fsg_history.h"
67 #include "fsg_lextree.h"
71 #define __FSG_DBG_CHAN__ 0
90 fsg_search_add_silences(
fsg_search_t *fsgs, fsg_model_t *fsg)
96 dict = ps_search_dict(fsgs);
110 fsg_model_add_silence(fsg,
"<sil>", -1,
111 cmd_ln_float32_r(ps_search_config(fsgs),
"-silprob"));
114 for (wid = dict_filler_start(dict); wid < dict_filler_end(dict); ++wid) {
115 char const *word = dict_wordstr(dict, wid);
116 if (wid == dict_startwid(dict) || wid == dict_finishwid(dict))
118 fsg_model_add_silence(fsg, word, -1,
119 cmd_ln_float32_r(ps_search_config(fsgs),
"-fillprob"));
128 fsg_search_check_dict(
fsg_search_t *fsgs, fsg_model_t *fsg)
133 dict = ps_search_dict(fsgs);
134 for (i = 0; i < fsg_model_n_word(fsg); ++i) {
138 word = fsg_model_word_str(fsg, i);
141 E_ERROR(
"The word '%s' is missing in the dictionary\n", word);
150 fsg_search_add_altpron(
fsg_search_t *fsgs, fsg_model_t *fsg)
156 dict = ps_search_dict(fsgs);
159 n_word = fsg_model_n_word(fsg);
160 for (i = 0; i < n_word; ++i) {
164 word = fsg_model_word_str(fsg, i);
167 while ((wid = dict_nextalt(dict, wid)) !=
BAD_S3WID) {
168 n_alt += fsg_model_add_alt(fsg, word, dict_wordstr(dict, wid));
173 E_INFO(
"Added %d alternate word transitions\n", n_alt);
178 fsg_search_init(
const char *name,
186 ps_search_init(ps_search_base(fsgs), &fsg_funcs, PS_SEARCH_TYPE_FSG, name, config, acmod, dict, d2p);
188 fsgs->
fsg = fsg_model_retain(fsg);
192 if (fsgs->
hmmctx == NULL) {
193 ps_search_free(ps_search_base(fsgs));
198 fsgs->
history = fsg_history_init(NULL, dict);
204 = (int32) logmath_log(acmod->
lmath, cmd_ln_float64_r(config,
"-beam"))
207 = (int32) logmath_log(acmod->
lmath, cmd_ln_float64_r(config,
"-pbeam"))
210 = (int32) logmath_log(acmod->
lmath, cmd_ln_float64_r(config,
"-wbeam"))
214 fsgs->lw = cmd_ln_float32_r(config,
"-lw");
215 fsgs->pip = (int32) (logmath_log(acmod->
lmath, cmd_ln_float32_r(config,
"-pip"))
218 fsgs->
wip = (int32) (logmath_log(acmod->
lmath, cmd_ln_float32_r(config,
"-wip"))
223 fsgs->
ascale = 1.0 / cmd_ln_float32_r(config,
"-ascale");
225 E_INFO(
"FSG(beam: %d, pbeam: %d, wbeam: %d; wip: %d, pip: %d)\n",
227 fsgs->
wip, fsgs->pip);
229 if (!fsg_search_check_dict(fsgs, fsg)) {
230 fsg_search_free(ps_search_base(fsgs));
234 if (cmd_ln_boolean_r(config,
"-fsgusefiller") &&
235 !fsg_model_has_sil(fsg))
236 fsg_search_add_silences(fsgs, fsg);
238 if (cmd_ln_boolean_r(config,
"-fsgusealtpron") &&
239 !fsg_model_has_alt(fsg))
240 fsg_search_add_altpron(fsgs, fsg);
242 if (fsg_search_reinit(ps_search_base(fsgs),
243 ps_search_dict(fsgs),
244 ps_search_dict2pid(fsgs)) < 0)
246 ps_search_free(ps_search_base(fsgs));
250 return ps_search_base(fsgs);
261 fsg_history_reset(fsgs->
history);
262 fsg_history_set_fsg(fsgs->
history, NULL, NULL);
263 fsg_history_free(fsgs->
history);
266 fsg_model_free(fsgs->
fsg);
287 ps_search_acmod(fsgs)->mdef,
291 fsg_history_set_fsg(fsgs->
history, fsgs->
fsg, dict);
306 for (gn = fsgs->
pnode_active; gn; gn = gnode_next(gn)) {
308 hmm = fsg_pnode_hmmptr(pnode);
309 assert(hmm_frame(hmm) == fsgs->
frame);
331 E_ERROR(
"Frame %d: No active HMM!!\n", fsgs->
frame);
335 for (n = 0, gn = fsgs->
pnode_active; gn; gn = gnode_next(gn), n++) {
339 hmm = fsg_pnode_hmmptr(pnode);
340 assert(hmm_frame(hmm) == fsgs->
frame);
343 E_INFO(
"pnode(%08x) active @frm %5d\n", (int32) pnode,
349 E_INFO(
"pnode(%08x) after eval @frm %5d\n",
350 (int32) pnode, fsgs->
frame);
359 E_INFO(
"[%5d] %6d HMM; bestscr: %11d\n", fsgs->
frame, n, bestscore);
364 maxhmmpf = cmd_ln_int32_r(ps_search_config(fsgs),
"-maxhmmpf");
365 if (maxhmmpf != -1 && n > maxhmmpf) {
387 if (n > fsg_lextree_n_pnode(fsgs->
lextree))
388 E_FATAL(
"PANIC! Frame %d: #HMM evaluated(%d) > #PNodes(%d)\n",
400 int32 newscore, thresh, nf;
403 assert(!fsg_pnode_leaf(pnode));
405 nf = fsgs->
frame + 1;
408 hmm = fsg_pnode_hmmptr(pnode);
410 for (child = fsg_pnode_succ(pnode);
411 child; child = fsg_pnode_sibling(child)) {
412 newscore = hmm_out_score(hmm) + child->logs2prob;
415 && (newscore
BETTER_THAN hmm_in_score(&child->hmm))) {
417 if (hmm_frame(&child->hmm) < nf) {
424 hmm_enter(&child->hmm, newscore, hmm_out_history(hmm), nf);
439 assert(fsg_pnode_leaf(pnode));
441 hmm = fsg_pnode_hmmptr(pnode);
442 fl = fsg_pnode_fsglink(pnode);
445 wid = fsg_link_wid(fl);
449 E_INFO(
"[%5d] Exit(%08x) %10d(score) %5d(pred)\n",
450 fsgs->
frame, (int32) pnode,
451 hmm_out_score(hmm), hmm_out_history(hmm));
458 if (fsg_model_is_filler(fsgs->
fsg, wid)
460 || (dict_is_single_phone(ps_search_dict(fsgs),
462 fsg_model_word_str(fsgs->
fsg, wid))))) {
467 fsg_history_entry_add(fsgs->
history,
471 hmm_out_history(hmm),
472 pnode->ci_ext, ctxt);
477 fsg_history_entry_add(fsgs->
history,
481 hmm_out_history(hmm),
482 pnode->ci_ext, pnode->ctxt);
499 int32 thresh, word_thresh, phone_thresh;
504 phone_thresh = fsgs->
bestscore + fsgs->pbeam;
507 for (gn = fsgs->
pnode_active; gn; gn = gnode_next(gn)) {
509 hmm = fsg_pnode_hmmptr(pnode);
511 if (hmm_bestscore(hmm) >= thresh) {
513 if (hmm_frame(hmm) == fsgs->
frame) {
514 hmm_frame(hmm) = fsgs->
frame + 1;
520 assert(hmm_frame(hmm) == fsgs->
frame + 1);
523 if (!fsg_pnode_leaf(pnode)) {
524 if (hmm_out_score(hmm) >= phone_thresh) {
526 fsg_search_pnode_trans(fsgs, pnode);
530 if (hmm_out_score(hmm) >= word_thresh) {
532 fsg_search_pnode_exit(fsgs, pnode);
546 int32 bpidx, n_entries, thresh, newscore;
555 n_entries = fsg_history_n_entries(fsgs->
history);
557 for (bpidx = fsgs->
bpidx_start; bpidx < n_entries; bpidx++) {
559 hist_entry = fsg_history_entry_get(fsgs->
history, bpidx);
561 l = fsg_hist_entry_fsglink(hist_entry);
564 s = l ? fsg_link_to_state(l) : fsg_model_start_state(fsg);
572 for (itor = fsg_model_arcs(fsg, s); itor;
573 itor = fsg_arciter_next(itor)) {
574 fsg_link_t *l = fsg_arciter_get(itor);
577 if (fsg_link_wid(l) != -1)
580 fsg_hist_entry_score(hist_entry) +
583 if (newscore >= thresh) {
584 fsg_history_entry_add(fsgs->
history, l,
585 fsg_hist_entry_frame(hist_entry),
588 fsg_hist_entry_lc(hist_entry),
589 fsg_hist_entry_rc(hist_entry));
603 int32 bpidx, n_entries;
606 int32 score, newscore, thresh, nf, d;
610 n_entries = fsg_history_n_entries(fsgs->
history);
613 nf = fsgs->
frame + 1;
615 for (bpidx = fsgs->
bpidx_start; bpidx < n_entries; bpidx++) {
616 hist_entry = fsg_history_entry_get(fsgs->
history, bpidx);
618 score = fsg_hist_entry_score(hist_entry);
619 assert(fsgs->
frame == fsg_hist_entry_frame(hist_entry));
621 l = fsg_hist_entry_fsglink(hist_entry);
624 d = l ? fsg_link_to_state(l) : fsg_model_start_state(fsgs->
627 lc = fsg_hist_entry_lc(hist_entry);
630 for (root = fsg_lextree_root(fsgs->
lextree, d);
631 root; root = root->sibling) {
634 if ((root->ctxt.bv[lc >> 5] & (1 << (lc & 0x001f))) &&
635 (hist_entry->rc.bv[rc >> 5] & (1 << (rc & 0x001f)))) {
643 newscore = score + root->logs2prob;
646 && (newscore
BETTER_THAN hmm_in_score(&root->hmm))) {
647 if (hmm_frame(&root->hmm) < nf) {
654 (
"[%5d] WordTrans bpidx[%d] -> pnode[%08x] (activated)\n",
655 fsgs->
frame, bpidx, (int32) root);
661 (
"[%5d] WordTrans bpidx[%d] -> pnode[%08x]\n",
662 fsgs->
frame, bpidx, (int32) root);
666 hmm_enter(&root->hmm, newscore, bpidx, nf);
675 fsg_search_step(
ps_search_t *search,
int frame_idx)
686 fsg_search_sen_active(fsgs);
696 fsg_search_hmm_eval(fsgs);
703 fsg_search_hmm_prune_prop(fsgs);
704 fsg_history_end_frame(fsgs->
history);
710 fsg_search_null_prop(fsgs);
711 fsg_history_end_frame(fsgs->
history);
717 fsg_search_word_trans(fsgs);
725 for (gn = fsgs->
pnode_active; gn; gn = gnode_next(gn)) {
727 hmm = fsg_pnode_hmmptr(pnode);
729 if (hmm_frame(hmm) == fsgs->
frame) {
734 assert(hmm_frame(hmm) == (fsgs->
frame + 1));
776 fsg_history_reset(fsgs->
history);
777 fsg_history_utt_start(fsgs->
history);
786 fsg_history_entry_add(fsgs->
history,
787 NULL, -1, 0, -1, silcipid, ctxt);
791 fsg_search_null_prop(fsgs);
794 fsg_search_word_trans(fsgs);
820 for (gn = fsgs->
pnode_active; gn; gn = gnode_next(gn)) {
836 n_hist = fsg_history_n_entries(fsgs->
history);
838 (
"%d frames, %d HMMs (%d/fr), %d senones (%d/fr), %d history entries (%d/fr)\n\n",
843 n_hist, (fsgs->
frame > 0) ? n_hist / fsgs->
frame : 0);
849 fsg_search_find_exit(
fsg_search_t *fsgs,
int frame_idx,
int final, int32 *out_score, int32* out_is_final)
853 int bpidx, frm, last_frm, besthist;
857 *out_is_final = FALSE;
860 frame_idx = fsgs->
frame - 1;
861 last_frm = frm = frame_idx;
864 bpidx = fsg_history_n_entries(fsgs->
history) - 1;
866 hist_entry = fsg_history_entry_get(fsgs->
history, bpidx);
867 if (fsg_hist_entry_frame(hist_entry) <= frame_idx) {
868 frm = last_frm = fsg_hist_entry_frame(hist_entry);
882 while (frm == last_frm) {
886 fl = fsg_hist_entry_fsglink(hist_entry);
887 score = fsg_hist_entry_score(hist_entry);
893 if (score == bestscore && fsg_link_to_state(fl) == fsg_model_final_state(fsg)) {
898 || fsg_link_to_state(fl) == fsg_model_final_state(fsg)) {
907 hist_entry = fsg_history_entry_get(fsgs->
history, bpidx);
908 frm = fsg_hist_entry_frame(hist_entry);
912 if (besthist == -1) {
913 E_ERROR(
"Final result does not match the grammar in frame %d\n", frame_idx);
919 *out_score = bestscore;
922 hist_entry = fsg_history_entry_get(fsgs->
history, besthist);
923 fl = fsg_hist_entry_fsglink(hist_entry);
924 *out_is_final = (fsg_link_to_state(fl) == fsg_model_final_state(fsg));
931 fsg_search_bestpath(
ps_search_t *search, int32 *out_score,
int backward)
942 if (search->
post == 0)
951 fsg_search_hyp(
ps_search_t *search, int32 *out_score, int32 *out_is_final)
954 dict_t *dict = ps_search_dict(search);
960 bpidx = fsg_search_find_exit(fsgs, fsgs->
frame, fsgs->
final, out_score, out_is_final);
971 if ((dag = fsg_search_lattice(search)) == NULL) {
972 E_WARN(
"Failed to obtain the lattice while bestpath enabled\n");
975 if ((link = fsg_search_bestpath(search, out_score, FALSE)) == NULL) {
976 E_WARN(
"Failed to find the bestpath in a lattice\n");
986 fsg_link_t *fl = fsg_hist_entry_fsglink(hist_entry);
987 char const *baseword;
990 bp = fsg_hist_entry_pred(hist_entry);
991 wid = fsg_link_wid(fl);
992 if (wid < 0 || fsg_model_is_filler(fsgs->
fsg, wid))
994 baseword = dict_basestr(dict,
996 fsg_model_word_str(fsgs->
fsg, wid)));
997 len += strlen(baseword) + 1;
1005 search->
hyp_str = ckd_calloc(1, len);
1008 c = search->
hyp_str + len - 1;
1011 fsg_link_t *fl = fsg_hist_entry_fsglink(hist_entry);
1012 char const *baseword;
1015 bp = fsg_hist_entry_pred(hist_entry);
1016 wid = fsg_link_wid(fl);
1017 if (wid < 0 || fsg_model_is_filler(fsgs->
fsg, wid))
1019 baseword = dict_basestr(dict,
1021 fsg_model_word_str(fsgs->
fsg, wid)));
1022 len = strlen(baseword);
1024 memcpy(c, baseword, len);
1041 if ((bp = fsg_hist_entry_pred(hist_entry)) >= 0)
1042 ph = fsg_history_entry_get(fsgs->
history, bp);
1043 seg->
word = fsg_model_word_str(fsgs->
fsg, hist_entry->fsglink->wid);
1044 seg->
ef = fsg_hist_entry_frame(hist_entry);
1045 seg->
sf = ph ? fsg_hist_entry_frame(ph) + 1 : 0;
1047 if (seg->
sf > seg->
ef) seg->
sf = seg->
ef;
1054 seg->
ascr = hist_entry->score - ph->score - seg->
lscr;
1057 seg->
ascr = hist_entry->score - seg->
lscr;
1065 ckd_free(itor->
hist);
1079 fsg_seg_bp2itor(seg, itor->
hist[itor->
cur]);
1089 fsg_search_seg_iter(
ps_search_t *search, int32 *out_score)
1095 bpidx = fsg_search_find_exit(fsgs, fsgs->
frame, fsgs->
final, out_score, NULL);
1105 if ((dag = fsg_search_lattice(search)) == NULL)
1107 if ((link = fsg_search_bestpath(search, out_score, TRUE)) == NULL)
1116 itor = ckd_calloc(1,
sizeof(*itor));
1117 itor->
base.
vt = &fsg_segfuncs;
1124 bp = fsg_hist_entry_pred(hist_entry);
1136 itor->
hist[cur] = hist_entry;
1137 bp = fsg_hist_entry_pred(hist_entry);
1142 fsg_seg_bp2itor((
ps_seg_t *)itor, itor->hist[0]);
1157 if ((dag = fsg_search_lattice(search)) == NULL)
1159 if ((link = fsg_search_bestpath(search, NULL, TRUE)) == NULL)
1161 return search->
post;
1170 find_node(
ps_lattice_t *dag, fsg_model_t *fsg,
int sf, int32 wid, int32 node_id)
1174 for (node = dag->
nodes; node; node = node->
next)
1175 if ((node->
sf == sf) && (node->
wid == wid) && (node->
node_id == node_id))
1181 new_node(
ps_lattice_t *dag, fsg_model_t *fsg,
int sf,
int ef, int32 wid, int32 node_id, int32 ascr)
1185 node = find_node(dag, fsg, sf, wid, node_id);
1189 if (node->
lef == -1 || node->
lef < ef)
1191 if (node->
fef == -1 || node->
fef > ef)
1202 node->
fef = node->
lef = ef;
1221 glist_t start = NULL;
1225 for (node = dag->
nodes; node; node = node->
next) {
1226 if (node->
sf == 0 && node->
exits) {
1227 E_INFO(
"Start node %s.%d:%d:%d\n",
1228 fsg_model_word_str(fsgs->
fsg, node->
wid),
1230 start = glist_add_ptr(start, node);
1239 node = gnode_ptr(start);
1245 wid = fsg_model_word_add(fsgs->
fsg,
"<s>");
1246 if (fsgs->
fsg->silwords)
1247 bitvec_set(fsgs->
fsg->silwords, wid);
1248 node = new_node(dag, fsgs->
fsg, 0, 0, wid, -1, 0);
1249 for (st = start; st; st = gnode_next(st))
1264 for (node = dag->
nodes; node; node = node->
next) {
1266 E_INFO(
"End node %s.%d:%d:%d (%d)\n",
1267 fsg_model_word_str(fsgs->
fsg, node->
wid),
1269 end = glist_add_ptr(end, node);
1275 node = gnode_ptr(end);
1277 else if (nend == 0) {
1283 for (node = dag->
nodes; node; node = node->
next) {
1291 E_INFO(
"End node %s.%d:%d:%d (%d)\n",
1292 fsg_model_word_str(fsgs->
fsg, node->
wid),
1301 wid = fsg_model_word_add(fsgs->
fsg,
"</s>");
1302 if (fsgs->
fsg->silwords)
1303 bitvec_set(fsgs->
fsg->silwords, wid);
1304 node = new_node(dag, fsgs->
fsg, fsgs->
frame, fsgs->
frame, wid, -1, 0);
1307 for (st = end; st; st = gnode_next(st)) {
1323 q = glist_add_ptr(NULL, end);
1329 q = gnode_free(q, NULL);
1331 for (x = node->
entries; x; x = x->next) {
1335 q = glist_add_ptr(q, next);
1377 n = fsg_history_n_entries(fsgs->
history);
1378 for (i = 0; i < n; ++i) {
1384 if (fh->fsglink == NULL || fh->fsglink->wid == -1)
1397 ascr = fh->score - pfh->score;
1398 sf = pfh->frame + 1;
1411 new_node(dag, fsg, sf, fh->frame, fh->fsglink->wid, fsg_link_to_state(fh->fsglink), ascr);
1417 n = fsg_history_n_entries(fsgs->
history);
1418 for (i = 0; i < n; ++i) {
1420 fsg_arciter_t *itor;
1426 if (fh->fsglink == NULL || fh->fsglink->wid == -1)
1432 sf = pfh->frame + 1;
1433 ascr = fh->score - pfh->score;
1439 src = find_node(dag, fsg, sf, fh->fsglink->wid, fsg_link_to_state(fh->fsglink));
1442 for (itor = fsg_model_arcs(fsg, fsg_link_to_state(fh->fsglink));
1443 itor; itor = fsg_arciter_next(itor)) {
1444 fsg_link_t *link = fsg_arciter_get(itor);
1447 if (link->wid >= 0) {
1452 if ((dest = find_node(dag, fsg, sf, link->wid, fsg_link_to_state(link))) != NULL)
1460 fsg_arciter_t *itor2;
1463 for (itor2 = fsg_model_arcs(fsg, fsg_link_to_state(link));
1464 itor2; itor2 = fsg_arciter_next(itor2)) {
1465 fsg_link_t *link = fsg_arciter_get(itor2);
1467 if (link->wid == -1)
1470 if ((dest = find_node(dag, fsg, sf, link->wid, fsg_link_to_state(link))) != NULL) {
1480 if ((dag->
start = find_start_node(fsgs, dag)) == NULL) {
1481 E_WARN(
"Failed to find the start node\n");
1484 if ((dag->
end = find_end_node(fsgs, dag)) == NULL) {
1485 E_WARN(
"Failed to find the end node\n");
1490 E_INFO(
"lattice start node %s.%d end node %s.%d\n",
1492 fsg_model_word_str(fsg, dag->
end->
wid), dag->
end->
sf);
1498 for (node = dag->
nodes; node; node = node->
next) {
1500 fsg_model_word_str(fsg, node->
wid));
1510 mark_reachable(dag, dag->
end);
1514 int32 silpen, fillpen;
1516 silpen = (int32)(logmath_log(fsg->lmath,
1517 cmd_ln_float32_r(ps_search_config(fsgs),
"-silprob"))
1520 fillpen = (int32)(logmath_log(fsg->lmath,
1521 cmd_ln_float32_r(ps_search_config(fsgs),
"-fillprob"))