26 lines
644 B
Python
26 lines
644 B
Python
from torch.utils.data import Dataset
|
|
class DogCatDataset(Dataset):
|
|
def __init__(self, ds, dog=[5], cat = [3]):
|
|
self.ds = ds
|
|
self.idx = []
|
|
for i in range(len(ds)):
|
|
img, lab = ds[i]
|
|
if lab in dog or lab in cat:
|
|
self.idx.append(i)
|
|
|
|
def __len__(self):
|
|
return len(self.idx)
|
|
|
|
def __getitem__(self, idx):
|
|
orig_idx = self.idx[idx]
|
|
img, lab = self.ds[orig_idx]
|
|
if lab == 5:
|
|
bin_lab = 1
|
|
elif lab == 3:
|
|
bin_lab = 0
|
|
else:
|
|
print('we got a non dog or cat label')
|
|
|
|
return img, bin_lab
|
|
|
|
|