// file: $isip/class/pr/PhoneticDecisionTree/pdt_07.cc // version: $Id: pdt_07.cc 9424 2004-03-10 21:09:43Z parihar $ // // isip include files // #include "PhoneticDecisionTree.h" // method: getStatTrain // // arguments: // Vector& comtext_map: (input) context-map // Vector& sub_graphs: (input) sub-graphs // Vector& symbol_table: (input) symbol-table // Vector& contextless_symbol_table: (input) contextless // symbol-table // Vector& stat_models: (input/output) Pool of // StatisticalModels // HashTable& symbol_hash: (input/output) Mapping from // symbols to StatisticalModels // Filename& ques_ans_file: (input) file for phonetic questions and answers // HashTable& tied_symbol_hash: (output) hashtable // containing the tied symbols // Vector& tied_stat_models_a: tied models // // return: a bool8 value indicating status // // this method gets the new set of statistical models and a mapping // table that maps symbols(states) to the statistical models. it also // stores the decision-tree in the binary format // bool8 PhoneticDecisionTree::getStatTrain(Vector& context_map_a, Vector >& sub_graphs_a, Vector& symbol_table_a, Vector& contextless_symbol_table_a, HashTable& symbol_hash_a, Vector& stat_models_a, Filename& phonetic_dt_file_a, HashTable& tied_symbol_hash_a, Vector& tied_stat_models_a) { // local variables // bool8 res = true; TreeNode* root_node = (TreeNode*)NULL; SingleLinkedList leaf_nodes(DstrBase::USER); HashTable map; Vector stat_models_out; HashTable symbol_hash_out; // get all the leaf nodes below the root-node // root_node = getFirst(); // check the node // if (root_node == (TreeNode*)NULL) { return Error::handle(name(), L"getStatTrain - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } res = getLeafNodes(*root_node, leaf_nodes); // set the capacity of the stat_models_out. each leaf node has a // typical-statistical model // int32 stat_len = leaf_nodes.length() + tied_stat_models_a.length(); stat_models_out.setCapacity(stat_len); stat_models_out.setLength(leaf_nodes.length()); // loop-over all the leaf-nodes that exists (that were not merged) // and reindex the map // for (bool8 more = leaf_nodes.gotoFirst(); more; more = leaf_nodes.gotoNext()) { // local variables // TreeNode* node = (TreeNode*)NULL; // get the leaf-node // node = leaf_nodes.getCurr(); // check the node // if (node == (TreeNode*)NULL) { return Error::handle(name(), L"getStatTrain - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // local variables // PhoneticDecisionTreeNode* pdt_node = (PhoneticDecisionTreeNode*)NULL; Long index = (Long)-1; bool8 flag_exists = true; // loop-over all the data at this leaf-node and get the typical // statistical-model to which all the rest of the // statistical-models at this node will be tied to. the model with // the highest scale or lowest variance is a typical model // // get all the data points on the node // pdt_node = node->getItem(); // check the node // if (pdt_node == (PhoneticDecisionTreeNode*)NULL) { return Error::handle(name(), L"getStatTrain - NULL PDT VERTEX", Error::ARG, __FILE__, __LINE__); } Data& data = pdt_node->getDataPoints(); // get the typical statistical-model at this pdt-node // StatisticalModel& typical_stat_model = pdt_node->getTypicalStatModel(); // see if this pdt-node exists // flag_exists = pdt_node->getFlagExists(); // proceed further only if this node exists, otherwise the data on // this node is already merged to some other node // if(flag_exists) { // get the typical-index of the statistical-model at this pdt-node // index = pdt_node->getTypicalIndex(); for (bool8 morea = data.gotoFirst(); morea; morea = data.gotoNext()) { // local variables // Long temp_index = 0; // get the index of this statistical-model // DataPoint* datapoint = data.getCurr(); temp_index = datapoint->first(); // reindexing from previous to new indices in the hashtable // if (!map.insert(temp_index, &index)) { return Error::handle(name(), L"getStatTrain", ERR, __FILE__, __LINE__); } } // end of for-loop over all datapoints at a leafnode } // end of if statement for leaf-nodes that exists // add the typical statistical model at this leafnode to the new // pool of statistical models // int32 actual_index = -1; actual_index = pdt_node->getActualIndex(); stat_models_out(actual_index).assign(typical_stat_model); } // end of for loop over leaf-nodes // create the new mapping table // Vector symbols; res = symbol_hash_a.keys(symbols); // add the untied models // for (int32 moreb = 0; moreb < symbols.length(); moreb++) { SearchSymbol temp_symbol = symbols(moreb); Long prev_index; Long* prev_ptr = (Long*)NULL; prev_ptr = symbol_hash_a.get(temp_symbol); if (prev_ptr == (Long*)NULL) { temp_symbol.debug(L"symbol: "); return Error::handle(name(), L"getStatTrain - symbol not found in the original mapping", ERR, __FILE__, __LINE__); } else { prev_index = *prev_ptr; } Long new_index; // if the statistical model corresponding to the search symbol is // not tied // if (!tied_symbol_hash_a.containsKey(temp_symbol)) { Long* ptr = (Long*)NULL; ptr = map.get(prev_index); if (ptr == (Long*)NULL) { prev_index.debug(L"index: "); return Error::handle(name(), L"getStatTrain - index not found in the new mapping", ERR, __FILE__, __LINE__); } else { new_index = *(ptr); } if (!symbol_hash_out.insert(symbols(moreb), &new_index)) { return Error::handle(name(), L"getStatTrain", ERR, __FILE__, __LINE__); } } } // update the tied and contexts with contextless symbols as central // symbols // for (int32 i = 0; i < tied_stat_models_a.length(); i++) { // add the statistical model // Long tmp_index = stat_models_out.length(); stat_models_out.setLength((int32)tmp_index + (int32)1); stat_models_out(tmp_index).assign(tied_stat_models_a(i)); // add all the search-symbols that are tied to this same // statistical-model // Vector tied_symbols; tied_symbol_hash_a.keys(tied_symbols); for (int32 j = 0; j < tied_symbols.length(); j++) { int32 old_index; Long* ptr = (Long*)NULL; ptr = tied_symbol_hash_a.get(tied_symbols(j)); if (ptr == (Long*)NULL) { tied_symbols(j).debug(L"symbol: "); return Error::handle(name(), L"getStatTrain - symbol not found in the new mapping", ERR, __FILE__, __LINE__); } else { old_index = *ptr; } // update hash_table with the search-symbol & statistical-model // index // if (old_index == i) { if (!symbol_hash_out.insert(tied_symbols(j), &tmp_index)) { return Error::handle(name(), L"getStatTrain - non-unique search-symbols", ERR, __FILE__, __LINE__); } } } } // update the statistical-models and symbol hashtable // stat_models_a.assign(stat_models_out); symbol_hash_a.assign(symbol_hash_out); // store the decision-tree into the output file in the binary format // if (phonetic_dt_file_a.length() == 0) { return Error::handle(name(), L"getStatTrain - invalid decision-tree file", ERR, __FILE__, __LINE__); } // open the decision-tree sof file // Sof out_sof; if(!out_sof.open(phonetic_dt_file_a, File::WRITE_ONLY, File::BINARY)) { return Error::handle(phonetic_dt_file_a, L"getStatTrain", ERR, __FILE__, __LINE__); } // write the decision-tree to an sof file // write(out_sof, int32(0)); // close the output sof file // out_sof.close(); // exit gracefully // return res; } // method: getStatTest // // arguments: // Vector& context_map: (input) contextMaps // int32& left_context: (input) length of the left-context // int32& right_context: (input) length of the right-context // Vector& upper_symbol_table: (input) symbol-table at level // upper to the lowest level // Vector& upper_contextless_symbol_table: (input) contextless // symbol-table // Vector >& sub_graphs: (input) subgraphs for the // contextMaps // Vector& symbol_table: (input) symbol-table (states) // HashTable& symbol_hash: (input) mapping from symbols // to statistical-models // Filename& ques_ans_file: (input) file for phonetic questions and answers // // return: a bool8 value indicating status // // this method runs the PhoneticDecisionTree class in TEST mode // bool8 PhoneticDecisionTree::getStatTest(Vector& context_map_a, int32& left_context_a, int32& right_context_a, Vector& upper_symbol_table_a, Vector& upper_contextless_symbol_table_a, Vector >& sub_graphs_a, Vector& symbol_table_a, HashTable& symbol_hash_a, Filename& ques_ans_file_a) { // local variables // Vector > ques_ans; SingleLinkedList > questions; HashTable answers; bool8 res = true; // check the question and answers // if (ques_ans_file_a.length() == 0) { return Error::handle(name(), L"loadTest - invalid question-answer file", ERR, __FILE__, __LINE__); } // print debugging information // if (debug_level_d >= Integral::DETAILED) { Console::increaseIndention(); String output; output.assign(L"\nloading question and answers: "); output.concat(ques_ans_file_a); Console::put(output); Console::decreaseIndention(); } // open the input sof file // Sof input_sof; if(!input_sof.open(ques_ans_file_a, File::READ_ONLY)) { return Error::handle(ques_ans_file_a, L"open", ERR, __FILE__, __LINE__); } // read the question and answers from an sof file // ques_ans.read(input_sof, int32(0)); // close the input questions and answers file // input_sof.close(); // compute the length of the context // int32 context_len = (int32)1 + left_context_a + right_context_a; // create all possible context-symbols of given length using the // upper_symbol_table // Vector all_context_map; createContexts(upper_symbol_table_a, context_len, all_context_map); // remove all the contexts that are not allowed // Vector valid_context_map; validateContexts(upper_contextless_symbol_table_a, all_context_map, valid_context_map); // get the contexts that do not exist in the input context-maps // Vector unseen_context_map; getUnseenContexts(context_map_a, valid_context_map, unseen_context_map); // create the new context-map, lower-level subgraphs, symbol-table // updateLowerLevel(context_map_a, unseen_context_map, sub_graphs_a, symbol_table_a, symbol_hash_a); // add the central-phone and state indices as the first two // attributes. note that the PhoneticDecisionTree is first split // using the central-symbol (phone) attribute, then using the // symbol-position (state) and then recursively using the phonetic // questions. this way we have n*m number of sub-trees. // Attribute temp_attribute; String temp_cph; temp_cph.assign(CPH); String temp_pos; temp_pos.assign(POS); SingleLinkedList temp_all_cph; SingleLinkedList temp_all_pos; // get central phones that have context, don't include the last two // symbols since NO_LEFT_CONTEXT and NO_RIGHT_CONTEXT // for (int32 k = 0; k < upper_symbol_table_a.length(); k++) { // local variables // String temp_symbol; bool8 include = true; temp_symbol.assign(upper_symbol_table_a(k)); // don't include the contextless symbols // for (int32 kk = 0; kk < upper_contextless_symbol_table_a.length(); kk++) { if (temp_symbol.eq(upper_contextless_symbol_table_a(kk))) { include = false; } } // don't include NO_LEFT_CONTEXT and NO_RIGHT_CONTEXT // if ((temp_symbol.eq(SearchSymbol::NO_LEFT_CONTEXT)) || (temp_symbol.eq(SearchSymbol::NO_RIGHT_CONTEXT))) { include = false; } // include only necessary symbols // if (include) { temp_all_cph.insert(&temp_symbol); } } for (int32 l = 0; l < context_len; l++) { String temp_pos; temp_pos.assign(l); temp_all_pos.insert(&temp_pos); } temp_attribute.assign(temp_cph, temp_all_cph); attributes_d.insert(&temp_attribute); temp_attribute.assign(temp_pos, temp_all_pos); attributes_d.insert(&temp_attribute); // get all possible answers for phonetic-decision-trees // SingleLinkedList temp_all_ans; String tmp_yes; tmp_yes.assign(YES); String tmp_no; tmp_no.assign(NO); temp_all_ans.insert(&tmp_yes); temp_all_ans.insert(&tmp_no); // read questions and then answeres seperately that will be used // later // for (int32 k = 0; k < ques_ans.length(); k++) { // local variables // Long temp_index; String temp_string; String temp_ques; String temp_symbol; Pair temp_question; temp_index = ques_ans(k).first(); temp_ques = ques_ans(k).second(); temp_symbol = ques_ans(k).third(); // read the questions only if the question is not already read // before // temp_question.assign(temp_index, temp_ques); if (!questions.contains(&temp_question)) { questions.insert(&temp_question); // add the questions & possible-answers to the attributes_d // Attribute temp_attribute; String temp_question_string; temp_question_string.assign(temp_index); temp_question_string.concat(temp_ques); temp_attribute.assign(temp_question_string, temp_all_ans); attributes_d.insert(&temp_attribute); } // read the answeres variable // temp_string.assign(temp_index); temp_string.concat(temp_ques); temp_string.concat(temp_symbol); String temp_yes; temp_yes.assign(YES); answers.insert(temp_string, &temp_yes); } // loop-over all the unseen context-maps with no central symbol as // contextless, and tie corresponding statistical-models // for (int32 i = 0; i < unseen_context_map.length(); i++) { // local variables // Vector context; Ulong context_index; DiGraph sub_graph; SingleLinkedList snodes; SingleLinkedList sarcs; HashTable attr_value; // get the context and the context-index // context = unseen_context_map(i).getContext(); context_index = unseen_context_map(i).getContextIndex(); // extract the attributes for this context by looping over all the // questions // for (bool8 morea = questions.gotoFirst(); morea; morea = questions.gotoNext()) { // local variables // Pair* question; Long direction; SearchSymbol symbol; String question_string; String extended_question; String direction_string; String* answer; String temp_no; temp_no.assign(NO); // get the direction (this also gives the position) from the // question // question = questions.getCurr(); direction = question->first(); // get the search symbol corresponding to the direction and // position // symbol = context((context.length()/2) + direction); // combine direction(position), question and symbol // direction_string.assign(direction); question_string.concat(direction_string); question_string.concat(question->second()); extended_question.assign(question_string); extended_question.concat(symbol); // get the answer for this extended questions from the answers // hashtable. if the hashtable contains the key, it has the // answer. else the answer is NO // if (answers.containsKey(extended_question)) { answer = answers.get(extended_question); } else answer = &temp_no; // use the question_string and answer to form attribute-value pair // that is added to the attr_value hashtable // attr_value.insert(question_string, answer); } // add the central-symbol (central-phone) to the attributes // String answer; SearchSymbol symbol = context((context.length() - 1)/2); answer.assign(symbol); attr_value.insert(temp_cph, &answer); // get the corresponding sub-graph for this context-index // sub_graph = sub_graphs_a(context_index); // loop-over all the SearchNode's in this sub-graph, get the // symbol id's (states) and then accumulate the statistical-models // corresponding to the symbols (states) // sub_graph.get(snodes, sarcs); // intialize the position of the symbol (state) // int32 position = (int32)0; for (bool8 more = snodes.gotoFirst(); more; more = snodes.gotoNext()) { // local variables // SearchNode* snode; int32 symbol_id; Long* stat_model_index; StatisticalModel stat_model; HashTable temp_attr_value; // create a DataPoint that corresponds to the search-symbol at // this node // DataPoint data_point; StatisticalModel sm; GaussianModel gm; MixtureModel mm; mm.add(gm); sm.assign(mm); sm.setOccupancy((Double)0); // get the SearchNode // snode = snodes.getCurr(); // get the symbol id at this search node // symbol_id = snode->getSymbolId(); // proceed only if the symbol is not dummy or default, // I.e. symbol-id is non-negative // if(symbol_id >= (int32)0) { // get the index corresponding to this symbol (state) // SearchSymbol ss = symbol_table_a(symbol_id); stat_model_index = symbol_hash_a.get(ss); // add the symbol-position (state-position) to the // attributes, increment the position // temp_attr_value.assign(attr_value); String position_string; position_string.assign(position); temp_attr_value.insert(temp_pos, &position_string); position++; // get a datapoint by combining statistical-model index, // statistical-model and the hash-table of attribute-value // data_point.assign(*stat_model_index, sm, temp_attr_value); // get the index for this data-point to which it is tied // Long n_index = 0; n_index = classifyDataPoint(data_point); // update the index in the symbol hash-table // symbol_hash_a.remove(ss); if (!symbol_hash_a.insert(ss, &n_index)) { return Error::handle(name(), L"getStatTest", ERR, __FILE__, __LINE__); } } // end of if statement } // end of loop-over of SearchNodes that contain search-symbol id's } // end of loop-over of unseen context-maps // exit gracefully // return res; }