19 lines
612 B
Python
19 lines
612 B
Python
from pathlib import Path
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
CIFAR_DIR = Path('Data/CIFAR10')
|
|
CIFAR_DIR.mkdir(exist_ok = True)
|
|
|
|
|
|
normalize = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
|
|
TRAIN_DATA = torchvision.datasets.CIFAR10(root = CIFAR_DIR, train = True, transform = normalize, download = True)
|
|
TEST_DATA = torchvision.datasets.CIFAR10(root = CIFAR_DIR, train = False, transform = normalize, download = True)
|
|
|
|
|
|
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
|
|
|
|