Commit 1b29e8f5 authored by DEBATY Thomas's avatar DEBATY Thomas
Browse files

add flappy bird exemple

parent 3c618630
import pygame
import os
class Base:
"""Représente le sol en mouvement (base)."""
VEL = 5 # Vitesse de défilement du sol
IMG = pygame.image.load(os.path.join("imgs", "ground.png")).convert_alpha() # Chargement de l'image du sol
WIDTH = IMG.get_width() # Largeur de l'image du sol
def __init__(self, y):
self.y = y # Position verticale (y) du sol
self.x1 = 0 # Position en x de la première image du sol
self.x2 = self.WIDTH # Position en x de la seconde image du sol, juste après la première
def move(self):
"""Déplace le sol pour créer un effet de défilement."""
self.x1 -= self.VEL # Déplace la première image vers la gauche
self.x2 -= self.VEL # Déplace la seconde image vers la gauche
# Si la première image sort de l'écran, elle est repositionnée après la seconde
if self.x1 + self.WIDTH < 0:
self.x1 = self.x2 + self.WIDTH
# Si la seconde image sort de l'écran, elle est repositionnée après la première
if self.x2 + self.WIDTH < 0:
self.x2 = self.x1 + self.WIDTH
def draw(self, win):
"""Dessine les images du sol à leurs positions actuelles."""
win.blit(self.IMG, (self.x1, self.y)) # Dessine la première image du sol
win.blit(self.IMG, (self.x2, self.y)) # Dessine la seconde image du sol
import pygame
import os
from utils import blit_rotate_center
class Bird:
"""Représente l'oiseau dans le jeu."""
# Variables pour ajuster facilement le comportement
MAX_ROTATION = 25 # Angle d'inclinaison maximal de l'oiseau
ROT_VEL = 20 # Vitesse de rotation (changement d'angle par tick)
ANIMATION_TIME = 5 # Temps entre chaque image de l'animation (en ticks)
# Paramètres ajustables
GRAVITY = 3 # Gravité (vitesse de descente)
JUMP_STRENGTH = -10.5 # Force du saut (vitesse de montée)
SCALE_FACTOR = 1 # Facteur de mise à l'échelle de la taille de l'oiseau (1 = taille originale)
# Chargement des images de l'oiseau (taille originale)
ORIGINAL_BIRD_IMGS = [
pygame.image.load(os.path.join("imgs", "bird_wing_up.png")).convert_alpha(),
pygame.image.load(os.path.join("imgs", "bird_wing_down.png")).convert_alpha()
]
def __init__(self, x, y):
# Mise à l'échelle des images en fonction du SCALE_FACTOR
self.BIRD_IMGS = [
pygame.transform.scale(
img,
(
int(img.get_width() * self.SCALE_FACTOR),
int(img.get_height() * self.SCALE_FACTOR)
)
)
for img in self.ORIGINAL_BIRD_IMGS
]
self.x = x # Position x de l'oiseau
self.y = y # Position y de l'oiseau
self.tilt = 0 # Angle d'inclinaison actuel de l'oiseau
self.tick_count = 0 # Nombre de ticks depuis le dernier saut
self.vel = 0 # Vitesse actuelle de l'oiseau
self.height = self.y # Hauteur de départ de l'oiseau
self.img_count = 0 # Compteur pour l'animation des ailes
self.img = self.BIRD_IMGS[0] # Image actuelle de l'oiseau (par défaut, la première image)
def jump(self):
"""Fait sauter l'oiseau."""
self.vel = self.JUMP_STRENGTH # Attribue la force du saut
self.tick_count = 0 # Réinitialise le compteur de ticks après le saut
self.height = self.y # Met à jour la hauteur de départ après le saut
def move(self):
"""Calcule la nouvelle position de l'oiseau."""
self.tick_count += 1 # Incrémente le nombre de ticks
# Calcul du déplacement
displacement = self.vel * self.tick_count + 0.5 * self.GRAVITY * self.tick_count ** 2
# Limite de la vitesse terminale
if displacement >= 16:
displacement = 16 # L'oiseau ne peut pas descendre plus vite que cette valeur
# Ajustement du mouvement vers le haut
if displacement < 0:
displacement -= 2 # Augmente la montée légèrement
self.y += displacement # Applique le déplacement vertical
# Gère l'inclinaison de l'oiseau (rotation)
if displacement < 0 or self.y < self.height + 50:
# Si l'oiseau monte ou est encore proche de la hauteur initiale après un saut
if self.tilt < self.MAX_ROTATION:
self.tilt = self.MAX_ROTATION # Inclinaison maximale (vers le haut)
else:
# Si l'oiseau descend, il se penche vers le bas
if self.tilt > -90:
self.tilt -= self.ROT_VEL # L'oiseau se penche progressivement vers le bas
def draw(self, win):
"""Dessine l'oiseau avec animation."""
self.img_count += 1 # Incrémente le compteur d'animation
# Cycle à travers les images pour l'animation des battements d'ailes
if self.img_count < self.ANIMATION_TIME:
self.img = self.BIRD_IMGS[0] # Ailes vers le haut
elif self.img_count < self.ANIMATION_TIME * 2:
self.img = self.BIRD_IMGS[1] # Ailes vers le bas
elif self.img_count >= self.ANIMATION_TIME * 2:
self.img = self.BIRD_IMGS[0] # Réinitialisation de l'animation
self.img_count = 0
# Lorsque l'oiseau plonge, il affiche toujours l'image avec les ailes vers le bas
if self.tilt <= -80:
self.img = self.BIRD_IMGS[1]
self.img_count = self.ANIMATION_TIME * 2 # Force l'image des ailes vers le bas
# Dessine l'oiseau avec rotation centrée
blit_rotate_center(win, self.img, (self.x, self.y), self.tilt)
def get_mask(self):
"""Récupère le masque de l'image pour la détection de collisions."""
return pygame.mask.from_surface(self.img) # Crée un masque pour l'oiseau basé sur l'image actuelle
[NEAT]
fitness_criterion = max
fitness_threshold = 100
pop_size = 100
reset_on_extinction = False
[DefaultGenome]
# node activation options
activation_default = tanh
activation_mutate_rate = 0.0
activation_options = tanh
# node aggregation options
aggregation_default = sum
aggregation_mutate_rate = 0.0
aggregation_options = sum
# node bias options
bias_init_mean = 0.0
bias_init_stdev = 1.0
bias_max_value = 30.0
bias_min_value = -30.0
bias_mutate_power = 0.5
bias_mutate_rate = 0.7
bias_replace_rate = 0.1
# genome compatibility options
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient = 0.5
# connection add/remove rates
conn_add_prob = 0.5
conn_delete_prob = 0.5
# connection enable options
enabled_default = True
enabled_mutate_rate = 0.01
# feed-forward or recurrent
feed_forward = True
initial_connection = full
# node add/remove rates
node_add_prob = 0.2
node_delete_prob = 0.2
# network parameters
num_hidden = 0
num_inputs = 6
num_outputs = 1
# node response options
response_init_mean = 1.0
response_init_stdev = 0.0
response_max_value = 30.0
response_min_value = -30.0
response_mutate_power = 0.0
response_mutate_rate = 0.0
response_replace_rate = 0.0
# connection weight options
weight_init_mean = 0.0
weight_init_stdev = 1.0
weight_max_value = 30
weight_min_value = -30
weight_mutate_power = 0.5
weight_mutate_rate = 0.8
weight_replace_rate = 0.1
[DefaultSpeciesSet]
compatibility_threshold = 3.0
[DefaultStagnation]
species_fitness_func = max
max_stagnation = 20
species_elitism = 2
[DefaultReproduction]
elitism = 2
survival_threshold = 0.2
exemple/tnn-flappy-bird-exemple/imgs/background.png

24.2 KB

exemple/tnn-flappy-bird-exemple/imgs/bird.gif

3.52 KB

exemple/tnn-flappy-bird-exemple/imgs/bird_wing_down.png

1.38 KB

exemple/tnn-flappy-bird-exemple/imgs/bird_wing_up.png

1.29 KB

exemple/tnn-flappy-bird-exemple/imgs/ground.png

2.78 KB

exemple/tnn-flappy-bird-exemple/imgs/pipe.png

6.18 KB

exemple/tnn-flappy-bird-exemple/imgs/pipe_body.png

331 Bytes

exemple/tnn-flappy-bird-exemple/imgs/pipe_end.png

692 Bytes

# main.py
import pygame
import sys
import os
from settings import SCREEN_WIDTH, SCREEN_HEIGHT # Import from settings.py
pygame.init()
# Create the window
WIN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Flappy Bird")
# Now that Pygame is initialized, import other modules
from bird import Bird
from pipe import Pipe
from base import Base
# Load background image
BG_IMG = pygame.image.load(os.path.join("imgs", "background.png")).convert()
BG_IMG = pygame.transform.scale(BG_IMG, (SCREEN_WIDTH, SCREEN_HEIGHT))
# Fonts
FONT = pygame.font.SysFont("comicsans", 50)
def draw_window(win, bird, pipes, base, score):
"""Draw everything on the screen."""
win.blit(BG_IMG, (0, 0))
for pipe in pipes:
pipe.draw(win)
base.draw(win)
bird.draw(win)
# Display score
text = FONT.render(f"Score: {score}", True, (255, 255, 255))
win.blit(text, (SCREEN_WIDTH - 10 - text.get_width(), 10))
pygame.display.update()
def main():
"""Main game loop."""
bird = Bird(230, 350)
base = Base(730) # Position the base at y=730
pipes = [Pipe(600)]
score = 0
clock = pygame.time.Clock()
run = True
while run:
clock.tick(30) # 30 FPS
for event in pygame.event.get():
if event.type == pygame.QUIT:
run = False
pygame.quit()
sys.exit()
# Key press event
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
bird.jump()
# Move bird and base
bird.move()
base.move()
# Handle pipes
add_pipe = False
rem = []
for pipe in pipes:
pipe.move()
# Check for collision
if pipe.collide(bird):
main() # Restart the game
# Remove off-screen pipes
if pipe.x + pipe.PIPE_BODY_IMG.get_width() < 0:
rem.append(pipe)
# Check if we need to add a new pipe
if not pipe.passed and pipe.x < bird.x:
pipe.passed = True
add_pipe = True
if add_pipe:
score += 1
pipes.append(Pipe(600))
for r in rem:
pipes.remove(r)
# Check if the bird hits the ground or flies too high
if bird.y + bird.img.get_height() >= base.y or bird.y < 0:
main() # Restart the game
# Draw everything
draw_window(WIN, bird, pipes, base, score)
if __name__ == "__main__":
main()
import pygame
import sys
import os
import neat
import random
from settings import SCREEN_WIDTH, SCREEN_HEIGHT
import visualize #a commenter si marche pas ou veux pas visionner réseau
# Initialise Pygame
pygame.init()
# Crée la fenêtre de jeu avec les dimensions spécifiées
WIN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("Flappy Bird NEAT")
# Maintenant que Pygame est initialisé, on importe d'autres modules
from bird import Bird # Classe représentant l'oiseau
from pipe import Pipe # Classe représentant les tuyaux
from base import Base # Classe représentant la base (sol du jeu)
# Charge l'image de fond et la redimensionne à la taille de l'écran
BG_IMG = pygame.image.load(os.path.join("imgs", "background.png")).convert()
BG_IMG = pygame.transform.scale(BG_IMG, (SCREEN_WIDTH, SCREEN_HEIGHT))
# Définition des polices pour afficher du texte
FONT = pygame.font.SysFont("comicsans", 50)
def draw_window(win, birds, pipes, base, score):
"""Dessine tous les éléments à l'écran."""
# Affiche l'image de fond
win.blit(BG_IMG, (0, 0))
# Affiche les tuyaux
for pipe in pipes:
pipe.draw(win)
# Affiche la base (sol)
base.draw(win)
# Affiche chaque oiseau
for bird in birds:
bird.draw(win)
# Affiche le score à l'écran
text = FONT.render(f"Score: {score}", True, (255, 255, 255))
win.blit(text, (SCREEN_WIDTH - 10 - text.get_width(), 10))
# Met à jour l'affichage
pygame.display.update()
def eval_genomes(genomes, config):
# Liste pour stocker les réseaux neuronaux, les génomes et les oiseaux
nets = []
ge = []
birds = []
# Initialisation de chaque génome et création des réseaux neuronaux pour chaque oiseau
for genome_id, genome in genomes:
genome.fitness = 0 # Fitness de départ
net = neat.nn.FeedForwardNetwork.create(genome, config) # Création du réseau pour chaque génome
nets.append(net)
birds.append(Bird(230, 350)) # Position initiale de l'oiseau
ge.append(genome)
base = Base(730) # Positionnement de la base à y=730
pipes = [Pipe(600)] # Création d'un premier tuyau
score = 0
clock = pygame.time.Clock() # Création d'un horloge pour réguler la vitesse du jeu
run = True
while run and len(birds) > 0:
# Limite la vitesse à 30 images par seconde
clock.tick(30)
# Gère les événements, notamment la fermeture de la fenêtre
for event in pygame.event.get():
if event.type == pygame.QUIT:
run = False
pygame.quit()
sys.exit()
# Détermine quel tuyau est le prochain à passer pour l'oiseau
pipe_ind = 0
if len(pipes) > 1 and len(birds) > 0 and birds[0].x > pipes[0].x + pipes[0].PIPE_BODY_IMG.get_width():
pipe_ind = 1
# Boucle sur chaque oiseau pour les déplacer et ajuster leur fitness
for i, bird in enumerate(birds):
bird.move() # Déplace l'oiseau
ge[i].fitness += 0.1 # Récompense pour chaque instant où l'oiseau survit
# Calcule le milieu du trou du tuyau
middle_of_gap = pipes[pipe_ind].height + (pipes[pipe_ind].GAP / 2)
# Les entrées pour le réseau neuronal :
output = nets[i].activate(
(
pipes[pipe_ind].x, # x du tuyau
pipes[pipe_ind].height, # y du haut du tuyau
pipes[pipe_ind].bottom, # y du bas du tuyau
bird.y - middle_of_gap, # Différence entre l'oiseau et le milieu du trou (non utilisé ici)
bird.x - pipes[pipe_ind].x, # Différence en x entre l'oiseau et le tuyau
bird.y # Position y de l'oiseau
)
)
# Si la sortie du réseau est supérieure à 0.5, l'oiseau saute
if output[0] > 0.5:
bird.jump()
base.move() # Déplace la base
rem = [] # Liste des tuyaux à supprimer
add_pipe = False # Indique s'il faut ajouter un nouveau tuyau
for pipe in pipes:
pipe.move() # Déplace le tuyau
# Vérifie si un oiseau entre en collision avec un tuyau
for i, bird in enumerate(birds):
if pipe.collide(bird):
ge[i].fitness -= 1 # Pénalité en cas de collision
birds.pop(i)
nets.pop(i)
ge.pop(i)
# Si l'oiseau dépasse un tuyau, marque le tuyau comme "passé"
if len(birds) > 0 and not pipe.passed and pipe.x < birds[0].x:
pipe.passed = True
add_pipe = True
# Si le tuyau est sorti de l'écran, on le marque pour suppression
if pipe.x + pipe.PIPE_BODY_IMG.get_width() < 0:
rem.append(pipe)
# Ajoute un nouveau tuyau si nécessaire
if add_pipe:
score += 1
for genome in ge:
genome.fitness += 5 # Récompense pour avoir passé un tuyau
pipes.append(Pipe(600))
# Supprime les tuyaux qui sont sortis de l'écran
for r in rem:
pipes.remove(r)
# Vérifie si un oiseau touche le sol ou dépasse le haut de l'écran
for i, bird in enumerate(birds):
if bird.y + bird.img.get_height() >= base.y or bird.y < 0:
ge[i].fitness -= 1 # Pénalité pour avoir touché le sol ou dépassé
birds.pop(i)
nets.pop(i)
ge.pop(i)
# Dessine l'écran seulement si des oiseaux sont encore en vie
if len(birds) > 0:
draw_window(WIN, birds, pipes, base, score)
# Arrête la simulation si le score dépasse une certaine limite
if score > 20:
break
def run(config_file):
# Charge la configuration de NEAT
config = neat.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
config_file
)
# Crée la population basée sur la configuration NEAT
p = neat.Population(config)
# Ajoute des reporters pour suivre l'évolution de l'entraînement
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
# Optionnel : ajoute un reporter pour sauvegarder des checkpoints
p.add_reporter(neat.Checkpointer(5)) # Sauvegarde toutes les 5 générations
# Exécute NEAT pour un maximum de 50 générations
winner = p.run(eval_genomes, 20) # Limite à 20 générations
# Visualisation des résultats. NE MARCHE PAS TRES BIEN DONC A COMMENTER !!
node_names = {-1: "X pipe", -2: "Y top pipe", -3: "Y bottom pipe", -4: "diff x bird and x pipes", -5: "Y bird", 0: "Jump"}
visualize.draw_net(config, winner, True, node_names=node_names)
visualize.draw_net(config, winner, True, node_names=node_names, prune_unused=True)
visualize.plot_stats(stats, ylog=False, view=True)
visualize.plot_species(stats, view=True)
# Affiche le meilleur génome après l'entraînement
print('\nBest genome:\n{!s}'.format(winner))
if __name__ == "__main__":
# Chemin vers le fichier de configuration NEAT
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, "config-feedforward.txt")
run(config_path)
import pygame
import os
import random
from settings import SCREEN_HEIGHT # Importation de la hauteur de l'écran depuis settings.py
class Pipe:
"""Représente un obstacle de type tuyau."""
GAP = 200 # Espace entre les tuyaux du haut et du bas
VEL = 5 # Vitesse à laquelle les tuyaux se déplacent vers la gauche
# Chargement des images des tuyaux
PIPE_BODY_IMG = pygame.image.load(os.path.join("imgs", "pipe_body.png")).convert_alpha() # Image du corps du tuyau
PIPE_END_IMG = pygame.image.load(os.path.join("imgs", "pipe_end.png")).convert_alpha() # Image de l'extrémité du tuyau
def __init__(self, x):
# Initialise les coordonnées et la hauteur du tuyau
self.x = x # Position horizontale du tuyau
self.height = 0 # Hauteur du tuyau (initialisée plus tard)
# Positions des tuyaux du haut et du bas
self.top = 0 # Position du tuyau supérieur (sera initialisée plus tard)
self.bottom = 0 # Position du tuyau inférieur (sera initialisée plus tard)
# Inversion des images pour le tuyau supérieur (car il est à l'envers)
self.PIPE_BODY_TOP = pygame.transform.flip(self.PIPE_BODY_IMG, False, True) # Inverse l'image du corps
self.PIPE_END_TOP = pygame.transform.flip(self.PIPE_END_IMG, False, True) # Inverse l'image de l'extrémité
self.passed = False # Indique si le tuyau a été dépassé par l'oiseau
self.set_height() # Détermine aléatoirement la hauteur du tuyau
def set_height(self):
"""Définit la hauteur du tuyau de manière aléatoire et crée les surfaces des tuyaux."""
# Détermine aléatoirement la hauteur du tuyau supérieur (du haut de l'écran jusqu'au début du GAP)
self.height = random.randrange(50, SCREEN_HEIGHT - self.GAP - 50)
self.bottom = self.height + self.GAP # Position du tuyau inférieur
# Création de la surface pour le tuyau supérieur
top_pipe_height = self.height # Hauteur du tuyau supérieur
self.PIPE_TOP_SURFACE = pygame.Surface((self.PIPE_BODY_TOP.get_width(), top_pipe_height), pygame.SRCALPHA)
# Dessin du tuyau supérieur en partant de l'extrémité
y = top_pipe_height - self.PIPE_END_TOP.get_height()
# Dessin de l'extrémité du tuyau supérieur
self.PIPE_TOP_SURFACE.blit(self.PIPE_END_TOP, (0, y))
y -= self.PIPE_BODY_TOP.get_height()
# Dessin du corps du tuyau supérieur
while y > -self.PIPE_BODY_TOP.get_height():
self.PIPE_TOP_SURFACE.blit(self.PIPE_BODY_TOP, (0, y))
y -= self.PIPE_BODY_TOP.get_height()
# Le tuyau supérieur commence à y=0 (du haut de l'écran)
self.top = 0
# Création de la surface pour le tuyau inférieur
bottom_pipe_height = SCREEN_HEIGHT - self.bottom # Hauteur du tuyau inférieur
self.PIPE_BOTTOM_SURFACE = pygame.Surface((self.PIPE_BODY_IMG.get_width(), bottom_pipe_height), pygame.SRCALPHA)
y = 0
# Dessin de l'extrémité du tuyau inférieur
self.PIPE_BOTTOM_SURFACE.blit(self.PIPE_END_IMG, (0, y))
y += self.PIPE_END_IMG.get_height()
# Dessin du corps du tuyau inférieur
while y < bottom_pipe_height:
self.PIPE_BOTTOM_SURFACE.blit(self.PIPE_BODY_IMG, (0, y))
y += self.PIPE_BODY_IMG.get_height()
def move(self):
"""Déplace le tuyau vers la gauche."""
self.x -= self.VEL # Diminue la position en x pour simuler le mouvement
def draw(self, win):
"""Dessine les tuyaux supérieurs et inférieurs."""
# Dessine le tuyau supérieur à la position calculée
win.blit(self.PIPE_TOP_SURFACE, (self.x, self.top))
# Dessine le tuyau inférieur à sa position calculée
win.blit(self.PIPE_BOTTOM_SURFACE, (self.x, self.bottom))
def collide(self, bird):
"""Vérifie si l'oiseau entre en collision avec un tuyau."""
bird_mask = bird.get_mask() # Récupère le masque de l'oiseau (pour la détection de collision)
# Crée des masques pour les tuyaux supérieurs et inférieurs
top_mask = pygame.mask.from_surface(self.PIPE_TOP_SURFACE)
bottom_mask = pygame.mask.from_surface(self.PIPE_BOTTOM_SURFACE)
# Calcule les décalages entre l'oiseau et les tuyaux
top_offset = (int(self.x - bird.x), int(self.top - bird.y)) # Décalage pour le tuyau supérieur
bottom_offset = (int(self.x - bird.x), int(self.bottom - bird.y)) # Décalage pour le tuyau inférieur
# Vérifie s'il y a des points de collision avec les tuyaux
t_point = bird_mask.overlap(top_mask, top_offset) # Collision avec le tuyau supérieur
b_point = bird_mask.overlap(bottom_mask, bottom_offset) # Collision avec le tuyau inférieur
# Si une collision est détectée, retourne True
if t_point or b_point:
return True
return False # Pas de collision détectée
# settings.py
SCREEN_WIDTH = 500
SCREEN_HEIGHT = 800
# utils.py
import pygame
def blit_rotate_center(surf, image, topleft, angle):
"""Rotate an image and blit it to the surface at the center."""
rotated_image = pygame.transform.rotate(image, angle)
new_rect = rotated_image.get_rect(center=image.get_rect(topleft=topleft).center)
surf.blit(rotated_image, new_rect.topleft)
import warnings
import graphviz
import matplotlib.pyplot as plt
import numpy as np
def plot_stats(statistics, ylog=False, view=False, filename='avg_fitness.svg'):
""" Plots the population's average and best fitness. """
if plt is None:
warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
return
generation = range(len(statistics.most_fit_genomes))
best_fitness = [c.fitness for c in statistics.most_fit_genomes]
avg_fitness = np.array(statistics.get_fitness_mean())
stdev_fitness = np.array(statistics.get_fitness_stdev())
plt.plot(generation, avg_fitness, 'b-', label="average")
plt.plot(generation, avg_fitness - stdev_fitness, 'g-.', label="-1 sd")
plt.plot(generation, avg_fitness + stdev_fitness, 'g-.', label="+1 sd")
plt.plot(generation, best_fitness, 'r-', label="best")
plt.title("Population's average and best fitness")
plt.xlabel("Generations")
plt.ylabel("Fitness")
plt.grid()
plt.legend(loc="best")
if ylog:
plt.gca().set_yscale('symlog')
plt.savefig(filename)
if view:
plt.show()
plt.close()
def plot_spikes(spikes, view=False, filename=None, title=None):
""" Plots the trains for a single spiking neuron. """
t_values = [t for t, I, v, u, f in spikes]
v_values = [v for t, I, v, u, f in spikes]
u_values = [u for t, I, v, u, f in spikes]
I_values = [I for t, I, v, u, f in spikes]
f_values = [f for t, I, v, u, f in spikes]
fig = plt.figure()
plt.subplot(4, 1, 1)
plt.ylabel("Potential (mv)")
plt.xlabel("Time (in ms)")
plt.grid()
plt.plot(t_values, v_values, "g-")
if title is None:
plt.title("Izhikevich's spiking neuron model")
else:
plt.title("Izhikevich's spiking neuron model ({0!s})".format(title))
plt.subplot(4, 1, 2)
plt.ylabel("Fired")
plt.xlabel("Time (in ms)")
plt.grid()
plt.plot(t_values, f_values, "r-")
plt.subplot(4, 1, 3)
plt.ylabel("Recovery (u)")
plt.xlabel("Time (in ms)")
plt.grid()
plt.plot(t_values, u_values, "r-")
plt.subplot(4, 1, 4)
plt.ylabel("Current (I)")
plt.xlabel("Time (in ms)")
plt.grid()
plt.plot(t_values, I_values, "r-o")
if filename is not None:
plt.savefig(filename)
if view:
plt.show()
plt.close()
fig = None
return fig
def plot_species(statistics, view=False, filename='speciation.svg'):
""" Visualizes speciation throughout evolution. """
if plt is None:
warnings.warn("This display is not available due to a missing optional dependency (matplotlib)")
return
species_sizes = statistics.get_species_sizes()
num_generations = len(species_sizes)
curves = np.array(species_sizes).T
fig, ax = plt.subplots()
ax.stackplot(range(num_generations), *curves)
plt.title("Speciation")
plt.ylabel("Size per Species")
plt.xlabel("Generations")
plt.savefig(filename)
if view:
plt.show()
plt.close()
def draw_net(config, genome, view=False, filename=None, node_names=None, show_disabled=True, prune_unused=False,
node_colors=None, fmt='svg'):
""" Receives a genome and draws a neural network with arbitrary topology. """
# Attributes for network nodes.
if graphviz is None:
warnings.warn("This display is not available due to a missing optional dependency (graphviz)")
return
# If requested, use a copy of the genome which omits all components that won't affect the output.
if prune_unused:
genome = genome.get_pruned_copy(config.genome_config)
if node_names is None:
node_names = {}
assert type(node_names) is dict
if node_colors is None:
node_colors = {}
assert type(node_colors) is dict
node_attrs = {
'shape': 'circle',
'fontsize': '9',
'height': '0.2',
'width': '0.2'}
dot = graphviz.Digraph(format=fmt, node_attr=node_attrs)
inputs = set()
for k in config.genome_config.input_keys:
inputs.add(k)
name = node_names.get(k, str(k))
input_attrs = {'style': 'filled', 'shape': 'box', 'fillcolor': node_colors.get(k, 'lightgray')}
dot.node(name, _attributes=input_attrs)
outputs = set()
for k in config.genome_config.output_keys:
outputs.add(k)
name = node_names.get(k, str(k))
node_attrs = {'style': 'filled', 'fillcolor': node_colors.get(k, 'lightblue')}
dot.node(name, _attributes=node_attrs)
used_nodes = set(genome.nodes.keys())
for n in used_nodes:
if n in inputs or n in outputs:
continue
attrs = {'style': 'filled',
'fillcolor': node_colors.get(n, 'white')}
dot.node(str(n), _attributes=attrs)
for cg in genome.connections.values():
if cg.enabled or show_disabled:
# if cg.input not in used_nodes or cg.output not in used_nodes:
# continue
input, output = cg.key
a = node_names.get(input, str(input))
b = node_names.get(output, str(output))
style = 'solid' if cg.enabled else 'dotted'
color = 'green' if cg.weight > 0 else 'red'
width = str(0.1 + abs(cg.weight / 5.0))
dot.edge(a, b, _attributes={'style': style, 'color': color, 'penwidth': width})
dot.render(filename, view=view)
return dot
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment