// file: $isip/class/search/HierarchicalSearch/HierarchicalSearch.h // version: $Id: HierarchicalSearch.h 10636 2007-01-26 22:18:09Z tm334 $ // // make sure definitions are only made once // #ifndef ISIP_HIERARCHICAL_SEARCH #define ISIP_HIERARCHICAL_SEARCH // isip include files: // #ifndef ISIP_FRONT_END #include #endif #ifndef ISIP_VECTOR_FLOAT #include #endif /* #ifndef ISIP_STACK */ /* #include */ /* #endif */ #ifndef ISIP_QUEUE #include #endif #ifndef ISIP_TRACE #include #endif #ifndef ISIP_INSTANCE #include #endif #ifndef ISIP_CONTEXT_POOL #include #endif #ifndef ISIP_HISTORY_POOL #include #endif #ifndef ISIP_HASH_KEY #include #endif #ifndef ISIP_HIERARCHICAL_DIGRAPH #include #endif #ifndef ISIP_SEARCH_NODE #include #endif #ifndef ISIP_BI_GRAPH #include #endif #ifndef NGRAM_CACHE #include #endif #ifndef ISIP_SYMBOL_GRAPH #include #endif #ifndef ISIP_SYMBOL_GRAPH_NODE #include #endif // HierarchicalSearch: A hierarchical, synchronous, Viterbi search engine. // // The basic premise behind this is as follows: // // 1) The search space is hierarchical in that each level has nodes and // each node is either: // a) composed of a sub-graph + probability evaluation // b) composed of probability evaluation. // // 2) At each level you will be able to do the following: // a) beam prune // b) instance prune // c) generate context-dependent scoring models on the fly including // cross-symbol models // d) specify whether search information is to be maintained - // otherwise we can prune away redundant information. For instance, // once a word is completed, it is no longer necessary to maintain // the sub-word units that correspond to that word. // e) turn on/off sub-level evaluation // // 3) Each graph position will be capable of the following: // a) generate a probability of being in that graph position given the // input stream. For acoustic models, this will involve direct // evaluation of the data stream. For levels like language models, // this will not involve manipulation of the data stream at all. // Any level can access the data stream as it will be passed from // the highest level all the way down to the lowest via pointer // reference. // b) store the overlapping paths in the sub-graph as a "lexical" tree. // c) use N-symbol models related to the possible entries on that level // // 4) We will revolve around the concept of a "trace" where the // trace moves through the graph and keeps track of the paths/scores // through the search space // class HierarchicalSearch { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; //---------------------------------------- // // other important constants // //---------------------------------------- // define the search mode choices // enum SEARCH_MODE { DECODE = 0, TRAIN, DEF_SEARCH_MODE = DECODE }; // define the start frame index // static const int64 DEF_START_FRAME = -1; static const int32 DEF_CONTEXT_LEVEL = 0; static const bool8 DEF_CONTEXT_MODE = false; //---------------------------------------- // // default values and arguments // //---------------------------------------- // default values // static const int32 DEF_NUM_LEVELS = 1; static const int32 DEF_CAPACITY = 12000; // default arguments to methods // static const int32 DEF_NUM_FRAMES = -1; static const int32 DEF_INITIAL_LEVEL = 0; static const int32 ALL_LEVELS = -1; //--------------------------------------- // // error codes // //--------------------------------------- static const int32 ERR = (int32)90700; static const int32 ERR_LEVEL = (int32)90701; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // define a search mode flag // SEARCH_MODE search_mode_d; // current frame's feature vector // VectorFloat features_d; // HierarchicalDigraph, sub-graphs are maintained at each SearchLevel // HierarchicalDigraph* h_digraph_d; // initial search level // Long initial_level_d; // current frame in the search // Long current_frame_d; // the lexical tree list // SingleLinkedList lex_tree_list_d; // current maximal trace scores // Vector max_trace_scores_d; // current maximal instance scores // Vector max_instance_scores_d; // symbol graph generation structure // SymbolGraph symbol_graph_d; // n-symbol probability lookup variables // NGramCache nsymbol_cache_d; Vector nsymbol_indices_d; // lists to keep track of instances during the search // Vector > instance_lists_d; DoubleLinkedList instance_valid_hyps_d; // lists to keep track of traces during the search // Vector > trace_lists_d; DoubleLinkedList trace_valid_hyps_d; // define a bi-graph structure that stores the entire search path // BiGraph trellis_d; // define a mapping from the hypothesis path to the vertices // HashTable, BiGraphVertex > instance_mapping_d; // define the context generation level // Long context_level_d; bool8 context_generation_mode_d; // define the context generation list // Vector context_list_d; HashTable context_hash_d; // history and context pools // HistoryPool history_pool_d; ContextPool context_pool_d; // define a static debug level // static Integral::DEBUG debug_level_d; // define a static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } static bool8 diagnose(Integral::DEBUG debug_level); // debug methods: // the setDebug method for this class is static because the debug_level is // shared across all objects of this class // bool8 debug(const unichar* message) const; // method: setDebug // static bool8 setDebug(Integral::DEBUG debug_level) { debug_level_d = debug_level; return true; } // destructor/constructor(s) // ~HierarchicalSearch(); HierarchicalSearch(); HierarchicalSearch(const HierarchicalSearch& copy_search); // method: assign // bool8 assign(const HierarchicalSearch& copy_search) { return Error::handle(name(), L"assign", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: sofSize // int32 sofSize() const { return Error::handle(name(), L"sofSize", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: read // bool8 read(Sof& sof, int32 tag, const String& cname = CLASS_NAME) { return Error::handle(name(), L"read", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: write // bool8 write(Sof& sof, int32 tag, const String& cname = CLASS_NAME) const { return Error::handle(name(), L"write", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: readData // bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false) { return Error::handle(name(), L"readData", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: writeData // bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const { return Error::handle(name(), L"writeData", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: eq // bool8 eq(const HierarchicalSearch& compare_search) const { return h_digraph_d->eq(*compare_search.h_digraph_d); } // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // clear method // bool8 clear(Integral::CMODE ctype = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods // //--------------------------------------------------------------------------- // method: setSearchMode // bool8 setSearchMode(SEARCH_MODE arg) { search_mode_d = arg; return true; } // method: getSearchMode // SEARCH_MODE getSearchMode() { return search_mode_d; } // method: setInitialLevel // bool8 setInitialLevel(int32 init_level) { initial_level_d = init_level; return true; } // method: getInitialLevel // int32 getInitialLevel() const { return (int32)initial_level_d; } // search level manipulation // bool8 setNumLevels(int32 num_levels); // method: getNumLevels // int32 getNumLevels() const { return h_digraph_d->length(); } // method: getContextLevel // int32 getContextLevel() { return (int32)context_level_d; } // method: setContextLevel // bool8 setContextLevel(int32 arg) { return context_level_d.assign(arg); } // method: getContextList // Vector& getContextList() { return context_list_d; } // method: setContextList // bool8 setContextList(Vector& arg) { return context_list_d.assign(arg); } // method: setFeatures // bool8 setFeatures(Vector& features) { current_frame_d = DEF_START_FRAME; return true; } // method: getSearchLevel // SearchLevel& getSearchLevel(int32 level) { return (*h_digraph_d)(level); } // method: getHDigraph // HierarchicalDigraph& getHDigraph() { return *h_digraph_d; } // method: setHDigraph // bool8 setHDigraph(HierarchicalDigraph& h_digraph) { h_digraph_d = &h_digraph; return true; } // method: getContextPool // ContextPool& getContextPool() { return context_pool_d; } // method: setContextPool // bool8 setContextPool(ContextPool& arg) { return context_pool_d.assign(arg); } // method: getHistoryPool // HistoryPool& getHistoryPool() { return history_pool_d; } // method: setHistoryPool // bool8 setHistoryPool(HistoryPool& arg) { return history_pool_d.assign(arg); } // search initialization methods // bool8 initializeLinearDecoder(); bool8 initializeLinearPartial(); bool8 initializeGrammarDecoder(); bool8 initializeGrammarPartial(); bool8 initializeNetworkDecoder(); bool8 initializeContextGeneration(); // decoding methods // bool8 linearDecoder(Vector& vector_fe, int32 num_frames = DEF_NUM_FRAMES); bool8 linearDecoder(Vector& fe, int32 num_frames = DEF_NUM_FRAMES); bool8 grammarDecoder(Vector& vector_fe, int32 num_frames = DEF_NUM_FRAMES); bool8 grammarDecoder(Vector& fe, int32 num_frames = DEF_NUM_FRAMES); bool8 networkDecoder(Vector& vector_fe, int32 num_frames = DEF_NUM_FRAMES); bool8 networkDecoder(Vector& fe, int32 num_frames = DEF_NUM_FRAMES); // hypothesis methods // bool8 getHypotheses(String& output_hyp, int32 level, float64& total_score, int32& num_frames, DoubleLinkedList& trace_path); bool8 getHypotheses(String& output_hyp, int32 level, float64& total_score, int32& num_frames, DoubleLinkedList& trace_path); bool8 convertLexInstances( DoubleLinkedList& instance_path); bool8 printInstances( DoubleLinkedList& instance_path); bool8 getLexHypotheses(String& output_hyp, int32 level, float32& total_score, int32& num_frames, DoubleLinkedList& trace_path); bool8 forcedAlignment(String& alignment, Long level, DoubleLinkedList& trace_path, bool8 cumulative = false); // trace manipulation methods // bool8 connectValidHypothesis(); BiGraphVertex* insertTrace(Trace* arg); BiGraphVertex* insertInstance(Instance* arg); // trellis state occupancy computation routines // BiGraph* computeForwardBackward(Vector& data, float32 beta_threshold); // method: getSymbolGraph // SymbolGraph& getSymbolGraph() { return symbol_graph_d; } // symbol graph generation methods // bool8 generateSymbolGraph(); bool8 initializeSymbolGraphTrace(float64& score, int32& num_frames); bool8 initializeSymbolGraphInstance(float64& score, int32& num_frames); // methods for parsing and generating transcriptions // bool8 parseGraph(Vector& node_a, const Vector& transcription_a, int32 index_a, int32 search_level_a, Long& operation_level_a, bool8 exclude_a); bool8 genRandomSentences(int32 sent_len, Random& random_gen, int32 output_level, String& sentence, bool8 exclude); //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: // symbol graph generation methods // bool8 generateSymbolGraph(SymbolGraphNode* prev_node, SymbolGraphNode* curr_node, float32 score, float32 lm_score); // search initialization methods // bool8 clearTraceStorage(); bool8 clearSearchNodesTraceLists(); bool8 clearValidHypsTrace(); bool8 clearInstanceStorage(); bool8 clearSearchNodesInstanceLists(); bool8 clearValidHypsInstance(); bool8 clearValidHypsLexInstance(); // search traversal methods // bool8 traverseTraceLevels(); bool8 pathsRemainTrace(); bool8 traverseInstanceLevels(); bool8 pathsRemainInstance(); // trace propagation methods // bool8 propagateTraces(); bool8 propagateTracesDown(int32 level_num); bool8 propagateTracesUp(int32 level_num); bool8 evaluateTraceModels(); // instance propagation methods // bool8 propagateInstances(); bool8 propagateInstancesDown(int32 level_num); bool8 propagateInstancesUp(int32 level_num); bool8 evaluateInstanceModels(); bool8 evaluateLexInstanceModels(); // trace bookkeeping methods // int32 getActiveTraces(int32 search_level = ALL_LEVELS); // instance bookkeeping methods // int32 getActiveInstances(int32 search_level = ALL_LEVELS); // debugging methods // bool8 printNewPath(Trace* new_trace, Trace* old_trace); bool8 printDeletedPath(Trace* new_trace, Trace* old_trace); bool8 printNewPath(Instance* new_instance, Instance* old_instance); bool8 printDeletedPath(Instance* new_instance, Instance* old_instance); // trellis insertion methods // bool8 insertNewPath(Trace* parent, Trace* child, float32 weight); bool8 insertOldPath(Trace* parent, Trace* child, float32 weight); bool8 insertNewPath(Instance* parent, Instance* child, float32 weight); bool8 insertOldPath(Instance* parent, Instance* child, float32 weight); // method: getPosteriorScore (not implemented yet) // float32 getPosteriorScore(Trace* curr_trace) { return 0.0; } // method to compute the n-symbol probability // float32 getPosteriorScore(Context** context, int32 level); // pruning method // bool8 beamPruneTrace(int32 level_num); bool8 beamPruneInstance(int32 level_num); bool8 instancePruneTrace(int32 level_num); bool8 instancePruneInstance(int32 level_num); bool8 beamPruneLexInstance(int32 level_num); bool8 instancePruneLexInstance(int32 level_num); // initialization routine // bool8 initializeForwardBackward(); // backward probability computation routines // bool8 computeBackward(Vector& data, float32 beta_threshold); bool8 computeBeta(Vector& data, BiGraphVertex* vertex, float32** model_cache); // forward probability computation routines // bool8 computeForward(Vector& data); bool8 computeAlpha(Vector& data, BiGraphVertex* vertex); // cross-word context related methods // bool8 generateRightContexts(DoubleLinkedList& context_list, DoubleLinkedList& score_list, Context* initial_context, int32 depth, int32 level); bool8 initializeRightContexts(Instance* curr_instance, Context& init_context, int32 level_num, float32 curr_weight, int32 depth); bool8 extendRightContexts(Instance* curr_instance, int32 level_num, Context* prev_symbol = (Context*)NULL); bool8 printHistory(const History* history); bool8 isTerminal(Instance* curr_instance, int32 level_num); bool8 propagateUp(int32 curr_level, int32 level, int32 curr_depth, int32 depth, DoubleLinkedList& inst_list, DoubleLinkedList& curr_hist, DoubleLinkedList& prev_hist, DoubleLinkedList& score_list); bool8 propagateDown(Context* symbol, Context* symbol1, GraphVertex* curr_vertex, int32 curr_level, int32 level, int32 curr_depth, int32 depth, DoubleLinkedList& inst_list, DoubleLinkedList& curr_hist, DoubleLinkedList& prev_hist, DoubleLinkedList& score_list); bool8 ascend(Instance* instance); bool8 descend(Instance* instance); bool8 lookAhead(int32 level, int32 depth, DoubleLinkedList& inst_list, DoubleLinkedList& score_list); bool8 lookAheadHelper(GraphVertex* curr_vertex, int32 curr_level, int32 level, int32 curr_depth, int32 depth, DoubleLinkedList& inst_list, DoubleLinkedList& curr_hist, DoubleLinkedList& prev_hist, DoubleLinkedList& score_list); // context related methods // bool8 generateRightLexContexts(DoubleLinkedList& context_list, const Context& initial_context, int32 depth, int32 level); // start search engine from linear structure // bool8 linearStart(); // start search engine from lexical tree structure // bool8 lexicalTreeStart(); // initialize the lexical tree // bool8 initializeLexicalTree(); // lexical tree related methods // bool8 printInstance(Instance* new_instance, int32 level_num = -1, bool8 recursive = false); bool8 changeHistory(Instance*& curr_instance, bool8 ascend, int32 level_num = 0, GraphVertex* start_vert = (GraphVertex*)NULL); bool8 addHypothesisPath(Trace*& new_instance, Trace*& old_instance, GraphArc*& arc, int32 level_num); bool8 addHypothesisPath(Instance*& new_instance, Instance*& old_instance, int32 level_num, float32 weight); bool8 addInstance(int32 level_num, Instance*& curr_instance, Instance*& next_instance, bool8 pruning); bool8 ascendInstance(int32 level_num, Instance*& curr_instance); bool8 descendInstance(int32 level_num, Instance*& curr_instance); bool8 traverseLexInstanceLevels(); bool8 propagateLexInstances(); bool8 propagateLexInstancesUp(int32 level_num); bool8 propagateLexInstancesDown(int32 level_num); bool8 networkLexStart(); bool8 expandLexicalTree(int32 level_num); // convert hypotheses // bool8 addToPath( Instance*& curr_inst, Instance*& prev_inst, Vector< DoubleLinkedList >& inst_levels, Vector< DoubleLinkedList >& prev_levels, int32 level_num, int32& frame_ind, DoubleLinkedList& instance_path); // n-symbol probability related methods // bool8 shiftNSymbol(Trace*& trace, int32 level, GraphVertex* vertex); bool8 shiftNSymbol(Instance*& instance, int32 level, GraphVertex* vertex); bool8 initializeNsymbolInstance(Instance*& instance); bool8 initializeNsymbolTrace(Trace*& trace); bool8 applyNSymbolScore(Instance*& curr_instance, int32 level_num); bool8 addInstanceScore(Instance*& curr_instance, float32 score, int32 level_num, bool8 weight_score); // methods for generation and parsing of transcriptions // given a hierarchical digraph // bool8 getNextNodes(Vector& input_node_a, Vector< Vector >& next_nodes_a, int32 current_level_a, Long& operation_level_d, bool8 exclude_a); Long getSymbolIndex(int32 current_level_a, Vector& input_node_a); bool8 getVectorSymbol(int32 output_level_a, Vector& vec_symbol_a); bool8 getVertexArc(int32 output_level_a, Long& symbol_index_a, Vector& vec_vertex_a, Vector< Pair >& vec_arc_a); }; // end of include file // #endif