Skip to content

Training

Contains the train module that governs training Deepymod

train(model, data, target, optimizer, sparsity_scheduler, split=0.8, exp_ID=None, log_dir=None, max_iterations=10000, write_iterations=25, **convergence_kwargs)

Trains the DeepMoD model. This function automatically splits the data set in a train and test set.

Parameters:

Name Type Description Default
model DeepMoD

A DeepMoD object.

required
data Tensor

Tensor of shape (n_samples x (n_spatial + 1)) containing the coordinates, first column should be the time coordinate.

required
target Tensor

Tensor of shape (n_samples x n_features) containing the target data.

required
optimizer [type]

Pytorch optimizer.

required
sparsity_scheduler [type]

Decides when to update the sparsity mask.

required
split float

Fraction of the train set, by default 0.8.

0.8
exp_ID str

Unique ID to identify tensorboard file. Not used if log_dir is given, see pytorch documentation.

None
log_dir str

Directory where tensorboard file is written, by default None.

None
max_iterations int

[description]. Max number of epochs , by default 10000.

10000
write_iterations int

[description]. Sets how often data is written to tensorboard and checks train loss , by default 25.

25
Source code in deepymod/training/training.py
def train(model: DeepMoD,
          data: torch.Tensor,
          target: torch.Tensor,
          optimizer,
          sparsity_scheduler,
          split: float = 0.8,
          exp_ID: str = None,
          log_dir: str = None,
          max_iterations: int = 10000,
          write_iterations: int = 25,
          **convergence_kwargs) -> None:
    """Trains the DeepMoD model. This function automatically splits the data set in a train and test set. 

    Args:
        model (DeepMoD):  A DeepMoD object.
        data (torch.Tensor):  Tensor of shape (n_samples x (n_spatial + 1)) containing the coordinates, first column should be the time coordinate.
        target (torch.Tensor): Tensor of shape (n_samples x n_features) containing the target data.
        optimizer ([type]):  Pytorch optimizer.
        sparsity_scheduler ([type]):  Decides when to update the sparsity mask.
        split (float, optional):  Fraction of the train set, by default 0.8.
        exp_ID (str, optional): Unique ID to identify tensorboard file. Not used if log_dir is given, see pytorch documentation.
        log_dir (str, optional): Directory where tensorboard file is written, by default None.
        max_iterations (int, optional): [description]. Max number of epochs , by default 10000.
        write_iterations (int, optional): [description]. Sets how often data is written to tensorboard and checks train loss , by default 25.
    """
    logger = Logger(exp_ID, log_dir)
    sparsity_scheduler.path = logger.log_dir # write checkpoint to same folder as tb output.

    # Splitting data, assumes data is already randomized
    n_train = int(split * data.shape[0])
    n_test = data.shape[0] - n_train
    data_train, data_test = torch.split(data, [n_train, n_test], dim=0)
    target_train, target_test = torch.split(target, [n_train, n_test], dim=0)

    # Training
    convergence = Convergence(**convergence_kwargs)
    for iteration in torch.arange(0, max_iterations):
        # ================== Training Model ============================
        prediction, time_derivs, thetas = model(data_train)

        MSE = torch.mean((prediction - target_train)**2, dim=0)  # loss per output
        Reg = torch.stack([torch.mean((dt - theta @ coeff_vector)**2)
                           for dt, theta, coeff_vector in zip(time_derivs, thetas, model.constraint_coeffs(scaled=False, sparse=True))])
        loss = torch.sum(MSE + Reg) 

        # Optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if iteration % write_iterations == 0:
            # ================== Validation costs ================
            with torch.no_grad():
                prediction_test = model.func_approx(data_test)[0]
                MSE_test = torch.mean((prediction_test - target_test)**2, dim=0)  # loss per output

            # ====================== Logging =======================
            _ = model.sparse_estimator(thetas, time_derivs) # calculating estimator coeffs but not setting mask
            logger(iteration, 
                   loss, MSE, Reg,
                   model.constraint_coeffs(sparse=True, scaled=True), 
                   model.constraint_coeffs(sparse=True, scaled=False),
                   model.estimator_coeffs(),
                   MSE_test=MSE_test)

            # ================== Sparsity update =============
            # Updating sparsity 
            update_sparsity = sparsity_scheduler(iteration, torch.sum(MSE_test), model, optimizer)
            if update_sparsity: 
                model.constraint.sparsity_masks = model.sparse_estimator(thetas, time_derivs)

            # ================= Checking convergence
            l1_norm = torch.sum(torch.abs(torch.cat(model.constraint_coeffs(sparse=True, scaled=True), dim=1)))
            converged = convergence(iteration, l1_norm)
            if converged:
                break
    logger.close(model)