From 33b5d9c83e9f5ea7fa1de7af7e1adccbb1e7c0af Mon Sep 17 00:00:00 2001 From: zomseffen Date: Sat, 12 Feb 2022 19:30:03 +0100 Subject: [PATCH] solves exiting --- labirinth_ai/Models/BaseModel.py | 7 ++++++- labirinth_ai/Subject.py | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/labirinth_ai/Models/BaseModel.py b/labirinth_ai/Models/BaseModel.py index e4210ad..e87a3c9 100644 --- a/labirinth_ai/Models/BaseModel.py +++ b/labirinth_ai/Models/BaseModel.py @@ -4,6 +4,9 @@ import numpy as np import tqdm from torch.utils.data import Dataset, DataLoader +import os +os.environ["TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT"] = "0" + device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using {device} device") @@ -36,7 +39,6 @@ class BaseModel(nn.Module): actions.append(self.actions[action](x_flat)) return torch.stack(actions, dim=1) - class BaseDataSet(Dataset): def __init__(self, states, targets): assert len(states) == len(targets), "Needs to have as many states as targets!" @@ -93,6 +95,9 @@ def train(states, targets, model, optimizer): print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") model.eval() + del data_set + del dataloader + if __name__ == '__main__': sample = np.random.random((1, 4, 11, 11)) diff --git a/labirinth_ai/Subject.py b/labirinth_ai/Subject.py index 3fe291a..5afa2a7 100644 --- a/labirinth_ai/Subject.py +++ b/labirinth_ai/Subject.py @@ -370,8 +370,7 @@ class NetLearner(Subject): self.x_in = [] self.actions = [] self.target = [] - self.model = BaseModel(self.viewD, 4, 4) - self.model.to(device) + self.model = BaseModel(self.viewD, 4, 4).to(device) self.optimizer = create_optimizer(self.model) if len(self.samples) < self.randomBuffer: