"""
    Knowledgedump.org - Image Classification - preprocess_images
    Loading and preprocessing the images from CIFAR-10 dataset.

    Required packages: torch, torchvision, numpy, matplotlib
"""

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as pypl
import os, sys



# Function for preparing the data, returning the DataLoaders for the training and testing dataset.
# Default batch size set to 32 by default as solid middle ground.

def load_and_preprocess_cifar10(batch_size=32):
    """
        When loading the dataset from the library torchvision, the images are pillow class objects, with each pixel in
        the 32x32 images being assigned an RGB value, i.e. [0,255]x[0,255]x[0,255]. For potentially faster/more stable
        training later on, we rescale this to values [-1,1]x[-1,1]x[-1,1].
        The transformations we apply can be directly fed into the torchvision.datasets.CIFAR10 class via the
        "transform" parameter, which takes a function with input of a PIL object and outputs a transformed version of it.
    """
    # Compose the transformation of the PIL object into one function.
    trafo = torchvision.transforms.Compose([
        # Transform PIL object to a torch.FloatTensor object of shape C x H x W, where C is the channel count (3 here)
        # and H x W are the image dimensions. The individual RGB values between [0,255] are automatically rescaled to [0,1].
        torchvision.transforms.ToTensor(),
        # Next, rescale to zero centered values in the range [-1,1]. An alternative could be to omit this step and
        # keep the [0,1] values or calculate the means and standard deviation of each channel value, in order to standardize
        # them to a mean of 0 and a standard deviation of 1 (z-score normalization).
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load the CIFAR-10 dataset. Download the dataset to the directory of the script calling this module (if not already present).
    # Apply transformations to get rescaled torch.FloatTensor objects for training and testing set.
    destination_dir = os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), "data")
    trainset = torchvision.datasets.CIFAR10(root=destination_dir, train=True, download=True, transform=trafo)
    testset = torchvision.datasets.CIFAR10(root=destination_dir, train=False, download=True, transform=trafo)

    # Create training and testing DataLoaders for batch processing.
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    return trainloader, testloader



"""
    Function for displaying the loaded images of a batch with matplotlib, to get a feeling for the data.
    This displays a batch of images in a grid and names their respective content on top (left to right).
    Input parameters:
        - images: Batch of images,
        - labels: Corresponding labels for the images (0,1,...,9),
        - classes: List of class names (airplane, automobile etc.).
"""

def show_images(images, labels, classes, batch_size):
    # Convert the tensor images to a grid.
    img_grid = torchvision.utils.make_grid(images)
    # Create a numpy array from the data. To visualize it with matplotlib, we have to rescale the [-1,1] values back to [0,1].
    np_img = img_grid.numpy()
    np_img = (np_img + 1) / 2
    
    # Since the tensor objects are of shape C x H x W and we want H x W x C for displaying with pyplot,
    # transpose the array prior.
    pypl.imshow(np.transpose(np_img, (1, 2, 0)))
    # Add title that describes the content of each image in the grid.
    pypl.title(" ".join([classes[labels[j]] for j in range(batch_size)]))
    pypl.show()


    
if __name__ == "__main__":
    # Load and preprocess CIFAR-10 dataset with default batch size of 32.
    trainloader, testloader = load_and_preprocess_cifar10()

    # Output number of batches and image shapes to verify the loaded data.
    print(f"Number of training batches: {len(trainloader)}")
    print(f"Number of testing batches: {len(testloader)}")

    # Define CIFAR-10 class names.
    classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
    
    # Set up iterator on DataLoader of training set.
    dataiter = iter(trainloader)

    # Show 3 batches of training sets.
    for _ in range(3):
        images, labels = next(dataiter)
        show_images(images, labels, classes, batch_size=trainloader.batch_size)
