Program Listing for File augmentation.hpp¶
↰ Return to documentation for file (/home/jenkins/docs/models/augmentation/augmentation.hpp
)
#ifndef MODELS_AUGMENTATION_AUGMENTATION_HPP
#define MODELS_AUGMENTATION_AUGMENTATION_HPP
#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp>
#include <mlpack/core/util/to_lower.hpp>
#include <boost/regex.hpp>
namespace mlpack {
namespace models {
class Augmentation
{
public:
Augmentation() :
augmentations(std::vector<std::string>()),
augmentationProbability(0.2)
{
// Nothing to do here.
}
Augmentation(const std::vector<std::string>& augmentations,
const double augmentationProbability) :
augmentations(augmentations),
augmentationProbability(augmentationProbability)
{
// Convert strings to lower case.
for (size_t i = 0; i < augmentations.size(); i++)
this->augmentations[i] = mlpack::util::ToLower(augmentations[i]);
// Sort the vector to place resize parameter to the front of the string.
// This prevents constant lookups for resize.
sort(this->augmentations.begin(), this->augmentations.end(), [](
std::string& str1, std::string& str2)
{
return str1.find("resize") != std::string::npos;
});
}
template<typename DatasetType>
void Transform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth = 1);
template<typename DatasetType>
void ResizeTransform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation);
private:
bool HasResizeParam(const std::string& augmentation = "")
{
if (augmentation.length())
return augmentation.find("resize") != std::string::npos;
// Search in augmentation vector.
return augmentations.size() <= 0 ? false :
augmentations[0].find("resize") != std::string::npos;
}
void GetResizeParam(size_t& outWidth,
size_t& outHeight,
const std::string& augmentation)
{
if (!HasResizeParam())
return;
outWidth = 0;
outHeight = 0;
// Use regex to find one or two numbers. If only one provided
// set output width equal to output height.
boost::regex regex{"[0-9]+"};
// Create an iterator to find matches.
boost::sregex_token_iterator matches(augmentation.begin(),
augmentation.end(), regex, 0), end;
size_t matchesCount = std::distance(matches, end);
if (matchesCount == 0)
{
mlpack::Log::Fatal << "Invalid size / shape in " <<
augmentation << std::endl;
}
if (matchesCount == 1)
{
outWidth = std::stoi(*matches);
outHeight = outWidth;
}
else
{
outWidth = std::stoi(*matches);
matches++;
outHeight = std::stoi(*matches);
}
}
std::vector<std::string> augmentations;
double augmentationProbability;
// The dataloader class should have access to internal functions of
// the augmentation class.
template<typename DatasetX, typename DatasetY, class ScalerType>
friend class DataLoader;
};
} // namespace models
} // namespace mlpack
#include "augmentation_impl.hpp" // Include implementation.
#endif