In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

figsize = (14, 6)

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

DATA_ROOT_FOLDER = '/Users/brandon/Data/pytorch'

DEMO_ROOT_FOLDER = f'{DATA_ROOT_FOLDER}/model/ee541-demo1'
os.makedirs(f'{DEMO_ROOT_FOLDER}', exist_ok = True) 

batch_size = 100

train_set = torchvision.datasets.FashionMNIST(root = f"{DATA_ROOT_FOLDER}/data", train = True, download = True, transform = transforms.ToTensor())
test_set = torchvision.datasets.FashionMNIST(root = f"{DATA_ROOT_FOLDER}/data", train = False, download = True, transform = transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

num_pixels = 28*28

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.hidden = nn.Linear(num_pixels, 128)
        self.output = nn.Linear(128, 10)

    def forward(self, x):
        x = self.hidden(x)
        x = F.relu(x)
        
        x = self.output(x)
        return x
model = Net()

#model = torch.nn.Sequential(
#        nn.Linear(in_features=num_pixels, out_features=128),
#        nn.ReLU(),
#        nn.Linear(in_features=128, out_features=10)
#        #nn.Softmax(dim=1)
#)

device = 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else device)
device = torch.device("mps" if torch.backends.mps.is_available() else device)
model.to(device)

print(device)

In [None]:
loss_func = nn.CrossEntropyLoss()

num_epochs = 4
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# for plots
loss_list = []
iteration_list = []
accuracy_list = []

count = 0
for epoch in range(num_epochs):
    for images, labels in train_loader: 
        count += 1
        
        model.train()

        # Transfering images and labels to GPU if available
        images, labels = images.to(device), labels.to(device)

        train = images.view(-1, num_pixels)
        
        # Forward pass 
        outputs = model(train)
        loss = loss_func(outputs, labels)

        # back-prop
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()

        if (count < 30) or not(count % 100):
            with torch.no_grad():
                total = 0
                correct = 0

                for images, labels in test_loader:
                    model.eval()
                    images = images.to(device)

                    test = images.view(-1, num_pixels)
                    outputs = model(test).cpu()

                    predictions = torch.max(outputs, 1)[1]
                    correct += (predictions == labels).sum().numpy()
                    total += len(labels)
                    
                accuracy = correct * 100 / total
                
                loss_list.append(loss.cpu().data)
                iteration_list.append(count)
                accuracy_list.append(accuracy)
        
                print(f'Epoch: {epoch+1:02d}, Iteration: {count:5d}, Loss: {loss.data:.4f}, Accuracy: {accuracy:.3f}%')      

print('Finished Training')

In [None]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
#print("Optimizer's state_dict:")
#for var_name in optimizer.state_dict():
#    print(var_name, "\t", optimizer.state_dict()[var_name])

torch.save(model.state_dict(), f'{DEMO_ROOT_FOLDER}/model.pth')
copy = torch.load(f'{DEMO_ROOT_FOLDER}/model.pth')

In [None]:
from torchinfo import summary
summary(model, input_size=(1, 1, num_pixels), device=device)
#summary(model, input_size=(batch_size, 1, num_pixels), device=device)

In [None]:
fig = plt.figure(figsize=figsize)

sp = fig.add_subplot(1, 2, 1)
plt.plot(iteration_list, loss_list)

sp.set_xlabel('Iteration')
sp.set_ylabel('Loss')

sp = fig.add_subplot(1, 2, 2)
plt.plot(iteration_list, accuracy_list)

sp.set_xlabel('Iteration')
sp.set_ylabel('Accuracy')

plt.show()