reworks loss function and training data, cleans up directions in code
This commit is contained in:
parent
cf4d773c10
commit
26e7ffb12b
2 changed files with 129 additions and 175 deletions
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import os
|
||||
|
@ -14,6 +14,7 @@ print(f"Using {device} device")
|
|||
# Define model
|
||||
class BaseModel(nn.Module):
|
||||
evolutionary = False
|
||||
|
||||
def __init__(self, view_dimension, action_num, channels):
|
||||
super(BaseModel, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
|
@ -59,11 +60,18 @@ def create_optimizer(model):
|
|||
|
||||
|
||||
def create_loss_function(action):
|
||||
lambda_factor = 0.0
|
||||
split_factor = 1.0
|
||||
def custom_loss(prediction, target):
|
||||
return torch.mean(0.5 * torch.square(
|
||||
0.1 * target[:, 0, 0] + target[:, 1, 0] - (
|
||||
prediction[:, action, 0] + prediction[:, action, 1])) + 0.5 * torch.square(
|
||||
target[:, 1, 0] - prediction[:, action, 0]), dim=0)
|
||||
return torch.mean(split_factor * torch.square(
|
||||
# discounted best estimate the old weights made for t+1
|
||||
lambda_factor * target[:, 0, 0] +
|
||||
# actual reward for t
|
||||
target[:, 1, 0] -
|
||||
# estimate for current weights
|
||||
(prediction[:, action, 0] + prediction[:, action, 1])) +
|
||||
# trying to learn present reward separate from future reward
|
||||
(1.0 - split_factor) * torch.square(target[:, 1, 0] - prediction[:, action, 0]), dim=0)
|
||||
|
||||
return custom_loss
|
||||
|
||||
|
@ -75,26 +83,31 @@ def from_numpy(x):
|
|||
def train(states, targets, model, optimizer):
|
||||
for action in range(model.action_num):
|
||||
data_set = BaseDataSet(states[action], targets[action])
|
||||
dataloader = DataLoader(data_set, batch_size=64, shuffle=True)
|
||||
dataloader = DataLoader(data_set, batch_size=256, shuffle=True)
|
||||
loss_fn = create_loss_function(action)
|
||||
|
||||
size = len(dataloader)
|
||||
model.train()
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
# Compute prediction error
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
epochs = 1
|
||||
with tqdm(range(epochs)) as progress_bar:
|
||||
for _ in enumerate(progress_bar):
|
||||
losses = []
|
||||
for batch, (X, y) in enumerate(dataloader):
|
||||
X, y = X.to(device), y.to(device)
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
# Compute prediction error
|
||||
pred = model(X)
|
||||
loss = loss_fn(pred, y)
|
||||
|
||||
if batch % 100 == 0:
|
||||
loss, current = loss.item(), batch * len(X)
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
losses.append(loss.item())
|
||||
progress_bar.set_postfix(loss=np.average(losses))
|
||||
progress_bar.update()
|
||||
model.eval()
|
||||
|
||||
del data_set
|
||||
|
|
|
@ -108,34 +108,6 @@ class QLearner(Subject):
|
|||
def createState(self, world: LabyrinthWorld):
|
||||
state = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.int) # - 1
|
||||
|
||||
# # floodfill state
|
||||
# queued = [(0, 0)]
|
||||
# todo = [(0, 0, 0)]
|
||||
# while todo != []:
|
||||
# doing = todo.pop(0)
|
||||
#
|
||||
# if self.x + doing[0] >= 0 and self.x + doing[0] < 64 and self.y + doing[1] >= 0 and self.y + doing[1] < 64:
|
||||
# value = world.board[self.x + doing[0], self.y + doing[1]]
|
||||
# state[self.viewD + doing[0], self.viewD + doing[1]] = value
|
||||
#
|
||||
# # if value == 3:
|
||||
# # state[self.viewD + doing[0], self.viewD + doing[1]] = value
|
||||
#
|
||||
# if value != 0 and doing[2] < self.viewD:
|
||||
# for i in range(-1, 2, 1):
|
||||
# for j in range(-1, 2, 1):
|
||||
# # 4-neighbour. without it it is 8-neighbour
|
||||
# if abs(i) + abs(j) == 1:
|
||||
# if (doing[0] + i, doing[1] + j) not in queued:
|
||||
# queued.append((doing[0] + i, doing[1] + j))
|
||||
# todo.append((doing[0] + i, doing[1] + j, doing[2] + 1))
|
||||
#
|
||||
# for sub in world.subjects:
|
||||
# if (sub.x - self.x, sub.y - self.y) in queued and state[
|
||||
# self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] != 3:
|
||||
# state[self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] = state[
|
||||
# self.viewD + sub.x - self.x, self.viewD + sub.y - self.y] * 100 + sub.col
|
||||
|
||||
maxdirleft = self.x - max(self.x - (self.viewD), 0)
|
||||
maxdirright = min(self.x + (self.viewD), (world.board_shape[0] - 1)) - self.x
|
||||
maxdirup = self.y - max(self.y - (self.viewD), 0)
|
||||
|
@ -300,6 +272,7 @@ class DoubleQLearner(QLearner):
|
|||
pass
|
||||
|
||||
|
||||
RECALCULATE = False
|
||||
class NetLearner(Subject):
|
||||
right = (1, 0)
|
||||
left = (-1, 0)
|
||||
|
@ -440,7 +413,6 @@ class NetLearner(Subject):
|
|||
axs[1, 1].set_title('grass')
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
def createState(self, world: LabyrinthWorld):
|
||||
state = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float) # - 1
|
||||
state2 = np.zeros((2 * self.viewD + 1, 2 * self.viewD + 1), np.float) # - 1
|
||||
|
@ -474,10 +446,8 @@ class NetLearner(Subject):
|
|||
action = self.lastAction
|
||||
return np.reshape(np.concatenate((area, action)), (1, 4 * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
|
||||
|
||||
def calculateAction(self, world: LabyrinthWorld, vals=None, state=None):
|
||||
# 0, 0 is top left
|
||||
def generate_valid_directions(self, world: LabyrinthWorld):
|
||||
directions = []
|
||||
|
||||
if self.x - 1 >= 0:
|
||||
if world.board[self.x - 1, self.y] != 0:
|
||||
directions.append(self.left)
|
||||
|
@ -493,9 +463,15 @@ class NetLearner(Subject):
|
|||
if self.y + 1 < world.board_shape[1]:
|
||||
if world.board[self.x, self.y + 1] != 0:
|
||||
directions.append(self.down)
|
||||
return directions
|
||||
|
||||
def calculateAction(self, world: LabyrinthWorld, vals=None, state=None):
|
||||
# 0, 0 is top left
|
||||
directions = self.generate_valid_directions(world)
|
||||
|
||||
if directions == []:
|
||||
print('Wut?')
|
||||
return
|
||||
|
||||
if directions != [] and self.alive:
|
||||
if state is None:
|
||||
|
@ -550,7 +526,8 @@ class NetLearner(Subject):
|
|||
self.nextTrain = min(self.batchsize + self.nextTrain, (self.historySizeMul + 1) * self.batchsize)
|
||||
print(len(self.samples), self.nextTrain)
|
||||
|
||||
self.accumulated_rewards += self.lastReward
|
||||
if not self.random:
|
||||
self.accumulated_rewards += self.lastReward
|
||||
|
||||
self.lastAction = self.action
|
||||
self.lastState = self.state
|
||||
|
@ -562,27 +539,10 @@ class NetLearner(Subject):
|
|||
self.executeAction(world, self.action)
|
||||
|
||||
def randomAct(self, world: LabyrinthWorld):
|
||||
right = (1, 0)
|
||||
left = (-1, 0)
|
||||
up = (0, -1)
|
||||
down = (0, 1)
|
||||
directions = []
|
||||
directions = self.generate_valid_directions(world)
|
||||
|
||||
if self.x - 1 >= 0:
|
||||
if world.board[self.x - 1, self.y] != 0:
|
||||
directions.append(left)
|
||||
|
||||
if self.x + 1 < world.board_shape[0]:
|
||||
if world.board[self.x + 1, self.y] != 0:
|
||||
directions.append(right)
|
||||
|
||||
if self.y - 1 >= 0:
|
||||
if world.board[self.x, self.y - 1] != 0:
|
||||
directions.append(up)
|
||||
|
||||
if self.y + 1 < world.board_shape[1]:
|
||||
if world.board[self.x, self.y + 1] != 0:
|
||||
directions.append(down)
|
||||
if len(directions) == 0:
|
||||
return 0, 0
|
||||
|
||||
d = random.randint(0, len(directions) - 1)
|
||||
action = directions[d]
|
||||
|
@ -616,16 +576,16 @@ class NetLearner(Subject):
|
|||
replace=True)
|
||||
samples = samples[index]
|
||||
# self.samples = []
|
||||
target[:, 1, 0] = samples[:, 0, 3] # reward t-2 got
|
||||
if partTwo:
|
||||
target[:, 1, 0] = samples[:, 1, 3] #reward t-2 got
|
||||
if RECALCULATE:
|
||||
nextState = np.concatenate(samples[:, 1, 0]) #states of t-1
|
||||
nextVals = self.model(from_numpy(nextState)).detach().numpy()
|
||||
|
||||
nextState = np.concatenate(samples[:, 1, 0]) #states of t-1
|
||||
nextVals = self.model(from_numpy(nextState)).detach().numpy()
|
||||
|
||||
nextVals2 = nextVals[:, i, 0] + nextVals[:, i, 1]
|
||||
target[:, 0, 0] = nextVals2 #best q t-1
|
||||
else:
|
||||
target[:, 1, 0] = np.array(list(map(lambda elem: list(elem), list(np.array(samples[:, 1, 4])))))[:, i] # reward t-2 got
|
||||
nextVals2 = np.max(nextVals[:, :, 0] + nextVals[:, :, 1], axis=1)
|
||||
target[:, 0, 0] = nextVals2 #best q t-1
|
||||
else:
|
||||
target[:, 0, 0] = samples[:, 1, 2] #best q t-1
|
||||
|
||||
targets.append(target)
|
||||
|
||||
|
@ -697,27 +657,7 @@ class Herbivore(NetLearner):
|
|||
return np.reshape(np.concatenate((area, action)), (1, 4 * (2 * self.viewD + 1) * (2 * self.viewD + 1) + 2))
|
||||
|
||||
def executeAction(self, world: LabyrinthWorld, action):
|
||||
right = (1, 0)
|
||||
left = (-1, 0)
|
||||
up = (0, -1)
|
||||
down = (0, 1)
|
||||
directions = []
|
||||
|
||||
if self.x - 1 >= 0:
|
||||
if world.board[self.x - 1, self.y] != 0:
|
||||
directions.append(left)
|
||||
|
||||
if self.x + 1 < world.board_shape[0]:
|
||||
if world.board[self.x + 1, self.y] != 0:
|
||||
directions.append(right)
|
||||
|
||||
if self.y - 1 >= 0:
|
||||
if world.board[self.x, self.y - 1] != 0:
|
||||
directions.append(up)
|
||||
|
||||
if self.y + 1 < world.board_shape[1]:
|
||||
if world.board[self.x, self.y + 1] != 0:
|
||||
directions.append(down)
|
||||
directions = self.generate_valid_directions(world)
|
||||
if len(action) == 2:
|
||||
if len(world.subjectDict[(self.x + action[0], self.y + action[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
|
||||
|
@ -729,26 +669,26 @@ class Herbivore(NetLearner):
|
|||
self.alive = False
|
||||
|
||||
self.lastRewards = []
|
||||
if right in directions:
|
||||
if self.right in directions:
|
||||
self.lastRewards.append(world.grass[self.x + 1, self.y])
|
||||
else:
|
||||
self.lastRewards.append(0)
|
||||
if left in directions:
|
||||
if self.left in directions:
|
||||
self.lastRewards.append(world.grass[self.x - 1, self.y])
|
||||
else:
|
||||
self.lastRewards.append(0)
|
||||
if up in directions:
|
||||
if self.up in directions:
|
||||
self.lastRewards.append(world.grass[self.x, self.y - 1])
|
||||
else:
|
||||
self.lastRewards.append(0)
|
||||
if down in directions:
|
||||
if self.down in directions:
|
||||
self.lastRewards.append(world.grass[self.x, self.y + 1])
|
||||
else:
|
||||
self.lastRewards.append(0)
|
||||
assert len(self.lastRewards) == 4, 'Last Rewards not filled correctly!'
|
||||
|
||||
world.subjectDict[(self.x, self.y)].remove(self)
|
||||
self.lastReward += world.trailMix[self.x, self.y]
|
||||
# self.lastReward += world.trailMix[self.x, self.y]
|
||||
self.x += action[0]
|
||||
self.y += action[1]
|
||||
world.subjectDict[(self.x, self.y)].append(self)
|
||||
|
@ -757,33 +697,58 @@ class Herbivore(NetLearner):
|
|||
world.grass[self.x, self.y] = 0
|
||||
world.hunter_grass[self.x, self.y] = 0
|
||||
|
||||
def generate_valid_directions(self, world: LabyrinthWorld):
|
||||
directions = []
|
||||
if self.x - 1 >= 0:
|
||||
if world.board[self.x - 1, self.y] != 0:
|
||||
if not world.subjectDict[(self.x - 1, self.y)]:
|
||||
directions.append(self.left)
|
||||
|
||||
if self.x + 1 < world.board_shape[0]:
|
||||
if world.board[self.x + 1, self.y] != 0:
|
||||
if not world.subjectDict[(self.x + 1, self.y)]:
|
||||
directions.append(self.right)
|
||||
|
||||
if self.y - 1 >= 0:
|
||||
if world.board[self.x, self.y - 1] != 0:
|
||||
if not world.subjectDict[(self.x, self.y - 1)]:
|
||||
directions.append(self.up)
|
||||
|
||||
if self.y + 1 < world.board_shape[1]:
|
||||
if world.board[self.x, self.y + 1] != 0:
|
||||
if not world.subjectDict[(self.x, self.y + 1)]:
|
||||
directions.append(self.down)
|
||||
return directions
|
||||
|
||||
def randomAct(self, world: LabyrinthWorld):
|
||||
right = (1, 0)
|
||||
left = (-1, 0)
|
||||
up = (0, -1)
|
||||
down = (0, 1)
|
||||
directions = []
|
||||
actDict = {}
|
||||
|
||||
if self.x - 1 >= 0:
|
||||
if world.board[self.x - 1, self.y] != 0:
|
||||
directions.append(left)
|
||||
actDict[left] = world.grass[self.x - 1, self.y]
|
||||
if not world.subjectDict[(self.x - 1, self.y)]:
|
||||
directions.append(self.left)
|
||||
actDict[self.left] = world.grass[self.x - 1, self.y]
|
||||
|
||||
if self.x + 1 < world.board_shape[0]:
|
||||
if world.board[self.x + 1, self.y] != 0:
|
||||
directions.append(right)
|
||||
actDict[right] = world.grass[self.x + 1, self.y]
|
||||
if not world.subjectDict[(self.x + 1, self.y)]:
|
||||
directions.append(self.right)
|
||||
actDict[self.right] = world.grass[self.x + 1, self.y]
|
||||
|
||||
if self.y - 1 >= 0:
|
||||
if world.board[self.x, self.y - 1] != 0:
|
||||
directions.append(up)
|
||||
actDict[up] = world.grass[self.x, self.y - 1]
|
||||
if not world.subjectDict[(self.x, self.y - 1)]:
|
||||
directions.append(self.up)
|
||||
actDict[self.up] = world.grass[self.x, self.y - 1]
|
||||
|
||||
if self.y + 1 < world.board_shape[1]:
|
||||
if world.board[self.x, self.y + 1] != 0:
|
||||
directions.append(down)
|
||||
actDict[down] = world.grass[self.x, self.y + 1]
|
||||
if not world.subjectDict[(self.x, self.y + 1)]:
|
||||
directions.append(self.down)
|
||||
actDict[self.down] = world.grass[self.x, self.y + 1]
|
||||
if len(directions) == 0:
|
||||
return 0, 0
|
||||
|
||||
allowedActions = dict(filter(lambda elem: elem[0] in directions, actDict.items()))
|
||||
action = max(allowedActions, key=allowedActions.get)
|
||||
|
@ -792,7 +757,7 @@ class Herbivore(NetLearner):
|
|||
|
||||
def respawnUpdate(self, x, y, world: LabyrinthWorld):
|
||||
super(Herbivore, self).respawnUpdate(x, y, world)
|
||||
self.lastReward -= 1
|
||||
# self.lastReward -= 1
|
||||
|
||||
|
||||
class Hunter(NetLearner):
|
||||
|
@ -802,16 +767,12 @@ class Hunter(NetLearner):
|
|||
g = 255
|
||||
b = 255
|
||||
def randomAct(self, world: LabyrinthWorld):
|
||||
right = (1, 0)
|
||||
left = (-1, 0)
|
||||
up = (0, -1)
|
||||
down = (0, 1)
|
||||
directions = []
|
||||
actDict = {}
|
||||
|
||||
if self.x - 1 >= 0:
|
||||
if world.board[self.x - 1, self.y] > 0.01:
|
||||
directions.append(left)
|
||||
directions.append(self.left)
|
||||
|
||||
sub = self.getClosestSubject(world, self.x - 1, self.y)
|
||||
dist = self.viewD
|
||||
|
@ -819,15 +780,15 @@ class Hunter(NetLearner):
|
|||
dist = np.sqrt(np.square(self.x - 1 - sub.x) + np.square(self.y - sub.y))
|
||||
distReward = self.viewD - dist
|
||||
|
||||
actDict[left] = world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * self.hunterGrassScale + distReward
|
||||
if len(world.subjectDict[(self.x + left[0], self.y + left[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + left[0], self.y + left[1])]:
|
||||
actDict[self.left] = world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * self.hunterGrassScale + distReward
|
||||
if len(world.subjectDict[(self.x + self.left[0], self.y + self.left[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + self.left[0], self.y + self.left[1])]:
|
||||
if sub.col != self.col:
|
||||
actDict[left] += 10
|
||||
actDict[self.left] += 10
|
||||
|
||||
if self.x + 1 < world.board_shape[0]:
|
||||
if world.board[self.x + 1, self.y] > 0.01:
|
||||
directions.append(right)
|
||||
directions.append(self.right)
|
||||
|
||||
sub = self.getClosestSubject(world, self.x + 1, self.y)
|
||||
dist = self.viewD
|
||||
|
@ -835,15 +796,15 @@ class Hunter(NetLearner):
|
|||
dist = np.sqrt(np.square(self.x + 1 - sub.x) + np.square(self.y - sub.y))
|
||||
distReward = self.viewD - dist
|
||||
|
||||
actDict[right] = world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * self.hunterGrassScale + distReward
|
||||
if len(world.subjectDict[(self.x + right[0], self.y + right[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + right[0], self.y + right[1])]:
|
||||
actDict[self.right] = world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * self.hunterGrassScale + distReward
|
||||
if len(world.subjectDict[(self.x + self.right[0], self.y + self.right[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + self.right[0], self.y + self.right[1])]:
|
||||
if sub.col != self.col:
|
||||
actDict[right] += 10
|
||||
actDict[self.right] += 10
|
||||
|
||||
if self.y - 1 >= 0:
|
||||
if world.board[self.x, self.y - 1] > 0.01:
|
||||
directions.append(up)
|
||||
directions.append(self.up)
|
||||
|
||||
sub = self.getClosestSubject(world, self.x, self.y - 1)
|
||||
dist = self.viewD
|
||||
|
@ -851,15 +812,15 @@ class Hunter(NetLearner):
|
|||
dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y - 1 - sub.y))
|
||||
distReward = self.viewD - dist
|
||||
|
||||
actDict[up] = world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * self.hunterGrassScale + distReward
|
||||
if len(world.subjectDict[(self.x + up[0], self.y + up[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + up[0], self.y + up[1])]:
|
||||
actDict[self.up] = world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * self.hunterGrassScale + distReward
|
||||
if len(world.subjectDict[(self.x + self.up[0], self.y + self.up[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + self.up[0], self.y + self.up[1])]:
|
||||
if sub.col != self.col:
|
||||
actDict[up] += 10
|
||||
actDict[self.up] += 10
|
||||
|
||||
if self.y + 1 < world.board_shape[1]:
|
||||
if world.board[self.x, self.y + 1] > 0.01:
|
||||
directions.append(down)
|
||||
directions.append(self.down)
|
||||
|
||||
sub = self.getClosestSubject(world, self.x, self.y + 1)
|
||||
dist = self.viewD
|
||||
|
@ -867,11 +828,11 @@ class Hunter(NetLearner):
|
|||
dist = np.sqrt(np.square(self.x - sub.x) + np.square(self.y + 1 - sub.y))
|
||||
distReward = self.viewD - dist
|
||||
|
||||
actDict[down] = world.trailMix[self.x, self.y + 1] + world.hunter_grass[self.x, self.y + 1] * self.hunterGrassScale + distReward
|
||||
if len(world.subjectDict[(self.x + down[0], self.y + down[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + down[0], self.y + down[1])]:
|
||||
actDict[self.down] = world.trailMix[self.x, self.y + 1] + world.hunter_grass[self.x, self.y + 1] * self.hunterGrassScale + distReward
|
||||
if len(world.subjectDict[(self.x + self.down[0], self.y + self.down[1])]) > 0:
|
||||
for sub in world.subjectDict[(self.x + self.down[0], self.y + self.down[1])]:
|
||||
if sub.col != self.col:
|
||||
actDict[down] += 10
|
||||
actDict[self.down] += 10
|
||||
|
||||
if len(actDict) > 0:
|
||||
allowedActions = dict(filter(lambda elem: elem[0] in directions, actDict.items()))
|
||||
|
@ -919,47 +880,27 @@ class Hunter(NetLearner):
|
|||
def executeAction(self, world: LabyrinthWorld, action):
|
||||
grass_factor = 0.5
|
||||
|
||||
right = (1, 0)
|
||||
left = (-1, 0)
|
||||
up = (0, -1)
|
||||
down = (0, 1)
|
||||
directions = []
|
||||
|
||||
if self.x - 1 >= 0:
|
||||
if world.board[self.x - 1, self.y] != 0:
|
||||
directions.append(left)
|
||||
|
||||
if self.x + 1 < world.board_shape[0]:
|
||||
if world.board[self.x + 1, self.y] != 0:
|
||||
directions.append(right)
|
||||
|
||||
if self.y - 1 >= 0:
|
||||
if world.board[self.x, self.y - 1] != 0:
|
||||
directions.append(up)
|
||||
|
||||
if self.y + 1 < world.board_shape[1]:
|
||||
if world.board[self.x, self.y + 1] != 0:
|
||||
directions.append(down)
|
||||
directions = self.generate_valid_directions(world)
|
||||
|
||||
if len(action) == 2:
|
||||
right_kill = left_kill = up_kill = down_kill = False
|
||||
if right in directions:
|
||||
for sub in world.subjectDict[(self.x + action[0], self.y + action[1])]:
|
||||
if self.right in directions:
|
||||
for sub in world.subjectDict[(self.x + self.right[0], self.y + self.right[1])]:
|
||||
if sub.alive:
|
||||
if sub.col != self.col:
|
||||
right_kill = True
|
||||
if left in directions:
|
||||
for sub in world.subjectDict[(self.x + left[0], self.y + left[1])]:
|
||||
if self.left in directions:
|
||||
for sub in world.subjectDict[(self.x + self.left[0], self.y + self.left[1])]:
|
||||
if sub.alive:
|
||||
if sub.col != self.col:
|
||||
left_kill = True
|
||||
if up in directions:
|
||||
for sub in world.subjectDict[(self.x + up[0], self.y + up[1])]:
|
||||
if self.up in directions:
|
||||
for sub in world.subjectDict[(self.x + self.up[0], self.y + self.up[1])]:
|
||||
if sub.alive:
|
||||
if sub.col != self.col:
|
||||
up_kill = True
|
||||
if down in directions:
|
||||
for sub in world.subjectDict[(self.x + down[0], self.y + down[1])]:
|
||||
if self.down in directions:
|
||||
for sub in world.subjectDict[(self.x + self.down[0], self.y + self.down[1])]:
|
||||
if sub.alive:
|
||||
if sub.col != self.col:
|
||||
down_kill = True
|
||||
|
@ -974,7 +915,7 @@ class Hunter(NetLearner):
|
|||
self.alive = True
|
||||
|
||||
self.lastRewards = []
|
||||
if right in directions:
|
||||
if self.right in directions:
|
||||
sub = self.getClosestSubject(world, self.x + 1, self.y)
|
||||
dist = self.viewD
|
||||
if sub is not None:
|
||||
|
@ -986,7 +927,7 @@ class Hunter(NetLearner):
|
|||
self.lastRewards.append(world.trailMix[self.x + 1, self.y] + world.hunter_grass[self.x + 1, self.y] * grass_factor + distReward)
|
||||
else:
|
||||
self.lastRewards.append(0)
|
||||
if left in directions:
|
||||
if self.left in directions:
|
||||
sub = self.getClosestSubject(world, self.x - 1, self.y)
|
||||
dist = self.viewD
|
||||
if sub is not None:
|
||||
|
@ -998,7 +939,7 @@ class Hunter(NetLearner):
|
|||
self.lastRewards.append(world.trailMix[self.x - 1, self.y] + world.hunter_grass[self.x - 1, self.y] * grass_factor + distReward)
|
||||
else:
|
||||
self.lastRewards.append(0)
|
||||
if up in directions:
|
||||
if self.up in directions:
|
||||
sub = self.getClosestSubject(world, self.x, self.y - 1)
|
||||
dist = self.viewD
|
||||
if sub is not None:
|
||||
|
@ -1010,7 +951,7 @@ class Hunter(NetLearner):
|
|||
self.lastRewards.append(world.trailMix[self.x, self.y - 1] + world.hunter_grass[self.x, self.y - 1] * grass_factor + distReward)
|
||||
else:
|
||||
self.lastRewards.append(0)
|
||||
if down in directions:
|
||||
if self.down in directions:
|
||||
sub = self.getClosestSubject(world, self.x, self.y + 1)
|
||||
dist = self.viewD
|
||||
if sub is not None:
|
||||
|
|
Loading…
Add table
Reference in a new issue