Code: Alles auswählen
import numpy
import math
import matplotlib.pyplot as plt
class SOM:
def __init__(self, size):
assert len(size) == 2
self.size = size
def train(self, data, weights, dimensions, iterations, epsilon):
self.nodes = numpy.random.rand(self.size[0], self.size[1], len(data[0]))
self.weights = weights
self.epsilon = epsilon
# plt.subplot(self.size[0], self.size[1], 1)
# plt.imshow(self.nodes.reshape(self.size[0], self.size[1], dimensions[0], dimensions[1])[3][3])
# plt.ion()
# plt.draw()
for i in range(iterations):
# if i % 1 == 0: # plot
# for x in range(self.size[0]):
# for y in range(self.size[1]):
# plt.title(str(i))
# plt.subplot(self.size[0], self.size[1], x + y * self.size[0] + 1)
# plot = plt.imshow(self.nodes.reshape(self.size[0], self.size[1], dimensions[0], dimensions[1])[x][y], interpolation="nearest")
# plot.axes.get_xaxis().set_visible(False)
# plot.axes.get_yaxis().set_visible(False)
# plt.draw()
# plt.show()
for currentData in data: #training
self.trainNode(currentData)
# plt.show()
input('Press Enter to exit')
def trainNode(self, data):
coordinate = self.findBestMatchingNode(data)
self.doTraining(data, coordinate)
def findBestMatchingNode(self, data):
bestDistance = float("nan")
result = [0, 0]
for x in range(self.size[0]):
for y in range(self.size[1]):
distance = self.calculateDistance(self.nodes[x][y], data)
if math.isnan(bestDistance) or distance < bestDistance:
bestDistance = distance
result = [x, y]
return result
def calculateDistance(self, a, b):
return numpy.sum(pow(a - b, 2) / (a + b))
def doTraining(self, data, coordinate):
for x in range(self.size[0]):
for y in range(self.size[1]):
f = 1 / (1 + pow(coordinate[0] - x, 2) + pow(coordinate[1] - y, 2))
self.nodes[x][y] = self.nodes[x][y] + self.weights[x][y] * self.epsilon * f * (data - self.nodes[x][y])
Ich bin gerade dabei sie zu analysieren und zu verstehen, aber komme leider an "data" nicht vorbei..
Da in Python die Variablen nicht explizit mit Datentypen deklariert werden, weiss ich nicht, was data (bzw. die Unterelemente der Unterelemente von "data", da "data" ein Containertyp zu sein scheint, die wiederum einen Containertyp enthält) für ein Datentyp sein könnte..
Ein paar Eckdaten:
Die SOM wird für Bilderkennung genutzt. Also gehe ich davon aus, dass "data" sowas wie die Bildpunktdaten enthalten könnte..
Das Auskommentierte ist eher unwichtig, da es zum Plotten verwendet wird.
Und es gibt leider keine Dokumentation oder Kommentare zum Quellcode.
Was mich noch ein wenig irritiert:
Die Elemente von "data" werden unverändert über verschiedene Methode bis zur Methode
Code: Alles auswählen
def calculateDistance(self, a, b):
return numpy.sum(pow(a - b, 2) / (a + b))
Wie ist das aber möglich, wenn doch die Elemente von "data" Containertypen sind?