Il Tuo Logo

                    
    import torch

    # Check if CUDA is correctly installed and the GPU is available
    print(f"Is CUDA available? {torch.cuda.is_available()}")
    print(f"How many CUDA devices are available? {torch.cuda.device_count()}")
    print(f"Name of the CUDA device: {torch.cuda.get_device_name(0)}")


    # Select the device to be used for the computation
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Create a tensor and send it to the device:
    # 1. The tensor is directly created on the device (more efficient)
    z = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 8]], dtype=torch.float32, device=device)

    # 2. The tensor is created on the CPU and then moved to the device
    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 8]]).to(torch.float32).to(device)

    print(z)
    print(x)

    # With z.shape we can get the shape of the tensor which indicates the number of elements in each dimension
    print(z.shape)

    # With z.size() we can get the total number of elements in the tensor
    print(z.size())

    # With z.numel() we can get the total number of elements in the tensor as well
    print(z.numel())

    # We can also get the data type of the tensor with z.dtype
    print(z.dtype)

    # Since we can also set the tensor's device, we can check the device of the tensor with z.device
    print(z.device)

    # Now, it is possible to create a tensor manually as we've seen before,
    # but PyTorch provides a variety of functions to create tensors with specific properties.
    # For instance, we can create a tensor with all zeros with torch.zeros(shape)
    z = torch.zeros((4, 4), dtype=torch.float32, device=device)
    print(z)

    # Similarly, we can create a tensor with all ones with torch.ones(shape)
    x = torch.ones((4, 4), dtype=torch.float32, device=device)
    print(x)

    # We can also create a tensor with random values with torch.rand(shape)
    y = torch.rand((4, 4), device=device)
    print(y)

    # We can generate a tensor with random values from a normal distribution with torch.randn(shape)
    y = torch.randn((4, 4), device=device)
    print(y)

    # We can also choose one value used to populate a tensor with torch.full(shape, value)
    z = torch.full((4, 4), 42, device=device)
    print(z)

    # There are multiple ways to create tensors with specific properties, my suggestion is to check the documentation
                    
                    
                            
                    

                    
    import torch

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    # Stride is a property of the tensor which indicates the number of elements in the memory 
    # between two consecutive elements in the tensor dimension
    z = torch.zeros((2, 2), dtype=torch.float64, device=device)
    print(z.stride())
    
    # pytorch operations allow us to perform element-wise operations on tensors
    # For instance, we can multiply a tensor with a scalar
    z = torch.ones((2, 2), dtype=torch.float64, device=device) * 2
    print(z)
    
    # We can also multiply two tensors element-wise, but they must have the same shape
    y = torch.ones((2, 2), dtype=torch.float64, device=device) * 6
    k = z * y
    print(k)
    
    # We can perform element-wise addition, subtraction, division, and exponentiation as well
    k = z + y
    print(k) # Addition
    
    k = z - y
    print(k) # Subtraction
    
    k = z / y
    print(k) # Division
    
    k = z ** y
    print(k) # Exponentiation
    
    # We can perform these operations in some cases where the tensors have different shapes
    # as long as the shapes are broadcastable
    # The tensor must have the same shape except for the one dimension
    y = torch.ones((1, 2), dtype=torch.float64, device=device) * 6
    k = z - y
    print(k)
    
    # Operations such as sum, mean, max, min, etc. can be performed on tensors
    # These operations can be performed along a specific dimension
    # This dimension will obviously collapse in one resulting element
    # For instance, we can sum all the elements of a tensor
    z = torch.ones((2, 4), dtype=torch.float64, device=device) * 2
    print(z)
    
    k = z.sum(dim=1)
    print(k)
    print(k.shape)
    
    # It is possible to preserve the dimension of the resulting tensor by setting keepdim=True
    k = z.sum(dim=1, keepdim=True)
    print(k)
    print(k.shape)
    
    # To concatenate two tensors along a specific dimension, we can use torch.cat
    # The tensors must have the same shape except for the dimension along which they are concatenated
    z = torch.ones((2, 2), dtype=torch.float64, device=device) * 2
    y = torch.ones((2, 2), dtype=torch.float64, device=device) * 6
    k = torch.cat((z, y), dim=0)
    print(k)
    print(k.shape)
                    
                    
                            
                    

                    
    import torch

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # since we will use random tensors, it is better to fix the seed
    torch.manual_seed(42)
    
    
    ################### BROADCASTING
    # Here broadcasting is automatically done. What happen under the hood is that the
    # smaller tensor is expanded to match the shape of the larger tensor.
    def broadcasting():
        z = torch.ones([2, 2], dtype=torch.float32, device=device)
        y = torch.rand([2, 1], dtype=torch.float32, device=device)
    
        print(z * 3)
    
        # The previous is equivalent to the following
        print(z * torch.tensor([3, 3], dtype=torch.float32, device=device))
    
        # Broadcast takes place also with tensors of different shapes
        print(z * y)
    
        # The previous is equivalent to the following
        print(z * y.expand(z.shape))
    
        # We can even drop the last dimension of y that it will be still automatically broadcasted
        print(y.squeeze().shape)
        print(z * y.squeeze())
    
        # it works even if z is multi-dimensional, as long as its dimensions are multiples of y's dimensions
        z = torch.ones([2, 2, 4, 2], dtype=torch.float32, device=device)
        y = torch.ones([2], dtype=torch.float32, device=device) * 2
        print(z * y.squeeze())
    
        # it does not work if the dimensions are not multiples
        try:
            z = torch.ones([2, 2, 4, 2], dtype=torch.float32, device=device)
            y = torch.ones([3], dtype=torch.float32, device=device) * 2
            print(z * y.squeeze())
        except RuntimeError as e:
            print(e)
    
    ################### SQUEEZE AND UNSQUEEZE
    # Squeeze removes all the dimensions of size 1
    def squeeze_unsqueeze():
        z = torch.ones([2, 1, 2, 1], dtype=torch.float32, device=device)
        print(z.squeeze().shape)
    
        # squeeze can take a dimension as argument
        print(z.squeeze(1).shape)
    
        # Unsqueeze adds a dimension of size 1
        z = torch.ones([2, 2], dtype=torch.float32, device=device)
        print(z.unsqueeze(0).shape)
    
    ################### INDEXING AND SLICING
    def indexing_slicing():
        # Indexing and slicing works as in numpy
        z = torch.ones([10, 2, 3], dtype=torch.float32, device=device)
    
        # get the first row
        print(z[:, 0])
        # get the first column
        print(z[0, :])
        # get the last dimension
        print(z[..., -1])
    
        # we can also use boolean masks
        z = torch.tensor([-1, 9, 3, -34, 12], dtype=torch.float32, device=device)
        mask = z > 0
        print(z[mask])
    
        # we can also use the where function: where(condition, x, y)
        # torch.where returns x if condition is True, y otherwise
        print(torch.where(mask, z, torch.zeros_like(z)))
    
        # we can also use the gather function
        # gather(input, dim, index)
        # input: tensor from which to gather values
        # dim: the dimension along which to index
        # index: the indices of the values to gather
        z = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32, device=device)
        print(torch.gather(z, 1, torch.tensor([[0], [1], [0]], device=device)))
        
    
    if __name__ == "__main__":
        # broadcasting()
        # squeeze_unsqueeze()
        indexing_slicing()
                    
                    
                            
                    

                    

    import torch

    # Let's start creating a tensor
    x = torch.tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
    
    # The shape of our tensor will be [3, 4]
    # keep it in memory for later
    print(x.shape)
    
    # this method will create a VIEW of the existing tensor
    # considering the size, stride and offset we set.
    print(torch.as_strided(x, [3, 3], (2, 2)))
    
    # the stride of the original tensor is [4, 1]
    # for each 4 column and 1 row of the elements in memory
    # print a row of the actual x tensor the way we wanted it.
    print(x.stride())
    
    # the tensor elements may be fragmented in memory, this could lead to
    # inefficiencies during operations between tensors. This method 
    # reorganizes the tensor to have all elements placed contiguously
    # within physical memory. This creates a copy of the tensor
    x.contiguous()
    
    # both of these methods change the shape of the tensor, BUT view operates
    # only on contiguous tensors, while reshape also on non-contiguous tensor,
    # and may return a copy of the original tensor. Thus use view whenever is possible
    print(x.view([1, -1]))
    print(torch.reshape(x, [1, -1]))
    
    # You can run operation between tensors on different devices. If you want to use
    # CUDA you can call the following methods
    
    print(torch.cuda.is_available())
    print(torch.cuda.get_device_name(0)) # 0 is the gpu ID
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # move the tensor on the gpu
    x = x.to(device)
    
    # create the tensor directly in GPU (more efficient when possible)
    x = torch.tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12]], device=device)

                    
                    
                            
                    

                    
    import torch

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # EPOCHS: number of times the entire dataset is passed through the network
    EPOCHS = 500
    
    # N: batch size, input_dimension: input data dimension, hidden_dimension: hidden layer dimension, output_dimension: output data dimension
    N, input_dimension, hidden_dimension, output_dimension = 64, 1000, 100, 10
    x = torch.randn(N, input_dimension, device=device) # dataset made of random numbers
    y = torch.randn(N, output_dimension, device=device) # dataset's label made of random numbers 
    
    w1 = torch.randn(input_dimension, hidden_dimension, device=device, requires_grad=True) # input weight matrix
    w2 = torch.randn(hidden_dimension, output_dimension, device=device, requires_grad=True) # output weight matrix
    
    learning_rate = 1e-6
    
    for epoch in range(EPOCHS):
        # this is not necessary, but I want to make it clear that the input data is x
        input_data = x
    
        # first the input data is multiplied by the input weight matrix through matrix multiplication
        # then the data is activated by ReLU (if the value is less than 0, it is changed to 0, otherwise it remains the same)
        hidden_data = torch.matmul(input_data, w1)
        hidden_data_activated = hidden_data.clamp(min=0)
        
        # the activated data is multiplied by the output weight matrix through matrix multiplication
        output_data = torch.matmul(hidden_data_activated, w2)
    
        # the loss is calculated by taking the sum of the squared difference between the output data and the label
        # this is equivalent to the mean squared error
        loss = (output_data - y).pow(2).sum()
    
        # the gradient of the loss with respect to the input weight matrix and the output weight matrix is calculated
        loss.backward()
        print(loss.item())
    
        with torch.no_grad():
            # the input weight matrix and the output weight matrix are updated by subtracting the product of the learning rate and the gradient
            # the gradient represents the direction in which the loss decreases, the learning rate represents the size of the step
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad
    
            # the gradient is reset to 0
            w1.grad.zero_()
            w2.grad.zero_()
                    
                    
                            
                    

                    
    import torch
    from torch import nn
    import torch.optim as optim
    
    
    
    # A neural network is defined as a class that inherits from nn.Module
    # The class has two main methods: __init__ and forward
    # __init__ is used to define the layers and attributes of the network
    # forward is used to define the forward pass of the network
    class Network(nn.Module):
        def __init__ (self, input_dimension: int, hidden_dimension: int, output_dimension: int) -> None:
            super(Network, self).__init__()
    
            # nn.Sequential is a container for modules, modules are applied in the order they are passed
            self.net = nn.Sequential(
                nn.Linear(input_dimension, hidden_dimension), # input linear layer
                nn.ReLU(), # activation function
                nn.Linear(hidden_dimension, output_dimension) # output linear layer
            )
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.net(x)
    
    # set seed for random generated numbers to allow reproducibility
    def set_seed(seed: int=42) -> None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
    
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
    
    if __name__ == "__main__":
        set_seed(42)
    
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
        # EPOCHS: number of times the entire dataset is passed through the network
        EPOCHS = 500
    
        # N: batch size, input_dimension: input data dimension, hidden_dimension: hidden layer dimension, output_dimension: output data dimension
        N, input_dimension, hidden_dimension, output_dimension = 64, 1000, 100, 10
        x = torch.randn(N, input_dimension, device=device) # dataset made of random numbers
        y = torch.randn(N, output_dimension, device=device) # dataset's label made of random numbers
    
        # the model is moved to the device
        model = Network(input_dimension, hidden_dimension, output_dimension)
        model = model.to(device)
    
        criterion = nn.MSELoss(reduction="sum") # mean squared error loss
        optimizer = optim.SGD(model.parameters(), lr=1e-4) # stochastic gradient descent optimizer 
    
        for epoch in range(EPOCHS):
            
            # the gradient is reset to 0 before running the model and calculating the loss
            optimizer.zero_grad()
    
            # the model is run with the input data
            # the loss is calculated by taking the sum of the squared difference between the output data and the label
            prediction = model(x)
            loss = criterion(prediction, y)
    
            loss.backward()
            print(loss.item())
    
            # the optimizer updates the model's parameters
            optimizer.step()
            
        print("\n")
    
        # VALIDATION STEP:
        # we define a validation dataset to test the goodness of our model
        # we first need to check if it does not overfit (it does...)
        x_val = torch.randn(N, input_dimension, device=device)
        y_val = torch.randn(N, output_dimension, device=device)
    
        # since we do not have to backpropagate because we do not want to train the model on the validation set
        # (otherwise AI would not make sense) we set torch.no_grad(). This allows the model not to retain gradients,
        # which means faster runtime and less memory footprint.
        with torch.no_grad():
            prediction = model(x_val)
            loss = criterion(prediction, y_val)
    
            print(f"validation loss: {loss.item()}")
                    
                    
                            
                    

                    
    import argparse

    import torch
    from torch import nn
    from torch.utils.data import DataLoader
    
    # we use torchvision to work with image datasets
    # we can download and load data, while also apply transforms on it
    from torchvision.datasets import MNIST
    import torchvision.transforms as T
    
    from tqdm import tqdm
    import logging
    
    # Let's define a basic Linear network with 1024 as hidden dimension
    # We use batch normalization, which normalizes tensors along the batch dimension
    # to help the model to better generalize
    class LinearNet(nn.Module):
        def __init__(self, in_channels: int, out_classes: int) -> None:
            super(LinearNet, self).__init__()
    
            self.arch = nn.Sequential(
                nn.Linear(in_channels, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Linear(1024, out_classes)
            )
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.arch(x)
    
    # a collate function is a special function executed before the dataloader provides a batch.
    # it is very useful to apply further custom operation on the data before using it (e.g.
    # you may add here positional encoding)
    def collate_fn(batch: tuple, device: torch.device):
        images, labels = zip(*batch)
        images = torch.stack(images).to(device)
        labels = torch.tensor(labels).to(device)
    
        return images, labels
    
    if __name__ == "__main__":
        # since we do not want to be bad programmer, we always need to make clear which parameters
        # the user can modify (useful for us to train multiple configurations)
        parser = argparse.ArgumentParser()
        parser.add_argument("-bs", "--batch-size", type=int, default=512, help="size of the batch of images")
        parser.add_argument("-ep", "--epochs", type=int, default=10, help="number of training epochs")
        
        args = parser.parse_args()
    
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
    
        logging.basicConfig(filename="Lecture6_Torchvision/training.log", level=logging.INFO)
    
        # A transform is always applied on data. Here first we transform an input image to tensor
        # since we will work then with tensors; then we normalize this tensor to lay in [-1, 1]
        # inteval thanks to 0.5 mean and 0.5 variance normalization (this helps the model to 
        # better generalize); finally we want to apply a custom transformation, we want to reshape
        # the tensor in order to make it linear (otherwise it does not fit into nn.Linear)
    
        # transform compose takes a list where order matters!
        transform = T.Compose([
            T.ToTensor(),
            T.Normalize((0.5), (0.5)),
            T.Lambda(lambda x: x.view(-1))
        ])
    
        # here we define the dataset: 
        # - first param: specifies the path to the dataset folder within the filesystem
        # - second param: datasets are tipically split into (trainset, valset, testset)
        #                 thus we need to specify which split we want
        # - third param: the transform we wrote before
        trainset = MNIST("/tmp/data", train=True, download=True, transform=transform) # (50.000 images)
        testset = MNIST("/tmp/data", train=False, download=True, transform=transform) # (10.000 images)
    
        # The dataloader is an iterable object that we will use to take the current batch during training or testing
        # - first param: the set object --> trainset or testset in this case
        # - second param: in training it is better to shuffle data because otherwise the network may learn to classify
        #                 only by remembering the order of the input data
        # - third param: num workers are the number of process which actively are involved in loading the data.
        #                0 means auto, N can go up to your processor number of threads (you may need to set multiprocessing)
        # - fourth param: collate fn we wrote before, where we can pass also the device
    
        trainloader = DataLoader(trainset, args.batch_size, shuffle=True, num_workers=0, collate_fn=lambda batch: collate_fn(batch, device))
        testloader = DataLoader(testset, args.batch_size, shuffle=False, num_workers=0, collate_fn=lambda batch: collate_fn(batch, device))
    
        model = LinearNet(in_channels=784, out_classes=10).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        
        numParameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logging.info(f"Model has {numParameters} parameters")
        logging.info(model)
    
    
        print("Training started!")
        pbar = tqdm(total=args.epochs, desc=f"EPOCH: 0 - running ...")
        for e in range(args.epochs):
            avg_loss = 0
    
            # Training Step: the output of a XXXXloader is always a tuple
            for (images, labels) in trainloader:
                predictions = model(images)
    
                loss = criterion(predictions, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                avg_loss += loss.item()
    
            avg_loss /= len(trainloader) / args.batch_size
    
            # Validation Step
            correct = 0
            total = 0
            with torch.no_grad():
                for (images, labels) in testloader:
                    predictions = model(images)
    
                    # we take the max values --> the highes probabilities (in the model's opinion)
                    _, predicted = torch.max(predictions, 1)
    
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
    
            accuracy = 100 * correct / total
            message = f"EPOCH: {e}: average loss is {avg_loss}, while accuracy is {accuracy}"
            pbar.set_description(message)
            logging.info(message)
            pbar.update(1)
                    
                    
                            
                    

                    
    import argparse

    import torch
    from torch import nn
    
    from torch.utils.data import DataLoader
    from torchvision import datasets
    import torchvision.transforms as T
    
    from tqdm import tqdm
    
    def collate_fn(batch: tuple, device: torch.device):
        images, labels = zip(*batch)
        images = torch.stack(images).to(device)
        labels = torch.tensor(labels).to(device)
    
        return images, labels
    
    def get_dataset(batch_size: int, num_workers: int, device: torch.device):
        data_path = '/tmp/data'
    
        train_transforms = T.Compose([
            T.ToTensor(),
            T.Normalize((0.5,), (0.5,)),
            T.RandomHorizontalFlip(),
            T.RandomRotation(10),
            T.RandomErasing()
        ])
    
        test_transforms = T.Compose([
            T.ToTensor(),
            T.Normalize((0.5,), (0.5,))
        ])
    
        train_set = datasets.CIFAR10(data_path, train=True, download=True, transform=train_transforms)
        val_set = datasets.CIFAR10(data_path, train=False, download=True, transform=test_transforms)
    
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=lambda batch: collate_fn(batch, device))
        val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=lambda batch: collate_fn(batch, device))
    
        return train_loader, val_loader
    
    class ChannelSELayer(nn.Module):
        def __init__(self, in_channels: int, reduction: int):
            super(ChannelSELayer, self).__init__()
    
            hidden_channels = in_channels // reduction
            self.reduction_ratio = reduction
            self.fc1 = nn.Linear(in_channels, hidden_channels, bias=True)
            self.fc2 = nn.Linear(hidden_channels, in_channels, bias=True)
            self.relu = nn.ReLU()
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            batch_size, num_channels, H, W = x.size()
            # Average along each channel
            squeeze_tensor = x.view(batch_size, num_channels, -1).mean(dim=2)
    
            # channel excitation
            fc_out_1 = self.relu(self.fc1(squeeze_tensor))
            fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
    
            a, b = squeeze_tensor.size()
            output_tensor = torch.mul(x, fc_out_2.view(a, b, 1, 1))
            return output_tensor
        
    class ConvBlock(nn.Module):
        def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int, reduction: int):
            super(ConvBlock, self).__init__()
    
            self.expander = nn.Conv2d(in_channels, out_channels * 4, kernel_size=1, stride=1)
            self.dwconv = nn.Conv2d(out_channels * 4, out_channels * 4, kernel_size, stride, padding, groups=out_channels * 4)
            self.bn = nn.BatchNorm2d(out_channels * 4)
            self.se = ChannelSELayer(out_channels * 4, reduction)
            self.reductor = nn.Conv2d(out_channels * 4, out_channels, kernel_size=1, stride=1)
    
            self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) if in_channels != out_channels else nn.Identity()
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            skip = self.skip_connection(x)
    
            x = self.expander(x)
            x = self.dwconv(x)
            x = self.bn(x)
            x = self.se(x)
            x = self.reductor(x) + skip
            return x
        
    class ConvNet(nn.Module):
        def __init__(self, in_channels: int, out_classes: int, reduction: int):
            super(ConvNet, self).__init__()
    
            self.arch = nn.Sequential(
                ConvBlock(in_channels, out_channels=96, kernel_size=3, stride=1, padding=1, reduction=reduction),
                nn.MaxPool2d(kernel_size=2, stride=2),
    
                ConvBlock(in_channels=96, out_channels=192, kernel_size=3, stride=1, padding=1, reduction=reduction),
                nn.MaxPool2d(kernel_size=2, stride=2),
    
                ConvBlock(in_channels=192, out_channels=384, kernel_size=3, stride=1, padding=1, reduction=reduction),
                nn.MaxPool2d(kernel_size=2, stride=2),
    
                ConvBlock(in_channels=384, out_channels=738, kernel_size=3, stride=1, padding=1, reduction=reduction),
                nn.AdaptiveAvgPool2d(1),
            )
    
            self.classifier = nn.Sequential(
                nn.Flatten(),
                nn.Linear(738, out_classes)
            )
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.arch(x)
            x = self.classifier(x)
            return x
    
    def ckpts_manager(ckpt_path: str, model: nn.Module, optimizer: torch.optim.Optimizer, mode: str):
        if ckpt_path is None or ckpt_path == '':
            print("No checkpoint path provided!")
            return model, optimizer
    
        if mode == 'load':
            ckpt = torch.load(ckpt_path)
            model.load_state_dict(ckpt['model'])
            optimizer.load_state_dict(ckpt['optimizer'])
        elif mode == 'save':
            ckpt = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(ckpt, ckpt_path)
    
        return model, optimizer
    
    
    def parse_args():
        parser = argparse.ArgumentParser()
        parser.add_argument("-bs", "--batch-size", type=int, default=512, help="size of the batch of images")
        parser.add_argument("-ep", "--epochs", type=int, default=10, help="number of training epochs")
        parser.add_argument("-r", "--reduction", type=int, default=4, help="reduction ratio for SE block")
        parser.add_argument("-lr", "--learning-rate", type=float, default=1e-4, help="learning rate for the optimizer")
        parser.add_argument("-nw", "--num-workers", type=int, default=0, help="number of workers for the dataloader")
        parser.add_argument('-sw', '--save-weights', type=str, default='weights.pth', help='path to save the weights')
        parser.add_argument('-lw', '--load-weights', type=str, default=None, help='path to load the weights')
        parser.add_argument('-cp', '--checkpoint', type=str, default=None, help='path to a checkpoint to load or store')
        
    
        return parser.parse_args()
    
    if __name__ == "__main__":
        args = parse_args()
    
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
    
        train_loader, val_loader = get_dataset(args.batch_size, num_workers=args.num_workers, device=device)
    
        model = ConvNet(in_channels=3, out_classes=10, reduction=args.reduction).to(device)
        print(model)
        print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
        if args.load_weights:
            model, optimizer = ckpts_manager(args.load_weights, model, optimizer, mode='load')
            print("Weights loaded!")
    
        pbar = tqdm(range(args.epochs))
        for epoch in pbar:
            model.train()
            for i, (images, labels) in enumerate(train_loader):
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
    
    
            model.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in val_loader:
                    outputs = model(images)
                    _, predicted = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
        
            pbar.set_description(f"Epoch {epoch + 1} | Loss: {loss.item():.4f} | Accuracy: {100 * correct / total:.2f}%")
            pbar.update(1)
    
        print("Training completed!")
        
        if args.checkpoint:
            model, optimizer = ckpts_manager(args.checkpoint, model, optimizer, mode='save')
            print("Checkpoint saved!")
        
        exit(0)