ToPS
|
00001 /* 00002 * ContextTree.cpp 00003 * 00004 * Copyright 2011 Andre Yoshiaki Kashiwabara <akashiwabara@usp.br> 00005 * Ígor Bonádio <ibonadio@ime.usp.br> 00006 * Vitor Onuchic <vitoronuchic@gmail.com> 00007 * Alan Mitchell Durham <aland@usp.br> 00008 * 00009 * This program is free software; you can redistribute it and/or modify 00010 * it under the terms of the GNU General Public License as published by 00011 * the Free Software Foundation; either version 3 of the License, or 00012 * (at your option) any later version. 00013 * 00014 * This program is distributed in the hope that it will be useful, 00015 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00016 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00017 * GNU General Public License for more details. 00018 * 00019 * You should have received a copy of the GNU General Public License 00020 * along with this program; if not, write to the Free Software 00021 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, 00022 * MA 02110-1301, USA. 00023 */ 00024 00025 #include "ContextTree.hpp" 00026 #include "util.hpp" 00027 #include <vector> 00028 #include "Symbol.hpp" 00029 00030 namespace tops 00031 { 00032 void ContextTree::buildParameters(ContextTreeNodePtr node, std::map<std::string, double> & parameters) const 00033 { 00034 ContextTreeNodePtr current = node; 00035 AlphabetPtr alphabet = current->getDistribution()->alphabet(); 00036 std::vector <std::string> aux; 00037 bool root=true; 00038 if(current != getRoot() ) 00039 root = false; 00040 while(current != getRoot()) 00041 { 00042 aux.push_back(alphabet->getSymbol(current->symbol())->name() ); 00043 current = _all_context[current->getParent()]; 00044 } 00045 for(int k = 0; k < (int)alphabet->size(); k++) 00046 { 00047 std::stringstream out; 00048 out << alphabet->getSymbol(k)->name() << "|" ; 00049 if(!root) 00050 { 00051 out << aux[aux.size() -1]; 00052 for(int i = aux.size()-2; i >=0; i--) 00053 out << " " << aux[i] ; 00054 } 00055 parameters[out.str()] = exp(node->getDistribution()->log_probability_of(k)); 00056 } 00057 if(!node->isLeaf()) 00058 for(int l = 0; l < node->alphabet_size(); l++) 00059 if(node->getChild(l) != NULL) 00060 buildParameters(node->getChild(l), parameters); 00061 00062 } 00063 DoubleMapParameterValuePtr ContextTree::getParameterValue () const 00064 { 00065 std::map<std::string, double> probabilities; 00066 buildParameters(getRoot(), probabilities); 00067 DoubleMapParameterValuePtr v = DoubleMapParameterValuePtr(new DoubleMapParameterValue(probabilities)); 00068 return v; 00069 } 00070 ContextTreeNode::ContextTreeNode(int alphabet_size) 00071 { 00072 _child.resize(alphabet_size); 00073 _counter.resize(alphabet_size); 00074 for(int i = 0; i < (int)_counter.size(); i++) 00075 _counter[i] = 0; 00076 _symbol = -1; 00077 _alphabet_size = alphabet_size; 00078 _leaf = true; 00079 _id = 0; 00080 } 00081 00082 ContextTreeNode::ContextTreeNode() { 00083 _symbol = -1; 00084 _alphabet_size = 0; 00085 _leaf = true; 00086 } 00087 00088 void ContextTreeNode::addCount (int s) { 00089 _counter[s] += 1.0; 00090 } 00091 00092 void ContextTreeNode::addCount (int s, double weight) { 00093 _counter[s] += weight; 00094 } 00095 00096 00097 void ContextTreeNode::setCount (int s, double v) { 00098 _counter[s] = v; 00099 } 00100 00101 std::vector<double> & ContextTreeNode::getCounter () { 00102 return _counter; 00103 } 00104 00105 int ContextTreeNode::alphabet_size() { 00106 return _alphabet_size; 00107 } 00108 00109 void ContextTreeNode::setParent(int parent) 00110 { 00111 _parent_id = parent; 00112 } 00113 int ContextTreeNode::getParent() 00114 { 00115 return _parent_id; 00116 } 00117 int ContextTreeNode::id () { 00118 return _id; 00119 } 00120 void ContextTreeNode::setId(int id) 00121 { 00122 _id = id; 00123 } 00124 00125 void ContextTreeNode::setChild(ContextTreeNodePtr child, int symbol){ 00126 assert((symbol >= 0) && (symbol < (int)_child.size())); 00127 _child[symbol] = child; 00128 child->setSymbol(symbol); 00129 child->setParent(id()); 00130 _leaf = false; 00131 } 00132 00133 int ContextTreeNode::symbol(){ 00134 return _symbol; 00135 } 00136 00137 void ContextTreeNode::setSymbol(int symbol){ 00138 _symbol = symbol; 00139 } 00140 00141 void ContextTreeNode::setDistribution(DiscreteIIDModelPtr distribution){ 00142 _distribution = distribution; 00143 } 00144 ContextTreeNodePtr ContextTreeNode::getChild(int symbol){ 00145 if(!((symbol >= 0) && (symbol < (int)_child.size()))) 00146 { 00147 std::cerr << "ERROR: ContextTree has reached an invalid node !" << std::endl; 00148 std::cerr << "Symbol id : " << symbol << std::endl; 00149 exit(-1); 00150 } 00151 return _child[symbol] ; 00152 } 00153 DiscreteIIDModelPtr ContextTreeNode::getDistribution(){ 00154 return _distribution; 00155 } 00156 void ContextTreeNode::deleteChildren() { 00157 ContextTreeNodePtr n; 00158 for(int m = 0; m < (int)_child.size(); m++) 00159 if(_child[m] != NULL) 00160 _child[m]->setParent(-1); 00161 _child.resize(0); 00162 _child.resize(_alphabet_size); 00163 _leaf = true; 00164 } 00165 std::vector <ContextTreeNodePtr> ContextTreeNode::getChildren() { 00166 return _child; 00167 } 00168 00169 00170 00171 bool ContextTreeNode::isLeaf(){ 00172 return _leaf; 00173 } 00174 00175 std::string ContextTreeNode::str() const { 00176 std::stringstream out; 00177 return out.str(); 00178 } 00179 00180 00181 00182 void ContextTree::printTree(ContextTreeNodePtr node, std::stringstream & out) const 00183 { 00184 ContextTreeNodePtr current = node; 00185 while(current != getRoot()) 00186 { 00187 out << current->symbol() << " "; 00188 current = _all_context[current->getParent()]; 00189 } 00190 int sum = 0; 00191 out << ": "; 00192 for(int k = 0; k < (int)_alphabet->size(); k++) 00193 { 00194 sum += (node->getCounter())[k]; 00195 out << (node->getCounter())[k] << " " ; 00196 } 00197 00198 out << "("<< node->id() << ", "<< sum << ") "; 00199 if(!node->isLeaf()) 00200 out << "internal node" << std::endl; 00201 else 00202 out << "leaf node" << std::endl; 00203 if(!node->isLeaf()) 00204 for(int l = 0; l < node->alphabet_size(); l++) 00205 if(node->getChild(l) != NULL) 00206 printTree(node->getChild(l), out); 00207 } 00208 00209 00210 00211 00212 ContextTree::ContextTree(AlphabetPtr alphabet){ 00213 _alphabet = alphabet; 00214 } 00215 00216 ContextTreeNodePtr ContextTree::getRoot() const { 00217 return _all_context[0]; 00218 } 00219 00220 00221 ContextTreeNodePtr ContextTree::createContext() { 00222 ContextTreeNodePtr n = ContextTreeNodePtr(new ContextTreeNode(_alphabet->size())); 00223 n->setId(_all_context.size()); 00224 if(n->id() == 0) 00225 { 00226 n->setParent(0); 00227 } 00228 _all_context.push_back(n); 00229 return n; 00230 } 00231 00232 ContextTreeNodePtr ContextTree::getContext (int id) 00233 { 00234 return _all_context[id]; 00235 } 00236 00237 00239 ContextTreeNodePtr ContextTree::getContext(const Sequence & s, int i){ 00240 ContextTreeNodePtr c = _all_context[0]; 00241 ContextTreeNodePtr p; 00242 int j; 00243 for(j = i-1; j >=0; j--){ 00244 if(c->isLeaf()) 00245 break; 00246 p = c; 00247 c = c->getChild(s[j]); 00248 if(c == NULL) 00249 { 00250 c = p; 00251 break; 00252 } 00253 } 00254 #if 0 00255 if(c == NULL) 00256 { 00257 std::cerr << "WARNING: You have reached an undefined context !"<< std::endl; 00258 std::cerr << "WARNING: Probability distribution for the following context was not defined: " << std::endl; 00259 for(int k = j; k < i; k++) 00260 std::cerr << s[k] << " " ; 00261 std::cerr << std::endl; 00262 std::cerr << "Position: " << j << " " << i-1 << std::endl; 00263 } 00264 #endif 00265 return c; 00266 } 00267 00268 std::set <int> ContextTree::getLevelOneNodes() 00269 { 00270 std::set<int> result; 00271 for(int i = 0; i < (int)_all_context.size(); i++) 00272 if(_all_context[i]->isLeaf()) 00273 { 00274 int parent_id = _all_context[i]->getParent(); 00275 if (parent_id < 0) 00276 continue; 00277 ContextTreeNodePtr parent = _all_context[parent_id]; 00278 bool levelOne = true; 00279 for(int l =0; l < (int)_alphabet->size(); l++) 00280 if((parent->getChild(l) != NULL )&& !parent->getChild(l)->isLeaf()) 00281 levelOne = false; 00282 if(levelOne) 00283 result.insert(parent->id()); 00284 } 00285 return result; 00286 } 00287 00288 00289 void ContextTree::removeContextNotUsed() 00290 { 00291 std::vector <ContextTreeNodePtr> newAllVector; 00292 for(int i = 0; i < (int)_all_context.size(); i++) 00293 { 00294 if(_all_context[i]->getParent() >= 0) 00295 { 00296 _all_context[i]->setId(newAllVector.size()); 00297 newAllVector.push_back(_all_context[i]); 00298 if((_all_context[i] != NULL) && (!_all_context[i]->isLeaf())) 00299 for(int m = 0; m < (int)_alphabet->size(); m++) 00300 _all_context[i]->getChild(m)->setParent(_all_context[i]->id()); 00301 } 00302 } 00303 _all_context = newAllVector; 00304 } 00305 00306 void ContextTree::normalize() 00307 { 00308 std::vector <ContextTreeNodePtr> newAllVector; 00309 for(int i = 0; i < (int)_all_context.size(); i++) 00310 { 00311 double total = 0; 00312 DoubleVector probs(_alphabet->size()); 00313 for(int l = 0; l < (int)_alphabet->size(); l++) 00314 total += (double)(_all_context[i]->getCounter())[l]; 00315 for(int l = 0; l < (int)_alphabet->size(); l++){ 00316 probs[l] = (double)((_all_context[i]->getCounter())[l])/total; 00317 } 00318 DiscreteIIDModelPtr distr = DiscreteIIDModelPtr(new DiscreteIIDModel(probs)); 00319 distr->setAlphabet(_alphabet); 00320 _all_context[i]->setDistribution(distr); 00321 } 00322 } 00323 00324 void ContextTree::normalize(ProbabilisticModelPtr old, double pseudocount, int t) 00325 { 00326 if(old == NULL){ 00327 std::cerr << "ERROR: ContextTree -> a priori model is null !" << std::endl; 00328 exit(-1); 00329 } 00330 std::vector <ContextTreeNodePtr> newAllVector; 00331 for(int i = 0; i < (int)_all_context.size(); i++) 00332 { 00333 double total = 0; 00334 DoubleVector probs(_alphabet->size()); 00335 00336 Sequence s; 00337 ContextTreeNodePtr current = _all_context[i]; 00338 bool valid = true; 00339 while (current != getRoot() ) { 00340 s.push_back (current->symbol()); 00341 if(current->getParent() < 0) { 00342 valid = false; 00343 break; 00344 } 00345 current = _all_context[current->getParent()]; 00346 } 00347 if(!valid) 00348 continue; 00349 00350 for(int l = 0; l < (int)_alphabet->size(); l++) { 00351 total += (double)(_all_context[i]->getCounter())[l]; 00352 } 00353 for(int l = 0; l < (int)_alphabet->size(); l++){ 00354 Sequence s3; 00355 s3 = s; 00356 s3.push_back(l); 00357 00358 double prob = exp(old->evaluatePosition(s3,s3.size()-1, t)); 00359 00360 probs[l] = (double)((_all_context[i]->getCounter())[l] + pseudocount*prob)/(total + pseudocount); 00361 } 00362 00363 DiscreteIIDModelPtr distr = DiscreteIIDModelPtr(new DiscreteIIDModel(probs)); 00364 distr->setAlphabet(_alphabet); 00365 _all_context[i]->setDistribution(distr); 00366 } 00367 } 00368 00369 std::string ContextTree::str() const{ 00370 std::stringstream out; 00371 std::vector <ContextTreeNodePtr> stack; 00372 ContextTreeNodePtr current = getRoot(); 00373 printTree(current, out); 00374 return out.str(); 00375 } 00376 00377 void ContextTree::initializeCounter(const SequenceEntryList & sequences, int order, const std::map<std::string, double> & weights) 00378 { 00379 initializeCounter(sequences, order,0, weights); 00380 } 00381 00382 void ContextTree::initializeCounter(const SequenceEntryList & sequences, int order, double pseudocounts, const std::map<std::string, double> & weights) 00383 { 00384 if (order < 0) order = 0; 00385 00386 ContextTreeNodePtr root = createContext(); 00387 if(pseudocounts > 0) { 00388 for(int sym = 0; sym < root->alphabet_size(); sym++) 00389 { 00390 root->setCount(sym,pseudocounts); 00391 } 00392 } 00393 00394 for ( int l = 0; l < (int)sequences.size(); l ++){ 00395 std::string seqname = sequences[l]->getName(); 00396 double weight = 1.0; 00397 if (weights.find(seqname) != weights.end()) 00398 weight = (weights.find(seqname)->second); 00399 // std::cerr << seqname << " with weight " << weight << std::endl; 00400 for( int i = order; i < (int)(sequences[l]->getSequence()).size(); i++) 00401 { 00402 int currentSymbol = (sequences[l]->getSequence())[i]; 00403 int j = i - 1; 00404 00405 ContextTreeNodePtr w = getRoot(); 00406 00407 w->addCount(currentSymbol, weight); 00408 00409 while((j >= 0) && ((i - j) <= order)) 00410 { 00411 int symbol = (sequences[l]->getSequence())[j]; 00412 if((w->getChild(symbol) == NULL) || w->isLeaf()) 00413 { 00414 ContextTreeNodePtr c2 = createContext(); 00415 w->setChild(c2, symbol); 00416 } 00417 w = w->getChild(symbol); 00418 00419 if(pseudocounts > 0) { 00420 for(int sym = 0; sym < root->alphabet_size(); sym++) 00421 { 00422 if(w->getCounter()[sym] <= 0.0) 00423 w->setCount(sym,pseudocounts); 00424 } 00425 } 00426 00427 00428 w->addCount(currentSymbol, weight); 00429 j -- ; 00430 } 00431 } 00432 } 00433 } 00434 00435 void ContextTree::pruneTreeSmallSampleSize(int small_) 00436 { 00437 00438 std::set<int> x = getLevelOneNodes(); 00439 std::vector<int> nodesToPrune (x.begin(),x.end()); 00440 std::set<int>::iterator it; 00441 00442 while(nodesToPrune.size() > 0) 00443 { 00444 int id = nodesToPrune.back(); 00445 nodesToPrune.pop_back(); 00446 double total = 0.0; 00447 ContextTreeNodePtr parentNode = getContext(id); 00448 if(parentNode->isLeaf()) 00449 break; 00450 00451 for(int m = 0; m < (int)_alphabet->size(); m++) 00452 total += (parentNode->getCounter())[m]; 00453 bool foundSmall = false; 00454 for (int l = 0; l < (int)_alphabet->size(); l++) 00455 { 00456 ContextTreeNodePtr childNode = parentNode->getChild(l); 00457 if(childNode == NULL) 00458 continue; 00459 double totalchild = 0; 00460 for(int m = 0; m < (int)_alphabet->size(); m++) 00461 { 00462 totalchild += (childNode->getCounter())[m]; 00463 } 00464 if(totalchild < small_){ 00465 for(int m = 0; m < (int)_alphabet->size(); m++) 00466 { 00467 (childNode->getCounter())[m] = (parentNode->getCounter())[m]; 00468 } 00469 } 00470 } 00471 if(total < small_) 00472 { 00473 parentNode->deleteChildren(); 00474 ContextTreeNodePtr parentNode2 = getContext(parentNode->getParent()); 00475 bool toPrune = true; 00476 for(int l = 0; l < (int)_alphabet->size(); l++) 00477 if((parentNode2->getChild(l) != NULL) && !(parentNode2->getChild(l)->isLeaf())) 00478 { 00479 toPrune = false; 00480 break; 00481 } 00482 if(toPrune) 00483 nodesToPrune.push_back(parentNode2->id()); 00484 } 00485 } 00486 } 00487 00488 00489 00490 void ContextTree::pruneTree(double delta) 00491 { 00492 00493 double sample_size = 0.0; 00494 for (int l = 0; l < (int)_alphabet->size(); l++) 00495 sample_size += (getRoot()->getCounter())[l]; 00496 std::set<int> x = getLevelOneNodes(); 00497 std::vector<int> nodesToPrune (x.begin(),x.end()); 00498 std::set<int>::iterator it; 00499 double small_ = ((double)_alphabet->size())*log(sample_size); 00500 00501 while(nodesToPrune.size() > 0) 00502 { 00503 int id = nodesToPrune.back(); 00504 nodesToPrune.pop_back(); 00505 double total = 0.0; 00506 double total_diff = 0.0; 00507 ContextTreeNodePtr parentNode = getContext(id); 00508 if(parentNode->isLeaf()) 00509 break; 00510 00511 for(int m = 0; m < (int)_alphabet->size(); m++) 00512 total += (parentNode->getCounter())[m]; 00513 bool foundSmall = false; 00514 for (int l = 0; l < (int)_alphabet->size(); l++) 00515 { 00516 ContextTreeNodePtr childNode = parentNode->getChild(l); 00517 double totalChild = 0.0; 00518 for(int m = 0; m < (int)_alphabet->size(); m++) 00519 totalChild+= (childNode->getCounter())[m]; 00520 for(int m = 0; m < (int)_alphabet->size(); m++) 00521 { 00522 double diff = (double)(parentNode->getCounter())[m] / total; 00523 diff -= (double)(childNode->getCounter())[m] /totalChild; 00524 assert(childNode->isLeaf()); 00525 if((double)(childNode->getCounter())[m] < small_) 00526 { 00527 foundSmall = true; 00528 break; 00529 } 00530 if(diff < 0) 00531 total_diff -= diff; 00532 else 00533 total_diff += diff; 00534 00535 } 00536 if(foundSmall) 00537 break; 00538 } 00539 if((total < small_) || 00540 (total_diff <delta) || 00541 (foundSmall==true)) 00542 { 00543 parentNode->deleteChildren(); 00544 ContextTreeNodePtr parentNode2 = getContext(parentNode->getParent()); 00545 bool toPrune = true; 00546 for(int l = 0; l < (int)_alphabet->size(); l++) 00547 if((parentNode2->getChild(l) != NULL) && !(parentNode2->getChild(l)->isLeaf())) 00548 { 00549 toPrune = false; 00550 break; 00551 } 00552 if(toPrune) 00553 nodesToPrune.push_back(parentNode2->id()); 00554 } 00555 } 00556 } 00557 00558 00559 00560 00561 00562 void ContextTree::initializeContextTreeRissanen(const SequenceEntryList & sequences) 00563 { 00564 ContextTreeNodePtr root = createContext(); 00565 for(int i = 0; i < (int)_alphabet->size(); i++) 00566 root->addCount(i); 00567 00568 for(int s = 0; s < (int)sequences.size(); s++) 00569 { 00570 for(int i = 0; i < (int)(sequences[s]->getSequence()).size(); i++) 00571 { 00572 int v = (sequences[s]->getSequence())[i]; 00573 ContextTreeNodePtr w = root; 00574 if((!w->isLeaf()) && ((w->getCounter())[v] == 1.0)) 00575 { 00576 for(int l = 0; l < (int)_alphabet->size(); l++) 00577 { 00578 ContextTreeNodePtr n = w->getChild(l); 00579 n->addCount(v); 00580 } 00581 w->addCount(v); 00582 continue; 00583 } 00584 if(w->isLeaf() && ((w->getCounter())[v] == 1.0)) 00585 { 00586 for(int l = 0; l < (int)_alphabet->size(); l++) 00587 { 00588 ContextTreeNodePtr n = createContext(); 00589 n->addCount(v); 00590 w->setChild(n, l); 00591 } 00592 w->addCount(v); 00593 continue; 00594 } 00595 int j = i - 1; 00596 if(j < 0) 00597 w->addCount(v); 00598 while(j >= 0) 00599 { 00600 int u = (sequences[s]->getSequence())[j]; 00601 w->addCount(v); 00602 w = w->getChild(u); 00603 if((!w->isLeaf()) && (w->getCounter())[v] == 1.0) 00604 { 00605 for(int l = 0; l < (int)_alphabet->size(); l++) 00606 { 00607 ContextTreeNodePtr n = w->getChild(l); 00608 n->addCount(v); 00609 } 00610 w->addCount(v); 00611 break; 00612 } 00613 if(w->isLeaf() && ((w->getCounter())[v] == 1.0) ) 00614 { 00615 for(int l = 0; l < (int)_alphabet->size(); l++) 00616 { 00617 ContextTreeNodePtr n = createContext(); 00618 n->addCount(v); 00619 w->setChild(n,l); 00620 } 00621 w->addCount(v); 00622 break; 00623 } 00624 j = j-1; 00625 } 00626 } 00627 } 00628 } 00629 00630 00631 00632 }