beundling weights

This commit is contained in:
zomseffen 2022-12-21 16:08:22 +01:00
parent bd56173379
commit b0d22f6bf1

View file

@ -51,36 +51,40 @@ class EvolutionModel(nn.Module):
self.incoming_connections[connection.end].append(connection)
self.layers = {}
self.layer_non_recurrent_inputs = {}
self.layer_recurrent_inputs = {}
self.layer_results = {}
self.layer_num = 1
self.indices = {}
self.has_recurrent = False
self.non_recurrent_indices = {}
self.recurrent_indices = {}
with torch.no_grad():
for key, value in self.incoming_connections.items():
value.sort(key=lambda element: element.start)
lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
for index, connection in enumerate(value):
lin.weight[0, index] = value[index].weight
if self.genes.nodes[key].bias is not None:
lin.bias[0] = self.genes.nodes[key].bias
for key, value in self.incoming_connections.items():
value.sort(key=lambda element: element.start)
non_lin = nn.ELU()
sequence = nn.Sequential(
lin,
non_lin
)
self.add_module('layer_' + str(key), sequence)
self.layers[key] = sequence
self.indices[key] = list(map(lambda element: element.start, value))
# lin = nn.Linear(len(value), 1, bias=self.genes.nodes[key].bias is not None)
# for index, connection in enumerate(value):
# lin.weight[0, index] = value[index].weight
# if self.genes.nodes[key].bias is not None:
# lin.bias[0] = self.genes.nodes[key].bias
#
# non_lin = nn.ELU()
# sequence = nn.Sequential(
# lin,
# non_lin
# )
# self.add_module('layer_' + str(key), sequence)
# self.layers[key] = sequence
self.indices[key] = list(map(lambda element: element.start, value))
self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
self.has_recurrent = True
self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
self.non_recurrent_indices[key] = list(filter(lambda element: not element.recurrent, value))
self.recurrent_indices[key] = list(filter(lambda element: element.recurrent, value))
if not self.has_recurrent and len(self.non_recurrent_indices[key]) != len(self.indices[key]):
self.has_recurrent = True
self.non_recurrent_indices[key] = list(map(lambda element: element.start, self.non_recurrent_indices[key]))
self.recurrent_indices[key] = list(map(lambda element: element.start, self.recurrent_indices[key]))
rank_of_node = {}
for i in range(self.num_input_nodes):
rank_of_node[i] = 0
@ -101,20 +105,39 @@ class EvolutionModel(nn.Module):
rank_of_node[key] = max_rank + 1
layers_to_add = list(filter(lambda element: element[0] not in rank_of_node.keys(), layers_to_add))
ranked_layers = list(rank_of_node.items())
ranked_layers.sort(key=lambda element: element[1])
ranked_layers = list(filter(lambda element: element[1] > 0, ranked_layers))
ranked_layers = list(map(lambda element: (element, 0),
filter(lambda recurrent_element:
recurrent_element not in list(
map(lambda ranked_layer: ranked_layer[0], ranked_layers)
),
list(filter(lambda recurrent_keys:
len(self.recurrent_indices[recurrent_keys]) > 0,
self.recurrent_indices.keys()))))) + ranked_layers
with torch.no_grad():
self.layer_num = max_rank = max(map(lambda element: element[1], rank_of_node.items()))
#todo: handle solely recurrent nodes
for rank in range(1, max_rank + 1):
# get nodes
nodes = list(map(lambda element: element[0], filter(lambda item: item[1] == rank, rank_of_node.items())))
non_recurrent_inputs = list(set.union(*map(lambda node: set(self.non_recurrent_indices[node]), nodes)))
non_recurrent_inputs.sort()
recurrent_inputs = list(set.union(*map(lambda node: set(self.recurrent_indices[node]), nodes)))
recurrent_inputs.sort()
lin = nn.Linear(len(non_recurrent_inputs) + len(recurrent_inputs), len(nodes), bias=True)
# todo: load weights
# for index, connection in enumerate(value):
# lin.weight[0, index] = value[index].weight
# if self.genes.nodes[key].bias is not None:
# lin.bias[0] = self.genes.nodes[key].bias
#
non_lin = nn.ELU()
sequence = nn.Sequential(
lin,
non_lin
)
self.add_module('layer_' + str(rank), sequence)
self.layers[rank] = sequence
self.layer_results[rank] = nodes
self.layer_non_recurrent_inputs[rank] = non_recurrent_inputs
self.layer_recurrent_inputs[rank] = recurrent_inputs
self.layer_order = list(map(lambda element: element[0], ranked_layers))
self.memory_size = (max(map(lambda element: element[1].node_id, self.genes.nodes.items())) + 1)
self.memory = torch.Tensor(self.memory_size)
self.output_range = range(self.num_input_nodes, self.num_input_nodes + self.action_num * 2)
@ -130,24 +153,25 @@ class EvolutionModel(nn.Module):
outs = []
for batch_index, batch_element in enumerate(x_flat):
memory[0:self.num_input_nodes] = batch_element
for layer_index in self.layer_order:
non_recurrent_in = memory[self.non_recurrent_indices[layer_index]]
for layer_index in range(1, self.layer_num + 1):
non_recurrent_in = memory[self.layer_non_recurrent_inputs[layer_index]]
non_recurrent_in = torch.stack([non_recurrent_in])
if self.has_recurrent and len(self.recurrent_indices[layer_index]) > 0:
recurrent_in = last_memory_flat[batch_index, self.recurrent_indices[layer_index]]
if self.has_recurrent and len(self.layer_recurrent_inputs[layer_index]) > 0:
recurrent_in = last_memory_flat[batch_index, self.layer_recurrent_inputs[layer_index]]
recurrent_in = torch.stack([recurrent_in])
combined_in = torch.concat([non_recurrent_in, recurrent_in], dim=1)
else:
combined_in = non_recurrent_in
memory[layer_index] = self.layers[layer_index](combined_in)
outs.append(memory[self.num_input_nodes: self.num_input_nodes + self.action_num * 2])
memory[self.layer_results[layer_index]] = self.layers[layer_index](combined_in)
outs.append(memory[self.output_range])
outs = torch.stack(outs)
self.memory = torch.Tensor(memory)
return torch.reshape(outs, (x.shape[0], outs.shape[1]//2, 2))
def update_genes_with_weights(self):
# todo rework
for key, value in self.incoming_connections.items():
value.sort(key=lambda element: element.start)