/Users/seb/Library/r-miniconda-arm64/envs/machine_learning/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'dlopen(/Users/seb/Library/r-miniconda-arm64/envs/machine_learning/lib/python3.9/site-packages/torchvision/image.so, 0x0006): Symbol not found: __ZN3c106detail19maybe_wrap_dim_slowIxEET_S2_S2_b
Referenced from: <8CBD0B78-6C7C-3C8B-8C76-ACA7B6112818> /Users/seb/Library/r-miniconda-arm64/envs/machine_learning/lib/python3.9/site-packages/torchvision/image.so
Expected in: <07CB8E54-8386-3606-A01E-B92223F93B74> /Users/seb/Library/r-miniconda-arm64/envs/machine_learning/lib/python3.9/site-packages/torch/lib/libc10.dylib'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
warn(
from torchvision.transforms import ToTensor# Download training data from open datasets.training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(),)# Download test data from open datasets.test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(),)training_data.data.shape
torch.Size([60000, 28, 28])
test_data.data.shape
torch.Size([10000, 28, 28])
batch_size =64# Create data loaders.train_dataloader = DataLoader(training_data, batch_size=batch_size)test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape}{y.dtype}")break
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
# Get cpu, gpu or mps device for training.device = ("cuda"if torch.cuda.is_available()else"mps"if torch.backends.mps.is_available()else"cpu")print(f"Using {device} device")
x, y = test_data[0][0], test_data[0][1]with torch.no_grad(): x = x.to(device) pred = model(x) predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')