#ifndef _SFEN_STREAM_H_ #define _SFEN_STREAM_H_ #include "nnue_training_data_formats.h" #include #include #include #include #include #include namespace training_data { using namespace binpack; static bool ends_with(const std::string& lhs, const std::string& end) { if (end.size() > lhs.size()) return false; return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); } static bool has_extension(const std::string& filename, const std::string& extension) { return ends_with(filename, "." + extension); } static std::string filename_with_extension(const std::string& filename, const std::string& ext) { if (ends_with(filename, ext)) { return filename; } else { return filename + "." + ext; } } struct BasicSfenInputStream { virtual std::optional next() = 0; virtual void fill(std::vector& vec, std::size_t n) { for (std::size_t i = 0; i < n; ++i) { auto v = this->next(); if (!v.has_value()) { break; } vec.emplace_back(*v); } } virtual bool eof() const = 0; virtual ~BasicSfenInputStream() {} }; struct BinSfenInputStream : BasicSfenInputStream { static constexpr auto openmode = std::ios::in | std::ios::binary; static inline const std::string extension = "bin"; BinSfenInputStream(std::string filename, bool cyclic, std::function skipPredicate) : m_stream(filename, openmode), m_filename(filename), m_eof(!m_stream), m_cyclic(cyclic), m_skipPredicate(std::move(skipPredicate)) { } std::optional next() override { nodchip::PackedSfenValue e; bool reopenedFileOnce = false; for(;;) { if(m_stream.read(reinterpret_cast(&e), sizeof(nodchip::PackedSfenValue))) { auto entry = packedSfenValueToTrainingDataEntry(e); if (!m_skipPredicate || !m_skipPredicate(entry)) return entry; } else { if (m_cyclic) { if (reopenedFileOnce) return std::nullopt; m_stream = std::fstream(m_filename, openmode); reopenedFileOnce = true; if (!m_stream) return std::nullopt; continue; } m_eof = true; return std::nullopt; } } } bool eof() const override { return m_eof; } ~BinSfenInputStream() override {} private: std::fstream m_stream; std::string m_filename; bool m_eof; bool m_cyclic; std::function m_skipPredicate; }; struct BinpackSfenInputStream : BasicSfenInputStream { static constexpr auto openmode = std::ios::in | std::ios::binary; static inline const std::string extension = "binpack"; BinpackSfenInputStream(std::string filename, bool cyclic, std::function skipPredicate) : m_stream(std::make_unique(filename, openmode)), m_filename(filename), m_eof(!m_stream->hasNext()), m_cyclic(cyclic), m_skipPredicate(std::move(skipPredicate)) { } std::optional next() override { bool reopenedFileOnce = false; for(;;) { if (!m_stream->hasNext()) { if (m_cyclic) { if (reopenedFileOnce) return std::nullopt; m_stream = std::make_unique(m_filename, openmode); reopenedFileOnce = true; if (!m_stream->hasNext()) return std::nullopt; continue; } m_eof = true; return std::nullopt; } auto e = m_stream->next(); if (!m_skipPredicate || !m_skipPredicate(e)) return e; } } bool eof() const override { return m_eof; } ~BinpackSfenInputStream() override {} private: std::unique_ptr m_stream; std::string m_filename; bool m_eof; bool m_cyclic; std::function m_skipPredicate; }; struct BinpackSfenInputParallelStream : BasicSfenInputStream { static constexpr auto openmode = std::ios::in | std::ios::binary; static inline const std::string extension = "binpack"; BinpackSfenInputParallelStream(int concurrency, const std::vector& filenames, bool cyclic, std::function skipPredicate) : m_stream(std::make_unique(concurrency, filenames, openmode, cyclic, skipPredicate)), m_filenames(filenames), m_concurrency(concurrency), m_eof(false), m_cyclic(cyclic), m_skipPredicate(skipPredicate) { } std::optional next() override { // filtering is done a layer deeper. auto v = m_stream->next(); if (!v.has_value()) { m_eof = true; return std::nullopt; } return v; } void fill(std::vector& v, std::size_t n) override { auto k = m_stream->fill(v, n); if (n != k) { m_eof = true; } } bool eof() const override { return m_eof; } ~BinpackSfenInputParallelStream() override {} private: std::unique_ptr m_stream; std::vector m_filenames; int m_concurrency; bool m_eof; bool m_cyclic; std::function m_skipPredicate; }; inline std::unique_ptr open_sfen_input_file(const std::string& filename, bool cyclic, std::function skipPredicate = nullptr) { if (has_extension(filename, BinSfenInputStream::extension)) return std::make_unique(filename, cyclic, std::move(skipPredicate)); else if (has_extension(filename, BinpackSfenInputStream::extension)) return std::make_unique(filename, cyclic, std::move(skipPredicate)); return nullptr; } inline std::unique_ptr open_sfen_input_file_parallel(int concurrency, const std::vector& filenames, bool cyclic, std::function skipPredicate = nullptr) { // TODO (low priority): optimize and parallelize .bin reading. if (has_extension(filenames[0], BinSfenInputStream::extension)) return std::make_unique(filenames[0], cyclic, std::move(skipPredicate)); else if (has_extension(filenames[0], BinpackSfenInputParallelStream::extension)) return std::make_unique(concurrency, filenames, cyclic, std::move(skipPredicate)); return nullptr; } } #endif