git clone git@github.com:yakhyo/unet-pytorch.git
cd unet-pytorch
dice_score = 1 - dice_loss
used for evaluation.weights
folderCarvana dataset is used to train and test the model.
bash tools/download.sh
Or simply download the train_hq.zip
and train_masks.zip
and extract those folders as shown below:
├── data
├── images
├── xxx.jpg
├── xxy.jpg
├── xxz.jpg
....
├── masks
├── xxx_mask.gif
├── xxy_mask.gif
├── xxz_mask.gif
Note: Please download kaggle.json
file first from associated kaggle account then save it under
the /home/username/.kaggle/kaggle.json
(ubuntu)
Training arguments
python -m tools.main -h
usage: main.py [-h] [--image_size IMAGE_SIZE] [--save-dir SAVE_DIR] [--epochs EPOCHS] [--batch-size BATCH_SIZE] [--lr LR] [--weights WEIGHTS] [--amp] [--num-classes NUM_CLASSES]
Train the model
python -m tools.main
Inference arguments
python inference.py -h
usage: inference.py [-h] [--weights WEIGHTS] [--input INPUT] [--output OUTPUT] [--view] [--no-save] [--conf-thresh CONF_THRESH]
Inference an image
python -m tools.inference --weights weights/last.pt --input assets/image.jpg --output result.jpg