move to pytorch

This commit is contained in:
zomseffen 2022-02-12 17:35:15 +01:00
parent 0638d5e666
commit e718873caa
4 changed files with 135 additions and 22 deletions

View file

@ -17,9 +17,9 @@ class LabyrinthClient(Client):
if self.world_provider.world.board[x, y] in [1, 2]:
r, g, b = 57, 92, 152
if 1.5 >= self.world_provider.world.hunter_grass[x, y] > 0.5:
r, g, b = 25, 149, 156
if 3 >= self.world_provider.world.hunter_grass[x, y] > 1.5:
r, g, b = 112, 198, 169
if 3 >= self.world_provider.world.hunter_grass[x, y] > 1.5:
r, g, b = 25, 149, 156
self.world_provider.world.set_color(x, y, 0, r / 255.0, g / 255.0, b / 255.0)
if self.world_provider.world.board[x, y] == 3:
self.world_provider.world.set_color(x, y, 0, 139 / 255.0, 72 / 255.0, 82 / 255.0)

View file

@ -0,0 +1,125 @@
import torch
from torch import nn
import numpy as np
import tqdm
from torch.utils.data import Dataset, DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class BaseModel(nn.Module):
def __init__(self, view_dimension, action_num, channels):
super(BaseModel, self).__init__()
self.flatten = nn.Flatten()
self.actions = []
self.action_num = action_num
self.viewD = view_dimension
self.channels = channels
for action in range(action_num):
action_sequence = nn.Sequential(
nn.Linear(channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2,
(2 * self.viewD + 1) * (2 * self.viewD + 1)),
nn.ELU(),
nn.Linear((2 * self.viewD + 1) * (2 * self.viewD + 1), (self.viewD + 1) * (self.viewD + 1)),
nn.ELU(),
nn.Linear((self.viewD + 1) * (self.viewD + 1), 2)
)
self.add_module('action_' + str(action), action_sequence)
self.actions.append(action_sequence)
def forward(self, x):
x_flat = self.flatten(x)
actions = []
for action in range(self.action_num):
actions.append(self.actions[action](x_flat))
return torch.stack(actions, dim=1)
class BaseDataSet(Dataset):
def __init__(self, states, targets):
assert len(states) == len(targets), "Needs to have as many states as targets!"
self.states = torch.tensor(states, dtype=torch.float32)
self.targets = torch.tensor(targets, dtype=torch.float32)
def __len__(self):
return len(self.states)
def __getitem__(self, idx):
return self.states[idx], self.targets[idx]
def create_optimizer(model):
return torch.optim.RMSprop(model.parameters(), lr=1e-3)
def create_loss_function(action):
def custom_loss(prediction, target):
return torch.mean(0.5 * torch.square(
0.1 * target[:, 0, 0] + target[:, 1, 0] - (
prediction[:, action, 0] + prediction[:, action, 1])) + 0.5 * torch.square(
target[:, 1, 0] - prediction[:, action, 0]), dim=0)
return custom_loss
def from_numpy(x):
return torch.tensor(x, dtype=torch.float32)
def train(states, targets, model, optimizer):
for action in range(model.action_num):
data_set = BaseDataSet(states[action], targets[action])
dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
loss_fn = create_loss_function(action)
size = len(dataloader)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
model.eval()
if __name__ == '__main__':
sample = np.random.random((1, 4, 11, 11))
model = BaseModel(5, 4, 4).to(device)
print(model)
test = model(torch.tensor(sample, dtype=torch.float32))
# test = test.cpu().detach().numpy()
print(test)
state = np.random.random((4, 11, 11))
target = np.random.random((4, 2))
states = [
[state],
[state],
[state],
[state],
]
targets = [
[target],
[target],
[target],
[target],
]
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
train(states, targets, model, optimizer)

View file

View file

@ -5,6 +5,7 @@ from tensorflow import keras
from labirinth_ai.LabyrinthWorld import LabyrinthWorld
from labirinth_ai.loss import loss2, loss3
from labirinth_ai.Models.BaseModel import BaseModel, train, create_optimizer, device, from_numpy
# import torch
# dtype = torch.float
@ -369,22 +370,9 @@ class NetLearner(Subject):
self.x_in = []
self.actions = []
self.target = []
for i in range(4):
x_in = keras.Input(shape=(self.channels * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
self.x_in.append(x_in)
inVec = keras.layers.Flatten()(x_in)
actions = keras.layers.Dense(((2 * self.viewD + 1) * (2 * self.viewD + 1)), activation='elu',
kernel_regularizer=keras.regularizers.l2(0.001),
name=self.name + str(self.id) + 'Dense' + str(i) + 'l1')(inVec)
actions = keras.layers.Dense(((self.viewD + 1) * (self.viewD + 1)), activation='elu',
kernel_regularizer=keras.regularizers.l2(0.001))(actions)
self.target.append(keras.Input(shape=(2, 1)))
self.actions.append(keras.layers.Dense(2, activation='linear', use_bias=False, kernel_regularizer=keras.regularizers.l2(0.001))(actions))
self.model = keras.Model(inputs=self.x_in, outputs=self.actions)
self.model.compile(optimizer=tf.keras.optimizers.RMSprop(self.learningRate), loss=loss3,
target_tensors=self.target)
self.model = BaseModel(self.viewD, 4, 4)
self.model.to(device)
self.optimizer = create_optimizer(self.model)
if len(self.samples) < self.randomBuffer:
self.random = True
@ -508,7 +496,7 @@ class NetLearner(Subject):
if state is None:
state = self.createState(world)
if vals is None:
vals = self.model.predict([state, state, state, state])
vals = self.model(from_numpy(state)).detach().numpy()
vals = np.reshape(np.transpose(np.reshape(vals, (4, 2)), (1, 0)),
(1, 8))
@ -623,9 +611,9 @@ class NetLearner(Subject):
target[:, 1, 0] = samples[:, 1, 3] #reward t-2 got
nextState = np.concatenate(samples[:, 1, 0]) #states of t-1
nextVals = self.model.predict([nextState, nextState, nextState, nextState])
nextVals = self.model(from_numpy(nextState)).detach().numpy()
nextVals2 = nextVals[i][:, 0] + nextVals[i][:, 1]
nextVals2 = nextVals[:, i, 0] + nextVals[:, i, 1]
target[:, 0, 0] = nextVals2 #best q t-1
else:
target[:, 1, 0] = np.array(list(map(lambda elem: list(elem), list(np.array(samples[:, 1, 4])))))[:, i] # reward t-2 got
@ -639,7 +627,7 @@ class NetLearner(Subject):
def train(self):
print(self.name)
states, target = self.generateSamples()
self.model.fit(states, target, epochs=1)
train(states, target, self.model, self.optimizer)
self.samples = self.samples[-self.historySizeMul*self.batchsize:]