Program Listing for File datasets.hpp¶
↰ Return to documentation for file (/home/jenkins/docs/models/dataloader/datasets.hpp
)
#ifndef MODELS_DATALOADER_DATASETS_HPP
#define MODELS_DATALOADER_DATASETS_HPP
#include <dataloader/preprocessor.hpp>
namespace mlpack {
namespace models {
template<
typename DatasetX = arma::mat,
typename DatasetY = arma::mat
>
struct DatasetDetails
{
std::string datasetName;
std::string trainDownloadURL;
std::string testDownloadURL;
std::string trainHash;
std::string testHash;
std::string datasetType;
std::string trainPath;
std::string testPath;
bool zipFile;
std::string datasetURL;
std::string datasetPath;
std::string datasetHash;
std::string serverName;
// Pre-Process functor.
std::function<void(DatasetX&, DatasetY&,
DatasetX&, DatasetY&, DatasetX&)> PreProcess;
// The following parameters are for CSVs only.
size_t startTrainingInputFeatures;
size_t endTrainingInputFeatures;
size_t startTrainingPredictionFeatures;
size_t endTrainingPredictionFeatures;
size_t startTestingInputFeatures;
size_t endTestingInputFeatures;
bool dropHeader;
// The following data members corresponds to image classification / detection
// type of datasets.
std::string trainingImagesPath;
std::string testingImagesPath;
std::string trainingAnnotationPath;
std::vector<std::string> classes;
size_t imageWidth;
size_t imageHeight;
size_t imageDepth;
// Default constructor.
DatasetDetails() :
datasetName(""),
trainDownloadURL(""),
testDownloadURL(""),
trainHash(""),
testHash(""),
datasetType("none"),
trainPath(""),
testPath(""),
zipFile(false),
datasetURL(""),
datasetPath(""),
datasetHash(""),
serverName("www.mlpack.org"),
startTrainingInputFeatures(0),
endTrainingInputFeatures(0),
startTrainingPredictionFeatures(0),
endTrainingPredictionFeatures(0),
startTestingInputFeatures(0),
endTestingInputFeatures(0),
dropHeader(false),
trainingImagesPath(""),
testingImagesPath(""),
trainingAnnotationPath(""),
classes(std::vector<std::string>()),
imageWidth(0),
imageHeight(0),
imageDepth(0)
{/* Nothing to do here. */}
DatasetDetails(const std::string& datasetName,
const std::string& trainDownloadURL,
const std::string& testDownloadURL,
const std::string& trainHash,
const std::string& testHash,
const std::string& datasetType,
const std::string& trainPath,
const std::string& testPath) :
datasetName(datasetName),
trainDownloadURL(trainDownloadURL),
testDownloadURL(testDownloadURL),
trainHash(trainHash),
testHash(testHash),
datasetType(datasetType),
trainPath(trainPath),
testPath(testPath),
zipFile(false),
datasetURL(""),
datasetHash(""),
serverName("www.mlpack.org"),
startTrainingInputFeatures(0),
endTrainingInputFeatures(0),
startTrainingPredictionFeatures(0),
endTrainingPredictionFeatures(0),
startTestingInputFeatures(0),
endTestingInputFeatures(0),
dropHeader(false),
trainingImagesPath(""),
testingImagesPath(""),
trainingAnnotationPath(""),
classes(std::vector<std::string>()),
imageWidth(0),
imageHeight(0),
imageDepth(0)
{
// Nothing to do here.
}
DatasetDetails(const std::string& datasetName,
const bool zipFile,
const std::string& datasetURL,
const std::string& datasetPath,
const std::string& datasetHash,
const std::string& datasetType,
const std::string& trainPath = "",
const std::string& testPath = "") :
datasetName(datasetName),
trainDownloadURL(""),
testDownloadURL(""),
trainHash(""),
testHash(""),
datasetType(datasetType),
trainPath(trainPath),
testPath(testPath),
zipFile(zipFile),
datasetURL(datasetURL),
datasetPath(datasetPath),
datasetHash(datasetHash),
serverName("www.mlpack.org"),
startTrainingInputFeatures(0),
endTrainingInputFeatures(0),
startTrainingPredictionFeatures(0),
endTrainingPredictionFeatures(0),
startTestingInputFeatures(0),
endTestingInputFeatures(0),
dropHeader(false),
trainingImagesPath(""),
testingImagesPath(""),
trainingAnnotationPath(""),
classes(std::vector<std::string>()),
imageWidth(0),
imageHeight(0),
imageDepth(0)
{
// Nothing to do here.
}
};
template<
typename DatasetX = arma::mat,
typename DatasetY = arma::mat
>
class Datasets
{
public:
const static DatasetDetails<DatasetX, DatasetY> MNIST()
{
DatasetDetails<DatasetX, DatasetY> mnistDetails(
"mnist",
true,
"/datasets/mnist.tar.gz",
"./../data/mnist.tar.gz",
"33470ca3",
"csv",
"./../data/mnist-dataset/mnist_train.csv",
"./../data/mnist-dataset/mnist_test.csv");
// Set the Pre-Processor Function.
mnistDetails.PreProcess = PreProcessor<DatasetX, DatasetY>::MNIST;
// Set Parameters for CSV file.
mnistDetails.startTestingInputFeatures = 0;
mnistDetails.endTestingInputFeatures = -1;
mnistDetails.startTrainingInputFeatures = 1;
mnistDetails.endTrainingInputFeatures = -1;
mnistDetails.startTrainingPredictionFeatures = 0;
mnistDetails.endTrainingPredictionFeatures = 0;
mnistDetails.dropHeader = true;
return mnistDetails;
}
const static DatasetDetails<DatasetX, DatasetY> VOCDetection()
{
DatasetDetails<DatasetX, DatasetY> VOCDetectionDetail(
"voc-detection",
true,
"/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
"./../data/VOCtrainval_11-May-2012.tar",
"504b9278",
"image-detection");
VOCDetectionDetail.trainingImagesPath =
"./../data/VOCdevkit/VOC2012/JPEGImages/";
VOCDetectionDetail.trainingAnnotationPath =
"./../data/VOCdevkit/VOC2012/Annotations/";
VOCDetectionDetail.serverName = "http://host.robots.ox.ac.uk";
VOCDetectionDetail.PreProcess = PreProcessor<DatasetX, DatasetY>::PascalVOC;
// Set classes for dataset.
VOCDetectionDetail.classes = {"background", "aeroplane", "bicycle",
"bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow",
"diningtable", "dog", "horse", "motorbike", "person", "pottedplant",
"sheep", "sofa", "train", "tvmonitor"};
return VOCDetectionDetail;
}
const static DatasetDetails<DatasetX, DatasetY> CIFAR10()
{
DatasetDetails<DatasetX, DatasetY> CIFAR10Detail(
"cifar10",
true,
"/datasets/cifar10.tar.gz",
"./../data/cifar10.tar.gz",
"4cd9757b",
"image-classification");
CIFAR10Detail.trainingImagesPath = "./../data/cifar10/train/";
CIFAR10Detail.testingImagesPath = "./../data/cifar10/test/";
CIFAR10Detail.serverName = "www.mlpack.org";
CIFAR10Detail.PreProcess = PreProcessor<DatasetX, DatasetY>::CIFAR10;
return CIFAR10Detail;
}
};
} // namespace models
} // namespace mlpack
#endif