227 lines
9.9 KiB
Python
227 lines
9.9 KiB
Python
from abc import abstractmethod
|
|
from typing import List, Dict
|
|
from copy import copy
|
|
|
|
import numpy as np
|
|
|
|
|
|
class NodeGene:
|
|
valid_types = ['sensor', 'hidden', 'output']
|
|
|
|
def __init__(self, node_id, node_type, bias=None):
|
|
assert node_type in self.valid_types, 'Unknown node type!'
|
|
self.node_id = node_id
|
|
self.node_type = node_type
|
|
if node_type == 'hidden':
|
|
if bias is None:
|
|
bias = np.random.random(1)[0] * 2 - 1.0
|
|
self.bias = bias
|
|
else:
|
|
self.bias = None
|
|
|
|
def __copy__(self):
|
|
return NodeGene(self.node_id, self.node_type, bias=self.bias)
|
|
|
|
|
|
class ConnectionGene:
|
|
def __init__(self, start, end, enabled, innovation_num, weight=None, recurrent=False):
|
|
self.start = start
|
|
self.end = end
|
|
self.enabled = enabled
|
|
self.innvovation_num = innovation_num
|
|
self.recurrent = recurrent
|
|
if weight is None:
|
|
self.weight = np.random.random(1)[0] * 2 - 1.0
|
|
else:
|
|
self.weight = weight
|
|
|
|
def __copy__(self):
|
|
return ConnectionGene(self.start, self.end, self.enabled, self.innvovation_num, self.weight, self.recurrent)
|
|
|
|
|
|
class Genotype:
|
|
def __init__(self, action_num: int = None, num_input_nodes: int = None,
|
|
nodes: Dict[int, NodeGene] = None, connections: List[ConnectionGene] = None):
|
|
self.nodes: Dict[int, NodeGene] = {}
|
|
self.connections: List[ConnectionGene] = []
|
|
if action_num is not None and num_input_nodes is not None:
|
|
node_id = 0
|
|
for _ in range(num_input_nodes):
|
|
self.nodes[node_id] = NodeGene(node_id, 'sensor')
|
|
node_id += 1
|
|
first_action = node_id
|
|
for _ in range(action_num * 2):
|
|
self.nodes[node_id] = NodeGene(node_id, 'output')
|
|
node_id += 1
|
|
|
|
for index in range(num_input_nodes):
|
|
for action in range(action_num * 2):
|
|
self.connections.append(
|
|
ConnectionGene(index, first_action + action, True, index * (action_num * 2) + action)
|
|
)
|
|
if nodes is not None and connections is not None:
|
|
self.nodes = nodes
|
|
self.connections = connections
|
|
|
|
def calculate_rank_of_nodes(self):
|
|
rank_of_node = {}
|
|
nodes_to_rank = list(self.nodes.items())
|
|
while len(nodes_to_rank) > 0:
|
|
for list_index, (id, node) in enumerate(nodes_to_rank):
|
|
incoming_connections = list(filter(lambda connection: connection.end == id and
|
|
not connection.recurrent and connection.enabled,
|
|
self.connections))
|
|
if len(incoming_connections) == 0:
|
|
rank_of_node[id] = 0
|
|
nodes_to_rank.pop(list_index)
|
|
break
|
|
|
|
incoming_connections_starts = list(map(lambda connection: connection.start, incoming_connections))
|
|
start_ranks = list(map(lambda element: rank_of_node[element[0]],
|
|
filter(lambda start_node: start_node[0] in incoming_connections_starts and
|
|
start_node[0] in rank_of_node.keys(),
|
|
self.nodes.items())))
|
|
if len(start_ranks) == len(incoming_connections):
|
|
rank_of_node[id] = max(start_ranks) + 1
|
|
nodes_to_rank.pop(list_index)
|
|
break
|
|
return rank_of_node
|
|
|
|
@abstractmethod
|
|
def mutate(self, innovation_num) -> int:
|
|
"""
|
|
Decides whether or not to mutate this network. Then returns the new innovation number.
|
|
:param innovation_num: Current innovation number
|
|
:return: Updated innovation number
|
|
"""
|
|
|
|
# return innovation_num
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def cross(self, other, fitnes_self, fitness_other):
|
|
raise NotImplementedError()
|
|
# return self
|
|
|
|
|
|
class NeatLike(Genotype):
|
|
connection_add_thr = 0.3
|
|
node_add_thr = 0.3
|
|
disable_conn_thr = 0.1
|
|
|
|
# connection_add_thr = 0.0
|
|
# node_add_thr = 0.0
|
|
# disable_conn_thr = 0.0
|
|
|
|
def mutate(self, innovation_num, allow_recurrent=False) -> int:
|
|
"""
|
|
Decides whether or not to mutate this network. Then returns the new innovation number.
|
|
:param allow_recurrent: Optional parameter allowing or disallowing recurrent connections to form
|
|
:param innovation_num: Current innovation number
|
|
:return: Updated innovation number
|
|
"""
|
|
# add connection
|
|
if np.random.random(1)[0] < self.connection_add_thr:
|
|
nodes = list(self.nodes.keys())
|
|
rank_of_node = self.calculate_rank_of_nodes()
|
|
end_nodes = list(filter(lambda node: rank_of_node[node] > 0, nodes))
|
|
|
|
connection_tuple = list(map(lambda connection: (connection.start, connection.end), self.connections))
|
|
|
|
start = np.random.randint(0, len(nodes))
|
|
end = np.random.randint(0, len(end_nodes))
|
|
|
|
tries = 50
|
|
while (rank_of_node[end_nodes[end]] == 0 or
|
|
((not allow_recurrent) and rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]])
|
|
or nodes[start] == end_nodes[end] or (nodes[start], end_nodes[end]) in connection_tuple) and\
|
|
tries > 0:
|
|
end = np.random.randint(0, len(end_nodes))
|
|
if (not allow_recurrent) and rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]:
|
|
start = np.random.randint(0, len(nodes))
|
|
tries -= 1
|
|
if tries > 0:
|
|
innovation_num += 1
|
|
self.connections.append(
|
|
ConnectionGene(nodes[start], end_nodes[end], True, innovation_num,
|
|
recurrent=rank_of_node[nodes[start]] > rank_of_node[end_nodes[end]]))
|
|
|
|
if np.random.random(1)[0] < self.node_add_thr:
|
|
active_connections = list(filter(lambda connection: connection.enabled, self.connections))
|
|
|
|
n = np.random.randint(0, len(active_connections))
|
|
old_connection = active_connections[n]
|
|
|
|
new_node = NodeGene(innovation_num, 'hidden')
|
|
node_id = innovation_num
|
|
connection_1 = ConnectionGene(old_connection.start, node_id, True, innovation_num,
|
|
recurrent=old_connection.recurrent)
|
|
innovation_num += 1
|
|
connection_2 = ConnectionGene(node_id, old_connection.end, True, innovation_num)
|
|
innovation_num += 1
|
|
|
|
old_connection.enabled = False
|
|
self.nodes[node_id] = new_node
|
|
self.connections.append(connection_1)
|
|
self.connections.append(connection_2)
|
|
|
|
if np.random.random(1)[0] < self.disable_conn_thr:
|
|
active_connections = list(filter(lambda connection: connection.enabled, self.connections))
|
|
n = np.random.randint(0, len(active_connections))
|
|
old_connection = active_connections[n]
|
|
old_connection.enabled = not old_connection.enabled
|
|
|
|
return innovation_num
|
|
|
|
def cross(self, other, fitnes_self, fitness_other):
|
|
new_genes = NeatLike()
|
|
node_nums = set(map(lambda node: node[0], self.nodes.items())).union(
|
|
set(map(lambda node: node[0], other.nodes.items())))
|
|
|
|
connections = {}
|
|
for connection in self.connections:
|
|
connections[connection.innvovation_num] = connection
|
|
|
|
other_connections = {}
|
|
for connection in other.connections:
|
|
other_connections[connection.innvovation_num] = connection
|
|
|
|
connection_nums = set(map(lambda connection: connection[0], connections.items())).union(
|
|
set(map(lambda connection: connection[0], other_connections.items())))
|
|
|
|
for node_num in node_nums:
|
|
if node_num in self.nodes.keys() and node_num in other.nodes.keys():
|
|
if int(fitness_other) == int(fitnes_self):
|
|
if np.random.randint(0, 2) == 0:
|
|
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
|
else:
|
|
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
|
elif fitnes_self > fitness_other:
|
|
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
|
else:
|
|
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
|
elif node_num in self.nodes.keys() and int(fitnes_self) >= int(fitness_other):
|
|
new_genes.nodes[node_num] = copy(self.nodes[node_num])
|
|
elif node_num in other.nodes.keys() and int(fitnes_self) <= int(fitness_other):
|
|
new_genes.nodes[node_num] = copy(other.nodes[node_num])
|
|
|
|
for connection_num in connection_nums:
|
|
if connection_num in connections.keys() and connection_num in other_connections.keys():
|
|
if int(fitness_other) == int(fitnes_self):
|
|
if np.random.randint(0, 2) == 0:
|
|
connection = copy(connections[connection_num])
|
|
else:
|
|
connection = copy(other_connections[connection_num])
|
|
elif fitnes_self > fitness_other:
|
|
connection = copy(connections[connection_num])
|
|
else:
|
|
connection = copy(other_connections[connection_num])
|
|
|
|
new_genes.connections.append(connection)
|
|
elif connection_num in connections.keys() and int(fitnes_self) >= int(fitness_other):
|
|
new_genes.connections.append(copy(connections[connection_num]))
|
|
elif connection_num in other_connections.keys() and int(fitnes_self) <= int(fitness_other):
|
|
new_genes.connections.append(copy(other_connections[connection_num]))
|
|
|
|
return new_genes
|