// file: $isip/class/search/StackSearch/StackSearch.h // version: $Id: StackSearch.h 9742 2004-08-20 20:43:15Z may $ // // make sure definitions are only made once // #ifndef ISIP_STACK_SEARCH #define ISIP_STACK_SEARCH // isip include files: // #ifndef ISIP_PRIORITY_QUEUE #include #endif #ifndef ISIP_HYPOTHESIS #include #endif #ifndef ISIP_CONTEXT_POOL #include #endif #ifndef ISIP_HISTORY_POOL #include #endif #ifndef ISIP_HASH_KEY #include #endif #ifndef ISIP_MATRIX_FLOAT #include #endif #ifndef ISIP_FRONTEND #include #endif #ifndef ISIP_HIERARCHICAL_DIGRAPH #include #endif // StackSearch: A hierarchical, asynchronous, stack search engine. // // StackSearch can decode same hierarchical structure of graphs. // Decoding algorithm follows the multistack approach described in: // // Renals S., Hochberg. M.: Decoder Technology for Connectionist // Large Vocabulary Speech Recognition // // class StackSearch { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; static const String PARAM_DECODING_MODE; static const String PARAM_SCORE_NORMALIZATION; static const String PARAM_STACK_LEVEL; static const String PARAM_MAX_MIN_FILENAME; static const String PARAM_SIL_DUR_PENALTY; //---------------------------------------- // // other important constants // //---------------------------------------- enum EVALUATION_MODE { FRAME_EVAL = 0, SEGMENT_EVAL, DEF_EVALUATION_MODE = FRAME_EVAL }; enum DECODING_MODE { GLOBAL_STACK = 0, MULTI_STACK, VITERBI, DEF_DECODING_MODE = GLOBAL_STACK }; enum SCORE_NORMALIZATION { NONE = 0, DURATION, NUM_OF_PHONES, DEF_SCORE_NORMALIZATION = NONE }; // define static NameMap objects for the enumerated values // static const NameMap EVALUATION_MODE_MAP; static const NameMap DECODING_MODE_MAP; static const NameMap SCORE_NORMALIZATION_MAP; // define the start frame index // static const int32 START_FRAME = -1; // define how many frames to decode // static const int32 ALL_FRAMES = -1; //---------------------------------------- // // default values and arguments // //---------------------------------------- // default values // static const int32 MAX_NUM_LEVELS = 3; static const int32 MAX_NUM_FRAMES = 1000; static const int32 DEF_N_BEST = 1; static const int32 DEF_STACK_LEVEL = 0; static const int32 DEF_GLOB_STACK_CAPACITY = 100000; static const float32 DEF_SIL_DUR_PENALTY = 0; // default arguments to methods // static const int32 ALL_LEVELS = -1; //--------------------------------------- // // error codes // //--------------------------------------- static const int32 ERR = (int32)90800; static const int32 ERR_LEVELS_NOT_LOADED = (int32)90801; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // type definition // typedef GraphVertex GVSnode; // evaluation mode // EVALUATION_MODE evaluation_mode_d; // decoding mode // DECODING_MODE decoding_mode_d; // score normalization // SCORE_NORMALIZATION score_normalization_d; bool8 user_requested_no_score_normalization_d; // feature vectors // Vector features_d; // svm features // Vector > svm_features_d; // total number of frames to decode // Long num_frames_d; // number of best word hypotheses to find // Long n_best_d; // lever where to put the stack // Long stack_level_d; // current frame in the search // Long current_frame_d; // search levels, sub-graphs are maintained at each SearchLevel // HierarchicalDigraph* h_digraph_d; // vector that contains the location (index) of the nsymbol mapping // Vector nsymbol_mapping_d; // vector that contains the location (index) of the context mapping // Vector context_mapping_d; // maximal trace scores at each frame for all levels // used for beam pruning // MatrixFloat max_frame_scores_d; // maximal trace scores for stacks // VectorFloat max_stack_scores_d; // statistical model scores // MatrixFloat stat_model_scores_d; // vector that maps states to statistical model indices // Vector stat_model_mapping_d; // lists to keep track of traces during the search // Vector > trace_lists_d; PriorityQueue valid_hyps_d; DoubleLinkedList term_hyps_d; // vector of hypothesis stacks // each stack contains word level hypothesess of one frame // Vector< HashTable > > stacks_d; // vector of hypothesis stacks // each stack contains word level hypothesess of one frame // PriorityQueue global_stack_d; // score of the best valid hypothesis found so far // float32 best_valid_score_d; Float sil_dur_penalty_d; // statistics // VectorLong num_eval_d; VectorLong num_avoided_eval_d; VectorLong num_traces_gen_d; VectorLong num_traces_vit_prun_d; VectorLong num_traces_beam_prun_d; // output file // Filename max_min_filename_d; // history and context pools // HistoryPool history_pool_d; ContextPool context_pool_d; // define a static debug level // static Integral::DEBUG debug_level_d; // define a static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } static bool8 diagnose(Integral::DEBUG debug_level); // debug methods: // the setDebug method for this class is static because the debug_level is // shared across all objects of this class // bool8 debug(const unichar* message) const; // method: setDebug // static bool8 setDebug(Integral::DEBUG debug_level) { debug_level_d = debug_level; return true; } // destructor/constructor(s) // ~StackSearch(); StackSearch(); StackSearch(const StackSearch& copy_search); // method: assign // bool8 assign(const StackSearch& copy_search) { return Error::handle(name(), L"assign", Error::NOT_IMPLEM, __FILE__, __LINE__); } // sofSize method // int32 sofSize() const; // read method // bool8 read(Sof& sof, int32 tag, const String& cname = CLASS_NAME); // write method // bool8 write(Sof& sof, int32 tag, const String& cname = CLASS_NAME) const; // method: readData // bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); // method: writeData // bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const; // method: eq // bool8 eq(const StackSearch& compare_search) const { return h_digraph_d->eq(*compare_search.h_digraph_d); } // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // clear method // bool8 clear(Integral::CMODE ctype = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods // //--------------------------------------------------------------------------- // search level manipulation // bool8 setNumLevels(int32 num_levels); // method: getNumLevels // int32 getNumLevels() const { return h_digraph_d->length(); } // method: getSearchLevel // SearchLevel& getSearchLevel(int32 level) { return (*h_digraph_d)(level); } // method: getHDigraph // HierarchicalDigraph& getHDigraph() { return *h_digraph_d; } // method: setHDigraph // bool8 setHDigraph(HierarchicalDigraph& h_digraph) { h_digraph_d = &h_digraph; return true; } // search initialization methods // bool8 initSizes(int32 total_num_frames); bool8 initializeLinearDecoder(); // stack decoding methods // bool8 decode(Vector& vector_fe, int32 num_frames = ALL_FRAMES); bool8 decode(Vector& features, int32 num_frames = ALL_FRAMES); // hypothesis methods // bool8 getHypotheses(String& output_hyp, int32 level, float64& total_score, int32& num_frames, DoubleLinkedList& trace_path, String* word_hyp = NULL); // method: setDecodingMode // bool8 setDecodingMode(DECODING_MODE decoding_mode) { decoding_mode_d = decoding_mode; return true; } // method: getDecodingMode // DECODING_MODE getDecodingMode() { return decoding_mode_d; } // method: setStackLevel // bool8 setStackLevel(int32 level) { stack_level_d = level; return true; } // method: getStackLevel // int32 getStackLevel() { return stack_level_d; } // method: getScoreNormalization // SCORE_NORMALIZATION getScoreNormalization() { return score_normalization_d; } // method: setScoreNormalization // bool8 setScoreNormalization(SCORE_NORMALIZATION arg) { score_normalization_d = arg; if (arg == NONE) { user_requested_no_score_normalization_d = true; } return true; } // method: setSilDurPenalty // bool8 setSilDurPenalty(float32 arg) { sil_dur_penalty_d = arg; return true; } // method: getSilDurPenalty // float32 getSilDurPenalty() { return sil_dur_penalty_d; } //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: // search initialization methods // bool8 clearTraceStorage(); bool8 clearSearchNodesTraceLists(); bool8 clearValidHypsTrace(); bool8 clearTermHypsTrace(); // search traversal methods // bool8 traverseTraceLevels(); bool8 pathsRemainTrace(); // trace propagation methods // bool8 propagateTraces(DoubleLinkedList* pickup_list = NULL, int32 pickup_level = -1); bool8 propagateTracesDown(int32 level_num_a, DoubleLinkedList* pickup_list = NULL, int32 pickup_level = -1); bool8 propagateTracesUp(int32 level_num_a, DoubleLinkedList* pickup_list = NULL, int32 pickup_level = -1); bool8 evaluateTraceModels(DoubleLinkedList* pickup_list = NULL, int32 pickup_level = -1); // trace bookkeeping methods // int32 getActiveTraces(int32 search_level = ALL_LEVELS); // debugging methods // bool8 printNewPath(Trace* new_trace, Trace* old_trace); bool8 printDeletedPath(Trace* new_trace, Trace* old_trace); // pruning method // bool8 beamPruneTrace(int32 level_num); // context related methods // bool8 generateRightContexts(DoubleLinkedList& context_list, Context* initial_context, int32 depth); bool8 generateInitialTraces(); // method extending hypothesis by one word // bool8 extend(DoubleLinkedList& hlist, int32 start_frame); bool8 getTraces(DoubleLinkedList& trace_list, int32 level); bool8 setTraces(DoubleLinkedList& trace_list, int32 level); // stack decoding core method // bool8 decode(); // method: setFrame // bool8 setFrame(int32 frame) { current_frame_d = frame; return true; } // method to determine evaluation mode (FRAME, SEGMENT) // bool8 determineEvaluationMode(); // method to normalize features // bool8 normalizeFeatures(); // method to transform skip models // bool8 transformSkipModels(); // method to add skip transitions // bool8 addSkipArcs(GraphVertex* vertex_a, const int32 level_ind_a); }; // end of include file // #endif