ML-Workshop/consts.py
2025-03-20 12:45:57 -04:00

19 lines
612 B
Python

from pathlib import Path
import torchvision
import torchvision.transforms as transforms
CIFAR_DIR = Path('Data/CIFAR10')
CIFAR_DIR.mkdir(exist_ok = True)
normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
TRAIN_DATA = torchvision.datasets.CIFAR10(root = CIFAR_DIR, train = True, transform = normalize, download = True)
TEST_DATA = torchvision.datasets.CIFAR10(root = CIFAR_DIR, train = False, transform = normalize, download = True)
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']