ISubGVQA.training.train_loop

Functions

train(→ float)

Train the model for a specified number of epochs.

Module Contents

ISubGVQA.training.train_loop.train(args: argparse.Namespace, model: torch.nn.Module, dataloaders: dict, criterion: torch.nn.Module, optimizer: torch.optim.Optimizer, gradscaler: torch.cuda.amp.GradScaler, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, save_model=False) float

Train the model for a specified number of epochs. Args:

args (Namespace): Arguments containing training configurations. model (torch.nn.Module): The model to be trained. dataloaders (dict): Dictionary containing ‘train’ and ‘dev’ dataloaders. criterion (torch.nn.Module): Loss function. optimizer (torch.optim.Optimizer): Optimizer for updating model parameters. gradscaler (torch.cuda.amp.GradScaler): Gradient scaler for mixed precision training. lr_scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. sampler_train (torch.utils.data.Sampler): Sampler for the training data. save_model (bool, optional): Flag to save the model checkpoints. Defaults to False.

Returns:

float: The highest validation accuracy achieved during training.