// file: $isip/class/search/HierarchicalSearch/hsrch_16.cc // version: $Id: hsrch_16.cc 10528 2006-03-24 16:00:16Z suh $ // // isip include files // #include "HierarchicalSearch.h" typedef Triple< Pair, Float, Boolean> TopoTriple; // methond: parseGraph // // Vector& node_a: (input) input node // Vector& transcription_a: (input) vector of transcription // int32 index_a: (input) count the index of transcription // int32 search_level_a: (input) user specifies search level // Long& operation_level_a: (input) operation level // // return: a bool8 value indicating status // // this mothod parse the graph after getting the next nodes from getNextNodes // bool8 HierarchicalSearch::parseGraph(Vector& node_a, const Vector& transcription_a, int32 index_a, int32 search_level_a, Long& operation_level_a, bool8 exclude_a) { // declare local variables // bool8 success; Long start_node; Long end_node; int32 operation_level; Vector< Vector > next_nodes; Vector temp_node(search_level_a + 1); Vector vec_symbol; Long symbol_index; Vector vec_vertex; Vector< Pair > vec_arc; // get dummy and exclude symbols // Vector vec_dummy; Vector vec_exclude; success = false; start_node = -1; end_node = -2; operation_level = operation_level_a; if (node_a(0) == start_node) { getNextNodes(node_a, next_nodes, 0, operation_level_a, exclude_a); } else { getNextNodes(node_a, next_nodes, operation_level, operation_level_a, exclude_a); } getVectorSymbol((int32)operation_level, vec_symbol); vec_dummy.assign(getHDigraph()(operation_level).getDummySymbolTable()); vec_exclude.assign(getHDigraph()(operation_level).getExcludeSymbolTable()); for (int32 i = 0; i < next_nodes.length(); i++) { if (next_nodes(i)(0) == end_node && index_a != transcription_a.length()-1) { success = false; } else if (next_nodes(i)(0) == end_node && index_a == transcription_a.length()-1) { return true; } else if (next_nodes(i)(0) != end_node) { symbol_index = getSymbolIndex(operation_level, next_nodes(i)); getVertexArc(operation_level, symbol_index, vec_vertex, vec_arc); if (vec_symbol(vec_vertex(next_nodes(i)(operation_level))).eq(transcription_a(index_a))) { if (index_a == transcription_a.length()-1) { return true; } else if(parseGraph(next_nodes(i), transcription_a, index_a + 1, operation_level, operation_level_a, exclude_a)) { success = true; } } } } return success; } // method: getNextNodes // // Vector& input_node: (input) input node // Vector< Vector >& output_node: (output) output nodes // int32 current_level: (input) indicate the current level // Long& operationg_level: (input) user specified output level // bool8 exclude_a: (input) user specified to include exclude symbol // // return: a bool8 value indicating status // // this method get the vector input to indicate the each graph // of corresponding level output the coordinate and next possible arcs // bool8 HierarchicalSearch::getNextNodes(Vector& input_node_a, Vector >& next_nodes_a, int32 current_level_a, Long& operation_level_d, bool8 exclude_a) { int32 operation_level_a; operation_level_a = operation_level_d; // declare start and end symbol // Long start_symbol = -1; Long end_symbol = -2; Long symbol_index; // get symbol, vertex, and arc // Vector vec_symbol; Vector vec_vertex; Vector< Pair > vec_arc; getVectorSymbol((int32)current_level_a, vec_symbol); symbol_index = getSymbolIndex(current_level_a, input_node_a); getVertexArc((int32)current_level_a, symbol_index, vec_vertex, vec_arc); // get dummy and exclude symbols // Vector vec_dummy; Vector vec_exclude; vec_dummy.assign(getHDigraph()(current_level_a).getDummySymbolTable()); vec_exclude.assign(getHDigraph()(current_level_a).getExcludeSymbolTable()); // loop through all the arcs to read possible arcs // for (int32 i = 0; i < vec_arc.length(); i++) { // find the second element of arc, which depends on input node // if (input_node_a(current_level_a) == vec_arc(i).first() ) { Vector temp_next_node(operation_level_a+1); temp_next_node.assign(input_node_a); temp_next_node(current_level_a) = vec_arc(i).second(); if (vec_arc(i).second() == end_symbol) { // go to the lower level to find lower next node // if (current_level_a == 0) { next_nodes_a.concat(temp_next_node); } // update temp_next_node to next_nodes_a // else if (current_level_a != 0) { // this node will be started from begining // temp_next_node(current_level_a) = start_symbol; // find the next node from previous search level // getNextNodes(temp_next_node, next_nodes_a, current_level_a - 1, operation_level_d, exclude_a); } } else if (vec_arc(i).second() != end_symbol) { if ( (vec_dummy.contains(&vec_symbol(vec_vertex(vec_arc(i).second()))) == true) || (vec_exclude.contains(&vec_symbol(vec_vertex(vec_arc(i).second()))) == true && exclude_a == true) ) { getNextNodes(temp_next_node, next_nodes_a, current_level_a, operation_level_d, exclude_a); } // does not contain dummy or exclude symbol // else { // go to the lower level to find lower next node // if (current_level_a < operation_level_a) { // find the next node for lower search level // getNextNodes(temp_next_node, next_nodes_a, current_level_a + 1, operation_level_d, exclude_a); } // update temp_next_node to next_nodes_a // else if (current_level_a == operation_level_a) { // update all the next node to next_nodes // next_nodes_a.concat(temp_next_node); } } } } } // exit gracefully // return true; } // method: getSymbolIndex // // int32 currelt_level_a: (input) current search level // Vector& input_node_a: (input) input node // // return: a Long value for the symbol index // // this method get the symbol index value for the current vertex // Long HierarchicalSearch::getSymbolIndex(int32 current_level_a, Vector& input_node_a) { Long symbol_index; Long end_symbol; symbol_index = 0; end_symbol = -2; Vector vec_vertex; Vector< Pair > vec_arc; for (int32 i = 0; i < current_level_a; i++) { getVertexArc(i,symbol_index, vec_vertex, vec_arc); symbol_index = vec_vertex(input_node_a(i)); } return symbol_index; } // method: getVectorSymbol // // int32 output_level_a: (input) current search level // Vector& vec_symbol_a: (output) symbol table // // return: a bool8 value indicating status // // this method get symbol of the vector for each search level // bool8 HierarchicalSearch::getVectorSymbol(int32 output_level_a, Vector& vec_symbol_a) { int32 symbol_length; symbol_length = getHDigraph()(output_level_a).getSymbolTable().length(); vec_symbol_a.setLength(symbol_length); if (vec_symbol_a.assign(getHDigraph()(output_level_a).getSymbolTable()) != true) { return false; } return true; } // method: getVertexArc // // int32 output_level_a: (input) current search level // Long& symbol_index_a: (input) symbol index // Vector& vec_vertex_a: (output) vector of vertex // Vector< Pair >& vec_arc_a: (output) vector of pair of arc // // this method get the vertex and arc of the vector for each search level // bool8 HierarchicalSearch::getVertexArc(int32 output_level_a, Long& symbol_index_a, Vector& vec_vertex_a, Vector >& vec_arc_a) { // declare a local variable // SingleLinkedList vertex_data; SingleLinkedList arc_data; getHDigraph()(output_level_a).convertSubgraphs()(symbol_index_a).get(vertex_data, arc_data); // set the length of vertex and arc // int32 vertex_length; int32 arc_length; vertex_length = vertex_data.length(); arc_length = arc_data.length(); vec_vertex_a.setLength((int32)vertex_length); vec_arc_a.setLength((int32)arc_length); // store the vertex into vector form // int32 i = 0; vertex_data.gotoFirst(); do { vec_vertex_a(i) = *vertex_data.getCurr(); i++; } while (vertex_data.gotoNext()); // store the arc into vector form // int32 j = 0; arc_data.gotoFirst(); do { vec_arc_a(j).assign(arc_data.getCurr() -> first()); j++; } while (arc_data.gotoNext()); return true; } // generates sentence specified by grammar and probabilty assigned for each arc // bool8 HierarchicalSearch::genRandomSentences(int32 sent_len_a, Random& random_gen_a, int32 output_level_a, String& sentence_a, bool8 exclude_a) { // declare local variables // Context tmp_context(1, 1); ContextPool context_pool; GraphVertex* v = (GraphVertex*)NULL; GraphVertex* nv = (GraphVertex*)NULL; // initialize the trace // v = getHDigraph()(0).getSubGraph(0).getStart(); tmp_context.assignAndAdvance((ulong)v); // create a new trace // Trace* curr_trace = new Trace(); curr_trace->setBackPointer((Trace*)NULL); // set the symbol // curr_trace->setSymbol(context_pool.get(tmp_context)); // set the history // History* history = new History(); curr_trace->setHistory(history); // loop untill we are done // int32 level = 0; int32 sent_len = 0; int32 level_ex = 0; bool8 ascend = false; bool8 status = true; bool8 dummy = false; bool8 skip_dummy = false; bool8 skip_exclude = false; while (status) { status = false; // have we reached the max sentence length // if (sent_len >= sent_len_a) { break; } v = curr_trace->getSymbol()->getCentralVertex(); // get current symbol // String cur_sym; curr_trace ->getSymbol()->print(cur_sym); // declare vector for dummy exclude symbol // Vector vec_dummy; Vector vec_exclude; vec_dummy.assign(getHDigraph()(level).getDummySymbolTable()); // check dummy symbol // for (int i=0 ; isetScore(Trace::INACTIVE_SCORE); if ((level == output_level_a) && !skip_exclude && !skip_dummy) { curr_trace->setScore(Trace::DEF_SCORE); } skip_dummy = false; // can we descend a level? // // if ((level != output_level_a) && !status && !ascend) { if ((level != output_level_a ) && !status && !ascend) { // ignore start and term vertices // if (!v->isStart() && !v->isTerm()) { // does this vertex have a subgraph? // Ulong* subgr_ind = (Ulong*)NULL; Context* tmp_symbol = (Context*)NULL; curr_trace->getSymbol()->convert(tmp_symbol); subgr_ind = getHDigraph()(level).getSubGraphIndex(*tmp_symbol); if (subgr_ind != (Ulong*)NULL) { status = true; ascend = false; // descend the trace // v = (getHDigraph()(level + 1).getSubGraph((int32)*subgr_ind)).getStart(); // create a new trace // Trace* next_trace = new Trace(*curr_trace); next_trace->setBackPointer(curr_trace); // set the symbol // tmp_context.assignAndAdvance((ulong)v); //tmp_context.print(); next_trace->setSymbol(context_pool.get(tmp_context)); // set the history // next_trace->getHistory()->push(curr_trace->getSymbol()); curr_trace = next_trace; level++; } // free memory // if (tmp_symbol != (Context*)NULL) { delete tmp_symbol; } } } // can we ascend a level? // if ((level != 0) && !status) { // consider only terminals or dummy symbol in the lower level // if (v->isTerm() || dummy) { status = true; ascend = true; dummy = false; v = curr_trace->getHistory()->peek()->getCentralVertex(); // create a new trace // Trace* next_trace = new Trace(*curr_trace); next_trace->setBackPointer(curr_trace); // set the symbol // tmp_context.assignAndAdvance((ulong)v); //tmp_context.print(); next_trace->setSymbol(context_pool.get(tmp_context)); // set the history // next_trace->getHistory()->pop(); curr_trace = next_trace; level--; } } // can we move forward // if ((v->length() >= 0) && !status) { //if ((v->length() > 0) && !status) { // create a new trace // Trace* next_trace = new Trace(*curr_trace); next_trace->setBackPointer(curr_trace); float64 score = 0.0; float64 prob = random_gen_a.get(); float64 scale = 0.0; float64 weight = 0.0; // compute the scale across all arcs // for (bool8 more = v->gotoFirst(); more; more = v->gotoNext()) { weight = Integral::exp(v->getCurr()->getWeight()); scale += weight; } // inverse the scale // scale = 1.0/scale; for (bool8 more = v->gotoFirst(); more; more = v->gotoNext()) { weight = Integral::exp(v->getCurr()->getWeight()) * scale; nv = v->getCurr()->getVertex(); if ((prob >= score) && (prob <= (score + weight))) { // set the symbol // sent_len++; tmp_context.assignAndAdvance((ulong)nv); //tmp_context.print(); next_trace->setSymbol(context_pool.get(tmp_context)); ascend = false; status = true; break; } score = score + weight; } // dummy symbol check // if (level > 0 && !v->gotoNext() && weight == 0 ) { status = true; dummy = true; } //String symbol; //next_trace->getSymbol()->print(symbol); //symbol.debug(L"symbol"); curr_trace = next_trace; } } Trace* tmp_trace = curr_trace; DoubleLinkedList trace_list(DstrBase::USER); bool8 skip = false; while (curr_trace != (Trace*)NULL) { trace_list.insertFirst(curr_trace); curr_trace = curr_trace->getBackPointer(); } sentence_a.clear(); for (bool8 more = trace_list.gotoFirst(); more; more = trace_list.gotoNext()) { curr_trace = trace_list.getCurr(); v = curr_trace->getSymbol()->getCentralVertex(); if ((curr_trace->getScore() == Trace::DEF_SCORE) && !v->isStart() && !v->isTerm()) { String symbol; curr_trace->getSymbol()->print(symbol); Vector vec_dummy; vec_dummy.assign(getHDigraph()(output_level_a).getDummySymbolTable()); // check dummy_symbol // for (int i = 0; i < vec_dummy.length();i++) { if (symbol.eq(vec_dummy(i))) { skip = true; } } // check dummy // if (!skip) { sentence_a.concat(symbol); sentence_a.concat(L" "); } } // // skip = false; } // free memory // Trace::deleteTrace(tmp_trace, true); delete history; history = (History*)NULL; // exit gracefully // return true; }