Program Listing for File augmentation_impl.hpp¶
↰ Return to documentation for file (/home/jenkins/docs/models/augmentation/augmentation_impl.hpp
)
#ifndef MODELS_AUGMENTATION_AUGMENTATION_IMPL_HPP
#define MODELS_AUGMENTATION_AUGMENTATION_IMPL_HPP
// Incase it has not been included already.
#include "augmentation.hpp"
namespace mlpack {
namespace models {
template<typename DatasetType>
void Augmentation::Transform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth)
{
// Initialize the augmentation map.
std::unordered_map<std::string, void(*)(DatasetType&,
size_t, size_t, size_t, std::string&)> augmentationMap;
for (size_t i = 0; i < augmentations.size(); i++)
{
if (augmentationMap.count(augmentations[i]))
{
augmentationMap[augmentations[i]](dataset, datapointWidth,
datapointHeight, datapointDepth, augmentations[i]);
}
else if (this->HasResizeParam(augmentations[i]))
{
this->ResizeTransform(dataset, datapointWidth, datapointHeight,
datapointDepth, augmentations[i]);
}
else
{
mlpack::Log::Warn << "Unknown augmentation : \'" <<
augmentations[i] << "\' not found!" << std::endl;
}
}
}
template<typename DatasetType>
void Augmentation::ResizeTransform(
DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation)
{
size_t outputWidth = 0, outputHeight = 0;
// Get output width and output height.
GetResizeParam(outputWidth, outputHeight, augmentation);
// We will use mlpack's bilinear interpolation layer to
// resize the input.
mlpack::ann::BilinearInterpolation<DatasetType, DatasetType> resizeLayer(
datapointWidth, datapointHeight, outputWidth, outputHeight,
datapointDepth);
DatasetType output;
resizeLayer.Forward(dataset, output);
dataset = std::move(output);
}
} // namespace models
} // namespace mlpack
#endif