
UNet Implementation using PyTorch | Car Segmentation


Input image

Input image

Predicted mask

Predicted mask

Image and mask overlay

Image and mask overlay

Getting Started

git clone git@github.com:yakhyo/unet-pytorch.git
cd unet-pytorch

To Do:


Carvana Image Masking (PNG) dataset is used to train the model. After downloading the data place them under ./data directory.

├── data
    ├── train_images
         ├── xxx.jpg
         ├── xxy.jpg
         ├── xxz.jpg
    ├── train_masks
         ├── xxx.png
         ├── xxy.png
         ├── xxz.png


Training arguments

usage: train.py [-h] [--data DATA] [--scale SCALE] [--num-classes NUM_CLASSES] [--weights WEIGHTS] [--epochs EPOCHS] [--batch-size BATCH_SIZE] [--num-workers N] [--lr LR] [--weight-decay WEIGHT_DECAY] [--momentum MOMENTUM] [--amp] [--print-freq PRINT_FREQ]
                [--resume RESUME] [--use-deterministic-algorithms] [--save-dir SAVE_DIR]

UNet training arguments

  -h, --help            show this help message and exit
  --data DATA           Directory containing the dataset (default: './data')
  --scale SCALE         Scale factor for input image size (default: 0.5)
  --num-classes NUM_CLASSES
                        Number of output classes (default: 2)
  --weights WEIGHTS     Path to pretrained model weights (default: '')
  --epochs EPOCHS       Number of training epochs (default: 10)
  --batch-size BATCH_SIZE
                        Batch size for training (default: 4)
  --num-workers N       Number of data loading workers (default: 8)
  --lr LR               Learning rate (default: 1e-5)
  --weight-decay WEIGHT_DECAY
                        Weight decay (default: 1e-8)
  --momentum MOMENTUM   Momentum (default: 0.9)
  --amp                 Enable mixed precision training
  --print-freq PRINT_FREQ
                        Frequency of printing training progress (default: 10)
  --resume RESUME       Path to checkpoint to resume training from (default: '')
                        Forces the use of deterministic algorithms only.
  --save-dir SAVE_DIR   Directory to save model weights (default: 'weights')

Train the model

python train.py


Inference arguments

usage: inference.py [-h] [--model-path MODEL_PATH] [--image-path IMAGE_PATH] [--scale SCALE] [--save-overlay]

Image Segmentation Inference

  -h, --help            show this help message and exit
  --model-path MODEL_PATH
                        Path to the model weights
  --image-path IMAGE_PATH
                        Path to the input image
  --scale SCALE         Scale factor for resizing the image
  --save-overlay        Save the overlay image if this flag is set


python inference.py --model-path weights/last.pt --image-path assets/image.jpg