Using config files

Resuming from checkpoints

You can resume from a previously saved checkpoint by:

python train.py --resume path/to/checkpoint

Using Multiple GPU

You can enable multi-GPU training by setting n_gpu argument of the config file to larger number.
If configured to use smaller number of gpu than available, first n devices will be used by default.
Specify indices of available GPUs by cuda environmental variable.

python train.py --device 2,3 -c config.json

This is equivalent to

CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py

Customization

Data Loader

Writing your own data loader

Inherit BaseDataLoader

BaseDataLoader is a subclass of torch.utils.data.DataLoader, you can use either of them.

Please refer to data_loader/data_loaders.py for an MNIST data loading example.

Trainer

Writing your own trainer

Inherit BaseTrainer

BaseTrainer handles:

Training process logging

Checkpoint saving

Checkpoint resuming

Reconfigurable performance monitoring for saving current best model, and early stop training.

If config monitor is set to max val_accuracy, which means then the trainer will save a checkpoint model_best.pth when validation accuracy of epoch replaces current maximum.

If config early_stop is set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to the early_stop option, or just deleting the line of config.

Implementing abstract methods

You need to implement _train_epoch() for your training process, if you need validation then you can implement _valid_epoch() as in trainer/trainer.py

Example

Please refer to trainer/trainer.py for MNIST training.

Model

Writing your own model

Inherit BaseModel

BaseModel handles:

Inherited from torch.nn.Module

__str__: Modify native print function to prints the number of trainable parameters.

Implementing abstract methods

Implement the foward pass method forward()

Example

Please refer to model/model.py for a LeNet example.

Loss

Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name.

Metrics

Metric functions are located in 'model/metric.py'.

You can monitor multiple metrics by providing a list in the configuration file, e.g.:

"metrics": ["my_metric", "my_metric2"],

Additional logging

If you have additional information to be logged, in _train_epoch() of your trainer class, merge them with log as shown below before returning:

Testing

You can test trained model by running test.py passing path to the trained checkpoint by --resume argument.

Validation data

To split validation data from a data loader, call BaseDataLoader.split_validation(), it will return a validation data loader, with the number of samples according to the specified ratio in your config file.

Note: the split_validation() method will modify the original data loader
Note: split_validation() will return None if "validation_split" is set to 0

Checkpoints

You can specify the name of the training session in config files:

"name": "MNIST_LeNet",

The checkpoints will be saved in save_dir/name/timestamp/checkpoint_epoch_n, with timestamp in mmdd_HHMMSS format.

TensorboardX Visualization

Type tensorboard --logdir saved/runs/ at the project root, then server will open at http://localhost:6006

By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged.
If you need more visualizations, use add_scalar('tag', data), add_image('tag', image), etc in the trainer._train_epoch method.
add_something() methods in this template are basically wrappers for those of tensorboardX.SummaryWriter module.

Note: You don't have to specify current steps, since WriterTensorboardX class defined at logger/visualization.py will track current steps.

Contributing

Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8