Home » Source Code » ID3

ID3

makz121
2018-07-21 09:36:20
The author
View(s):
Download(s): 0
Point (s): 1 
Category Category:
c++c++ AllAll

Description

#include
#include


#include
#include
#include
#include
#include
#include
#include


/// Forward declarations.
template class ID3;
template class ID3Train;


/// Number of attributes in each sample.
constexpr size_t N = 4;
/// Floating point error.
constexpr float EPS = 1e-7;


/**
 * A single sample with attributes & target class.
 */
template
struct Sample {
  Sample()
  {
  }


  Sample(Sample &&sample)
    : attributes(std::move(sample.attributes))
    , clazz(std::move(sample.clazz))
  {
  }


  void operator = (Sample &&sample) {
    attributes = std::move(sample.attributes);
    clazz = std::move(sample.clazz);
  }


  // Array of attributes of the sample.
  std::array attributes;
  // Class to which the sample belongs.
  std::string clazz;
};


/**
 * ID3 tree.
 */
template
class ID3 {
 public:
  std::string classify(const std::array &sample)
  {
    return root_->classify(sample);
  }


  void print(std::ostream &os)
  {
    root_->print(os, 0);
  }


 private:
  class Node {
   public:
    virtual ~Node() {};
    virtual std::string classify(const std::array &sample) = 0;
    virtual void print(std::ostream &os, size_t level) = 0;
  };


  class TerminalNode : public Node {
   public:
    TerminalNode(const std::string& clazz)
      : clazz_(clazz)
    {
    }


    std::string classify(const std::array &sample) override {
      return clazz_;
    }


    void print(std::ostream &os, size_t level) override {
      for (size_t i = 0; i < level; ++i) {
        os << ' ';
      }
      os << clazz_ << std::endl;
    }


   private:
    std::string clazz_;
  };


  /// Inner node that makes a decision.
  class InnerNode : public Node {
   public:
    InnerNode(
        size_t attribute,
        std::string &&clazz,
        std::unordered_map<std::string, std::unique_ptr>&& branches)
      : attribute_(attribute)
      , clazz_(clazz)
      , branches_(std::move(branches))
    {
    }


    std::string classify(const std::array &sample) override {
      auto it = branches_.find(sample[attribute_]);
      if (it == branches_.end()) {
        return clazz_;
      } else {
        return it->second->classify(sample);
      }
    }


    void print(std::ostream &os, size_t level) override {
      auto tabs = [&os, level] {
        for (size_t i = 0; i < level; ++i) {
          os << ' ';
        }
      };
      for (const auto &branch : branches_) {
        tabs(); os << branch.first << ":" << std::endl;
        branch.second->print(os, level + 1);
      }
    }


   private:
    size_t attribute_;
    std::string clazz_;
    std::unordered_map<std::string, std::unique_ptr> branches_;
  };


  ID3(std::unique_ptr &&root)
    : root_(std::move(root))
  {
  }


  /// Root node of the decision tree.
  std::unique_ptr root_;


  /// Only the trainer can construct this class.
  template  friend class ID3Train;
};


/**
 * ID3 Trainer.
 */
template
class ID3Train {
 private:
  // Some shorthand aliases.
  using ID3          = typename ::ID3;
  using Node         = typename ::ID3::Node;
  using TerminalNode = typename ::ID3::TerminalNode;
  using InnerNode    = typename ::ID3::InnerNode;
  using Sample       = typename ::Sample;
  using Iter         = typename std::vector<typename ::Sample>::iterator;


 public:
  ID3Train(std::vector &&samples)
    : samples_(std::move(samples))
  {
  }


  std::unique_ptr train()
  {
    return std::unique_ptr(new ID3(train(samples_.begin(), samples_.end())));
  }


 private:
  std::unique_ptr train(Iter start, Iter end)
  {
    auto ig = std::make_pair(0, std::numeric_limits::min());
    std::unordered_map clazzes;
    std::string maxClazz;


    // For each attribute/value pair, compute how many items fall into that category.
    std::array<std::unordered_map<std::string, std::unordered_map>, N> count;
    for (auto it = start; it != end; ++it) {
      for (size_t i = 0; i < N; ++i) {
        count[i][it->attributes[i]][it->clazz]++;
      }
      clazzes[it->clazz]++;
    }


    // Compute the entropy of the current set.
    auto entropy = 0.0f;
    auto total = end - start;
    for (auto clazz : clazzes) {
      auto p = clazz.second / (float)total;
      entropy -= p * log(p) / log(2.0f);


      if (!maxClazz.empty() && clazz.second > clazzes[maxClazz]) {
        maxClazz = clazz.first;
      }
    }


    // If set is all classified, return leaf node.
    if (abs(entropy) <= EPS) {
      return std::make_unique(start->clazz);
    }


    // Compute the information gain on all possible splits.
    for (size_t i = 0; i < N; ++i) {
      auto attribIG = entropy;
      for (auto split : count[i]) {
        auto setTotal = 0;
        for (auto clazz : split.second) {
          setTotal += clazz.second;
        }


        auto setEntropy = 0.0f;
        for (auto clazz : split.second) {
          auto p = clazz.second / (float)setTotal;
          setEntropy -= p * log(p) / log(2.0f);
        }
        attribIG -= (float)setTotal / (float)total * setEntropy;
      }
      if (attribIG >= ig.second) {
        ig.first = i;
        ig.second = attribIG;
      }
    }


    // Sort the set by the attribute index ig.first.
    auto attribIndex = ig.first;
    std::sort(start, end, [attribIndex] (const Sample &a, const Sample &b) {
      return a.attributes[attribIndex] < b.attributes[attribIndex];
    });


    // Split the samples by the attributes.
    auto setStart = start;
    std::unordered_map<std::string, std::unique_ptr> nodes;
    for (auto it = start + 1; it != end + 1; ++it) {
      if (it < end && it->attributes[attribIndex] == setStart->attributes[attribIndex]) {
        continue;
      }
      nodes[setStart->attributes[attribIndex]] = train(setStart, it);
      setStart = it;
    }


    return std::make_unique(attribIndex, std::move(maxClazz), std::move(nodes));
  }


  std::vector samples_;
};


/**
 * Entry point of the application.
 */
int main(int argc, char **argv)
{
  std::vector<Sample> samples;


  // Read the samples.
  {
    std::ifstream is(argc >= 2 ? argv[1] : "data");
    while (!is.eof()) {
      Sample sample;
      for (auto i = 0; i < N; ++i) {
        if (!(is >> sample.attributes[i])) {
          break;
        }
      }
      if (!(is >> sample.clazz)) {
        break;
      }
      samples.push_back(std::move(sample));
    }
  }


  // Train the ID3.
  auto id3 = ID3Train(std::move(samples)).train();


  // Print the ID3 tree.
  id3->print(std::cout);


  // Classify some samples from stdin.
  {
    std::string line;
    while (line.resize(512), cin.getline(&line[0], line.size())) {
      line.resize(line.find_first_of('\0'));
      if (line.size() == 0) {
        break;
      }
      std::stringstream is(line);
      std::array sample;
      for (auto i = 0; i < N; ++i) {
        if (!(is >> sample[i])) {
          return EXIT_SUCCESS;
        }
      }
      std::cout << id3->classify(sample) << std::endl;
    }
  }


  return EXIT_SUCCESS;
}

Sponsored links

File list

Tips: You can preview the content of files by clicking file names^_^
Name Size Date
01.97 kB
id3.cpp7.14 kB2015-09-24 15:29
...
Sponsored links

Comments

(Add your comment, get 0.1 Point)
Minimum:15 words, Maximum:160 words
  • 1
  • Page 1
  • Total 1

ID3 (2.62 kB)

Need 1 Point(s)
Your Point (s)

Your Point isn't enough.

Get 22 Point immediately by PayPal

Point will be added to your account automatically after the transaction.

More(Debit card / Credit card / PayPal Credit / Online Banking)

Submit your source codes. Get more Points

LOGIN

Don't have an account? Register now
Need any help?
Mail to: support@codeforge.com

切换到中文版?

CodeForge Chinese Version
CodeForge English Version

Where are you going?

^_^"Oops ...

Sorry!This guy is mysterious, its blog hasn't been opened, try another, please!
OK

Warm tip!

CodeForge to FavoriteFavorite by Ctrl+D