ToPS
ContextTree.cpp
00001 /*
00002  *       ContextTree.cpp
00003  *
00004  *       Copyright 2011 Andre Yoshiaki Kashiwabara <akashiwabara@usp.br>
00005  *                      Ígor Bonádio <ibonadio@ime.usp.br>
00006  *                      Vitor Onuchic <vitoronuchic@gmail.com>
00007  *                      Alan Mitchell Durham <aland@usp.br>
00008  *
00009  *       This program is free software; you can redistribute it and/or modify
00010  *       it under the terms of the GNU  General Public License as published by
00011  *       the Free Software Foundation; either version 3 of the License, or
00012  *       (at your option) any later version.
00013  *
00014  *       This program is distributed in the hope that it will be useful,
00015  *       but WITHOUT ANY WARRANTY; without even the implied warranty of
00016  *       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00017  *       GNU General Public License for more details.
00018  *
00019  *       You should have received a copy of the GNU General Public License
00020  *       along with this program; if not, write to the Free Software
00021  *       Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
00022  *       MA 02110-1301, USA.
00023  */
00024 
00025 #include "ContextTree.hpp"
00026 #include "util.hpp"
00027 #include <vector>
00028 #include "Symbol.hpp"
00029 
00030 namespace tops
00031 {
00032   void ContextTree::buildParameters(ContextTreeNodePtr node, std::map<std::string, double> & parameters) const
00033   {
00034     ContextTreeNodePtr current = node;
00035     AlphabetPtr alphabet = current->getDistribution()->alphabet();
00036     std::vector <std::string> aux;
00037     bool root=true;
00038     if(current != getRoot() )
00039       root = false;
00040     while(current != getRoot())
00041       {
00042         aux.push_back(alphabet->getSymbol(current->symbol())->name() );
00043         current = _all_context[current->getParent()];
00044       }
00045     for(int k = 0; k < (int)alphabet->size(); k++)
00046       {
00047         std::stringstream out;
00048         out << alphabet->getSymbol(k)->name() << "|" ;
00049         if(!root)
00050           {
00051             out << aux[aux.size() -1];
00052             for(int i = aux.size()-2; i >=0; i--)
00053               out << " " << aux[i] ;
00054           }
00055         parameters[out.str()] =  exp(node->getDistribution()->log_probability_of(k));
00056       }
00057     if(!node->isLeaf())
00058       for(int l = 0; l < node->alphabet_size(); l++)
00059         if(node->getChild(l) != NULL)
00060           buildParameters(node->getChild(l), parameters);
00061 
00062   }
00063   DoubleMapParameterValuePtr ContextTree::getParameterValue () const
00064   {
00065     std::map<std::string, double> probabilities;
00066     buildParameters(getRoot(), probabilities);
00067     DoubleMapParameterValuePtr v = DoubleMapParameterValuePtr(new DoubleMapParameterValue(probabilities));
00068     return v;
00069   }
00070   ContextTreeNode::ContextTreeNode(int alphabet_size)
00071   {
00072     _child.resize(alphabet_size);
00073     _counter.resize(alphabet_size);
00074     for(int i = 0; i < (int)_counter.size(); i++)
00075       _counter[i] = 0;
00076     _symbol = -1;
00077     _alphabet_size = alphabet_size;
00078     _leaf = true;
00079     _id = 0;
00080   }
00081 
00082   ContextTreeNode::ContextTreeNode() {
00083     _symbol = -1;
00084     _alphabet_size = 0;
00085     _leaf = true;
00086   }
00087 
00088   void ContextTreeNode::addCount (int s) {
00089     _counter[s] += 1.0;
00090   }
00091 
00092   void ContextTreeNode::addCount (int s, double weight) {
00093     _counter[s] += weight;
00094   }
00095 
00096 
00097   void ContextTreeNode::setCount (int s, double v) {
00098     _counter[s] = v;
00099   }
00100 
00101   std::vector<double> & ContextTreeNode::getCounter () {
00102     return _counter;
00103   }
00104 
00105   int ContextTreeNode::alphabet_size() {
00106     return _alphabet_size;
00107   }
00108 
00109   void ContextTreeNode::setParent(int parent)
00110   {
00111     _parent_id = parent;
00112   }
00113   int ContextTreeNode::getParent()
00114   {
00115     return _parent_id;
00116   }
00117   int ContextTreeNode::id () {
00118     return _id;
00119   }
00120   void ContextTreeNode::setId(int id)
00121   {
00122     _id = id;
00123   }
00124 
00125   void ContextTreeNode::setChild(ContextTreeNodePtr child, int symbol){
00126     assert((symbol >= 0) && (symbol < (int)_child.size()));
00127     _child[symbol] = child;
00128     child->setSymbol(symbol);
00129     child->setParent(id());
00130     _leaf = false;
00131   }
00132 
00133   int ContextTreeNode::symbol(){
00134     return _symbol;
00135   }
00136 
00137   void ContextTreeNode::setSymbol(int symbol){
00138     _symbol = symbol;
00139   }
00140 
00141   void ContextTreeNode::setDistribution(DiscreteIIDModelPtr distribution){
00142     _distribution = distribution;
00143   }
00144   ContextTreeNodePtr ContextTreeNode::getChild(int symbol){
00145     if(!((symbol >= 0) && (symbol < (int)_child.size())))
00146       {
00147         std::cerr << "ERROR: ContextTree has reached an invalid node !" << std::endl;
00148         std::cerr << "Symbol id : " << symbol << std::endl;
00149         exit(-1);
00150       }
00151     return _child[symbol] ;
00152   }
00153   DiscreteIIDModelPtr ContextTreeNode::getDistribution(){
00154     return _distribution;
00155   }
00156   void ContextTreeNode::deleteChildren() {
00157     ContextTreeNodePtr n;
00158     for(int m = 0; m < (int)_child.size(); m++)
00159       if(_child[m] != NULL)
00160         _child[m]->setParent(-1);
00161     _child.resize(0);
00162     _child.resize(_alphabet_size);
00163     _leaf = true;
00164   }
00165   std::vector <ContextTreeNodePtr> ContextTreeNode::getChildren() {
00166     return _child;
00167   }
00168 
00169 
00170 
00171   bool ContextTreeNode::isLeaf(){
00172     return _leaf;
00173   }
00174 
00175   std::string ContextTreeNode::str() const {
00176     std::stringstream out;
00177     return out.str();
00178   }
00179 
00180 
00181 
00182   void ContextTree::printTree(ContextTreeNodePtr node, std::stringstream & out) const
00183   {
00184     ContextTreeNodePtr current = node;
00185     while(current != getRoot())
00186       {
00187         out << current->symbol() << " ";
00188         current = _all_context[current->getParent()];
00189       }
00190     int sum = 0;
00191     out << ": ";
00192     for(int k = 0; k < (int)_alphabet->size(); k++)
00193       {
00194         sum += (node->getCounter())[k];
00195         out << (node->getCounter())[k] << " " ;
00196       }
00197 
00198     out << "("<< node->id() << ", "<<  sum << ") ";
00199     if(!node->isLeaf())
00200       out << "internal node" << std::endl;
00201     else
00202       out << "leaf node" << std::endl;
00203     if(!node->isLeaf())
00204       for(int l = 0; l < node->alphabet_size(); l++)
00205         if(node->getChild(l) != NULL)
00206           printTree(node->getChild(l), out);
00207   }
00208 
00209 
00210 
00211 
00212   ContextTree::ContextTree(AlphabetPtr alphabet){
00213     _alphabet = alphabet;
00214   }
00215 
00216   ContextTreeNodePtr ContextTree::getRoot() const {
00217     return _all_context[0];
00218   }
00219 
00220 
00221   ContextTreeNodePtr ContextTree::createContext() {
00222     ContextTreeNodePtr n = ContextTreeNodePtr(new ContextTreeNode(_alphabet->size()));
00223     n->setId(_all_context.size());
00224     if(n->id() == 0)
00225       {
00226         n->setParent(0);
00227       }
00228     _all_context.push_back(n);
00229     return n;
00230   }
00231 
00232   ContextTreeNodePtr ContextTree::getContext (int id)
00233   {
00234     return _all_context[id];
00235   }
00236 
00237 
00239   ContextTreeNodePtr ContextTree::getContext(const Sequence & s, int i){
00240     ContextTreeNodePtr c = _all_context[0];
00241     ContextTreeNodePtr p;
00242     int j;
00243     for(j = i-1; j >=0; j--){
00244       if(c->isLeaf())
00245         break;
00246       p = c;
00247       c = c->getChild(s[j]);
00248       if(c == NULL)
00249         {
00250           c = p;
00251           break;
00252         }
00253     }
00254 #if 0
00255     if(c == NULL)
00256       {
00257         std::cerr << "WARNING: You have reached an undefined context  !"<< std::endl;
00258         std::cerr << "WARNING: Probability distribution for the following context was not defined: " << std::endl;
00259         for(int k = j; k < i; k++)
00260           std::cerr << s[k] << " " ;
00261         std::cerr << std::endl;
00262         std::cerr << "Position: " << j << " " << i-1 << std::endl;
00263       }
00264 #endif
00265     return c;
00266   }
00267 
00268   std::set <int> ContextTree::getLevelOneNodes()
00269   {
00270     std::set<int> result;
00271     for(int i = 0; i  < (int)_all_context.size(); i++)
00272       if(_all_context[i]->isLeaf())
00273         {
00274           int parent_id = _all_context[i]->getParent();
00275           if (parent_id < 0)
00276             continue;
00277           ContextTreeNodePtr parent = _all_context[parent_id];
00278           bool levelOne = true;
00279           for(int l =0; l < (int)_alphabet->size(); l++)
00280             if((parent->getChild(l) != NULL )&& !parent->getChild(l)->isLeaf())
00281               levelOne = false;
00282           if(levelOne)
00283             result.insert(parent->id());
00284         }
00285     return result;
00286   }
00287 
00288 
00289   void ContextTree::removeContextNotUsed()
00290   {
00291     std::vector <ContextTreeNodePtr> newAllVector;
00292     for(int i = 0; i  < (int)_all_context.size(); i++)
00293       {
00294         if(_all_context[i]->getParent() >= 0)
00295           {
00296             _all_context[i]->setId(newAllVector.size());
00297             newAllVector.push_back(_all_context[i]);
00298             if((_all_context[i] != NULL) && (!_all_context[i]->isLeaf()))
00299               for(int  m = 0; m < (int)_alphabet->size(); m++)
00300                 _all_context[i]->getChild(m)->setParent(_all_context[i]->id());
00301           }
00302       }
00303     _all_context = newAllVector;
00304   }
00305 
00306   void ContextTree::normalize()
00307   {
00308     std::vector <ContextTreeNodePtr> newAllVector;
00309     for(int i = 0; i  < (int)_all_context.size(); i++)
00310       {
00311         double total = 0;
00312         DoubleVector probs(_alphabet->size());
00313         for(int l = 0; l < (int)_alphabet->size(); l++)
00314           total += (double)(_all_context[i]->getCounter())[l];
00315         for(int l = 0; l < (int)_alphabet->size(); l++){
00316           probs[l] = (double)((_all_context[i]->getCounter())[l])/total;
00317         }
00318         DiscreteIIDModelPtr distr = DiscreteIIDModelPtr(new DiscreteIIDModel(probs));
00319         distr->setAlphabet(_alphabet);
00320         _all_context[i]->setDistribution(distr);
00321       }
00322   }
00323 
00324     void ContextTree::normalize(ProbabilisticModelPtr old, double pseudocount, int t)
00325     {
00326         if(old == NULL){
00327             std::cerr << "ERROR: ContextTree -> a priori model is null !" << std::endl;
00328             exit(-1);
00329         }
00330     std::vector <ContextTreeNodePtr> newAllVector;
00331     for(int i = 0; i  < (int)_all_context.size(); i++)
00332       {
00333         double total = 0;
00334         DoubleVector probs(_alphabet->size());
00335 
00336         Sequence s;
00337         ContextTreeNodePtr current = _all_context[i];
00338         bool valid = true;
00339         while (current != getRoot() ) {
00340             s.push_back (current->symbol());
00341             if(current->getParent() < 0) {
00342                 valid = false;
00343                 break;
00344             }
00345             current = _all_context[current->getParent()];
00346         }
00347         if(!valid)
00348             continue;
00349 
00350         for(int l = 0; l < (int)_alphabet->size(); l++) {
00351             total += (double)(_all_context[i]->getCounter())[l];
00352         }
00353         for(int l = 0; l < (int)_alphabet->size(); l++){
00354             Sequence s3;
00355             s3 = s;
00356             s3.push_back(l);
00357 
00358             double prob = exp(old->evaluatePosition(s3,s3.size()-1, t));
00359 
00360             probs[l] = (double)((_all_context[i]->getCounter())[l] + pseudocount*prob)/(total + pseudocount);
00361         }
00362 
00363         DiscreteIIDModelPtr distr = DiscreteIIDModelPtr(new DiscreteIIDModel(probs));
00364         distr->setAlphabet(_alphabet);
00365         _all_context[i]->setDistribution(distr);
00366       }
00367     }
00368 
00369   std::string ContextTree::str() const{
00370     std::stringstream out;
00371     std::vector <ContextTreeNodePtr> stack;
00372     ContextTreeNodePtr current = getRoot();
00373     printTree(current, out);
00374     return out.str();
00375   }
00376 
00377   void ContextTree::initializeCounter(const SequenceEntryList & sequences, int order, const std::map<std::string, double> & weights)
00378   {
00379     initializeCounter(sequences, order,0, weights);
00380   }
00381 
00382   void ContextTree::initializeCounter(const SequenceEntryList & sequences, int order, double pseudocounts, const std::map<std::string, double> & weights)
00383   {
00384     if (order < 0) order = 0;
00385 
00386     ContextTreeNodePtr root = createContext();
00387     if(pseudocounts > 0) {
00388       for(int sym = 0; sym < root->alphabet_size(); sym++)
00389         {
00390           root->setCount(sym,pseudocounts);
00391         }
00392     }
00393 
00394     for ( int l = 0; l < (int)sequences.size(); l ++){
00395       std::string seqname = sequences[l]->getName();
00396       double weight = 1.0;
00397       if (weights.find(seqname) != weights.end())
00398         weight = (weights.find(seqname)->second);
00399       // std::cerr << seqname << " with weight " << weight << std::endl;
00400       for( int i = order; i < (int)(sequences[l]->getSequence()).size(); i++)
00401         {
00402           int currentSymbol = (sequences[l]->getSequence())[i];
00403           int j = i - 1;
00404 
00405           ContextTreeNodePtr w = getRoot();
00406 
00407           w->addCount(currentSymbol, weight);
00408 
00409           while((j >= 0) &&  ((i - j) <= order))
00410             {
00411               int symbol = (sequences[l]->getSequence())[j];
00412               if((w->getChild(symbol) == NULL) || w->isLeaf())
00413                 {
00414                   ContextTreeNodePtr c2 = createContext();
00415                   w->setChild(c2, symbol);
00416                 }
00417               w = w->getChild(symbol);
00418 
00419               if(pseudocounts > 0) {
00420                 for(int sym = 0; sym < root->alphabet_size(); sym++)
00421                   {
00422                     if(w->getCounter()[sym] <= 0.0)
00423                       w->setCount(sym,pseudocounts);
00424                   }
00425               }
00426 
00427 
00428               w->addCount(currentSymbol, weight);
00429               j -- ;
00430             }
00431         }
00432     }
00433   }
00434 
00435   void ContextTree::pruneTreeSmallSampleSize(int small_)
00436   {
00437 
00438     std::set<int> x = getLevelOneNodes();
00439     std::vector<int> nodesToPrune (x.begin(),x.end());
00440     std::set<int>::iterator it;
00441 
00442     while(nodesToPrune.size() > 0)
00443       {
00444         int id = nodesToPrune.back();
00445         nodesToPrune.pop_back();
00446         double total = 0.0;
00447         ContextTreeNodePtr parentNode = getContext(id);
00448         if(parentNode->isLeaf())
00449           break;
00450 
00451         for(int m = 0; m < (int)_alphabet->size(); m++)
00452           total += (parentNode->getCounter())[m];
00453         bool foundSmall = false;
00454         for (int l = 0; l < (int)_alphabet->size(); l++)
00455           {
00456             ContextTreeNodePtr childNode = parentNode->getChild(l);
00457             if(childNode == NULL)
00458               continue;
00459             double totalchild = 0;
00460             for(int m = 0; m < (int)_alphabet->size(); m++)
00461               {
00462                 totalchild += (childNode->getCounter())[m];
00463               }
00464             if(totalchild < small_){
00465               for(int m = 0; m < (int)_alphabet->size(); m++)
00466                 {
00467                   (childNode->getCounter())[m] = (parentNode->getCounter())[m];
00468                 }
00469             }
00470           }
00471         if(total < small_)
00472           {
00473             parentNode->deleteChildren();
00474             ContextTreeNodePtr parentNode2 = getContext(parentNode->getParent());
00475             bool toPrune = true;
00476             for(int l = 0; l < (int)_alphabet->size(); l++)
00477               if((parentNode2->getChild(l) != NULL) && !(parentNode2->getChild(l)->isLeaf()))
00478                 {
00479                   toPrune = false;
00480                   break;
00481                 }
00482             if(toPrune)
00483               nodesToPrune.push_back(parentNode2->id());
00484           }
00485       }
00486   }
00487 
00488 
00489 
00490   void ContextTree::pruneTree(double delta)
00491   {
00492 
00493     double sample_size = 0.0;
00494     for (int l = 0; l < (int)_alphabet->size(); l++)
00495       sample_size += (getRoot()->getCounter())[l];
00496     std::set<int> x = getLevelOneNodes();
00497     std::vector<int> nodesToPrune (x.begin(),x.end());
00498     std::set<int>::iterator it;
00499     double small_ = ((double)_alphabet->size())*log(sample_size);
00500 
00501     while(nodesToPrune.size() > 0)
00502       {
00503         int id = nodesToPrune.back();
00504         nodesToPrune.pop_back();
00505         double total = 0.0;
00506         double total_diff = 0.0;
00507         ContextTreeNodePtr parentNode = getContext(id);
00508         if(parentNode->isLeaf())
00509           break;
00510 
00511         for(int m = 0; m < (int)_alphabet->size(); m++)
00512           total += (parentNode->getCounter())[m];
00513         bool foundSmall = false;
00514         for (int l = 0; l < (int)_alphabet->size(); l++)
00515           {
00516             ContextTreeNodePtr childNode = parentNode->getChild(l);
00517             double totalChild = 0.0;
00518             for(int m = 0; m < (int)_alphabet->size(); m++)
00519               totalChild+= (childNode->getCounter())[m];
00520             for(int m = 0; m < (int)_alphabet->size(); m++)
00521               {
00522                 double diff = (double)(parentNode->getCounter())[m] / total;
00523                 diff -= (double)(childNode->getCounter())[m] /totalChild;
00524                 assert(childNode->isLeaf());
00525                 if((double)(childNode->getCounter())[m] < small_)
00526                   {
00527                     foundSmall = true;
00528                     break;
00529                   }
00530                 if(diff < 0)
00531                   total_diff -= diff;
00532                 else
00533                   total_diff += diff;
00534 
00535               }
00536             if(foundSmall)
00537               break;
00538           }
00539         if((total < small_) ||
00540            (total_diff <delta) ||
00541            (foundSmall==true))
00542           {
00543             parentNode->deleteChildren();
00544             ContextTreeNodePtr parentNode2 = getContext(parentNode->getParent());
00545             bool toPrune = true;
00546             for(int l = 0; l < (int)_alphabet->size(); l++)
00547               if((parentNode2->getChild(l) != NULL) && !(parentNode2->getChild(l)->isLeaf()))
00548                 {
00549                   toPrune = false;
00550                   break;
00551                 }
00552             if(toPrune)
00553               nodesToPrune.push_back(parentNode2->id());
00554           }
00555       }
00556   }
00557 
00558 
00559 
00560 
00561 
00562   void ContextTree::initializeContextTreeRissanen(const SequenceEntryList & sequences)
00563   {
00564     ContextTreeNodePtr root = createContext();
00565     for(int i = 0; i < (int)_alphabet->size(); i++)
00566       root->addCount(i);
00567 
00568     for(int s = 0; s < (int)sequences.size(); s++)
00569       {
00570         for(int i = 0; i < (int)(sequences[s]->getSequence()).size(); i++)
00571           {
00572             int v = (sequences[s]->getSequence())[i];
00573             ContextTreeNodePtr w = root;
00574             if((!w->isLeaf()) && ((w->getCounter())[v] == 1.0))
00575               {
00576                 for(int l = 0; l < (int)_alphabet->size(); l++)
00577                   {
00578                     ContextTreeNodePtr n = w->getChild(l);
00579                     n->addCount(v);
00580                   }
00581                 w->addCount(v);
00582                 continue;
00583               }
00584             if(w->isLeaf() && ((w->getCounter())[v] == 1.0))
00585               {
00586                 for(int l = 0; l < (int)_alphabet->size(); l++)
00587                   {
00588                     ContextTreeNodePtr n = createContext();
00589                     n->addCount(v);
00590                     w->setChild(n, l);
00591                   }
00592                 w->addCount(v);
00593                 continue;
00594               }
00595             int j = i - 1;
00596             if(j < 0)
00597               w->addCount(v);
00598             while(j >= 0)
00599               {
00600                 int u = (sequences[s]->getSequence())[j];
00601                 w->addCount(v);
00602                 w = w->getChild(u);
00603                 if((!w->isLeaf()) && (w->getCounter())[v] == 1.0)
00604                   {
00605                     for(int l = 0; l < (int)_alphabet->size(); l++)
00606                       {
00607                         ContextTreeNodePtr n = w->getChild(l);
00608                         n->addCount(v);
00609                       }
00610                     w->addCount(v);
00611                     break;
00612                   }
00613                 if(w->isLeaf() && ((w->getCounter())[v] == 1.0) )
00614                   {
00615                     for(int l = 0; l  < (int)_alphabet->size(); l++)
00616                       {
00617                         ContextTreeNodePtr n = createContext();
00618                         n->addCount(v);
00619                         w->setChild(n,l);
00620                       }
00621                     w->addCount(v);
00622                     break;
00623                   }
00624                 j = j-1;
00625               }
00626           }
00627       }
00628   }
00629 
00630 
00631 
00632 }