"""
    Knowledgedump.org - Image Classification - image_classification
    This script trains a simple CNN model for the CIFAR-10 dataset and briefly analyzes its performance.
    (CIFAR-10 dataset from "Learning Multiple Layers of Features from Tiny Images", Alex Krizhevsky, 2009)
    The individual steps are carried out in the respective modules, with the names being self-explanatory for their function:
        - preprocess_images.py - load + preprocess data for training
        - define_model.py - set up CNN model,
        - train_model.py - train the model on data and evaluate at each step.


    preprocesses the images, and splits the dataset into training and test sets.
    CIFAR-10 consists of 60000 32x32 color images in 10 classes, with 6000 images per class. The classes are mutually
    exclusive, i.e. each image belongs to exactly one class and each one is grouped into 5000 training and 1000 testing images.
    The dataset is loaded using PyTorch and torchvision.

    Required packages: torch, torchvision, numpy, matplotlib
"""

from preprocess_images import load_and_preprocess_cifar10
from define_model import SimpleCNN
from train_model import train_model
from visualize_results import plot_results

# Train and evaluate model with different parameters:
if __name__ == "__main__":

    # Load and preprocess CIFAR-10 dataset with batch_size=32.
    trainloader, testloader = load_and_preprocess_cifar10(batch_size=32)

    for epo, lrn in zip([10, 10, 20, 20], [0.001, 0.0005, 0.001, 0.0005]):    
        # Initialize our "SimpleCNN" model.
        model = SimpleCNN()

        # Train the model with the parameters
        perf_res = train_model(model, trainloader, testloader, epochs=epo, learning_rate=lrn, device=None)

        # Visualize the performance results.
        plot_results(*perf_res)