Probleme mit scipy.integrate.odeint

Wenn du dir nicht sicher bist, in welchem der anderen Foren du die Frage stellen sollst, dann bist du hier im Forum für allgemeine Fragen sicher richtig.
Antworten
MdI
User
Beiträge: 2
Registriert: Donnerstag 5. Juni 2014, 11:59

Hallo liebes Forum,

ich habe ein paar Probleme mit der Benutzung von scipy.integrate.odeint.

Grundsätzlich geht es um ein Dreikörperproblem. Die drei Körper ziehen sich gravitativ an. Das entsprechende System gekoppelter ODEs soll dann mit scipy.integrate.odeint gelöst und die Bahnen geplottet werden:

Code: Alles auswählen


import scipy as scp
from scipy.integrate import odeint

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
"""
dr1/dt = r1'
dr2/dt = r2'
dr3/dt = r3'

dr1'/dt = - G*m2*(r1-r2)/|r1-r2|^3 - G*m3*(r1-r3)/|r1-r3|^3
dr2'/dt = - G*m3*(r2-r3)/|r2-r3|^3 - G*m1*(r2-r1)/|r2-r1|^3
dr3'/dt = - G*m1*(r3-r1)/|r3-r1|^3 - G*m2*(r3-r2)/|r3-r2|^3

Definiere:

z[0] = r1
z[1] = r1'
z[2] = r2
z[3] = r2'
z[4] = r3
z[5] = r3'

ODE-System:

dz[0]/dt = z[1]
dz[2]/dt = z[3]
dz[4]/dt = z[5]

dz[1]/dt = - G*m2*(z[0]-z[2])/|z[0]-z[2]|^3 - G*m3*(z[0]-z[4])/|z[0]-z[4]|^3
dz[3]/dt = - G*m2*(z[2]-z[4])/|z[2]-z[4]|^3 - G*m3*(z[2]-z[0])/|z[2]-z[0]|^3
dz[5]/dt = - G*m2*(z[4]-z[0])/|z[4]-z[0]|^3 - G*m3*(z[4]-z[0])/|z[4]-z[0]|^3

"""

m1 = 1
m2 = 1
m3 = 1

def deriv(z, t):
	G = 6.573*10**(-11)
	dz = scp.array([z[1], z[3], z[5], \
					 - G*m2*(z[0]-z[2])/scp.fabs(z[0]-z[2])**3 - G*m3*(z[0]-z[4])/scp.fabs(z[0]-z[4])**3,\
					 - G*m2*(z[2]-z[4])/scp.fabs(z[2]-z[4])**3 - G*m3*(z[2]-z[0])/scp.fabs(z[2]-z[0])**3,\
					 - G*m2*(z[4]-z[0])/scp.fabs(z[4]-z[0])**3 - G*m3*(z[4]-z[0])/scp.fabs(z[4]-z[0])**3])
	return dz

time = scp.linspace(0.0, 10.0, 100)

# init values
zinit = scp.array([	0.0, -1.5, 3.0, -1.12, 10.32, 0.1,\
					-2.0, 3.5, 1.0, 0.65, -1.112, 1.43,\
					1.5, 0.0, -0.5, 0.23, 0.231, -1.545	])

r = odeint(deriv, zinit, time)


x1 = r[:][0]
y1 = r[:][1]
z1 = r[:][2]
x2 = r[:][6]
y2 = r[:][7]
z2 = r[:][8]
x3 = r[:][12]
y3 = r[:][13]
z3 = r[:][14]


fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

plt.plot(x1, y1, z1, 'blue', x2, y2, z3, 'red', x3, y3, z3, 'yellow')
plt.show()
Ich versuch jetzt schon ne ganze Weile daran herum, aber anscheinend übergebe ich irgendetwas nicht in der richtigen shape oder Dimension.

Hier der Plot:
http://www.directupload.net/file/d/3644 ... bx_png.htm

Falls mir da jemand einen Tipp geben könnte wäre ich sehr dankbar :)
Benutzeravatar
MagBen
User
Beiträge: 799
Registriert: Freitag 6. Juni 2014, 05:56
Wohnort: Bremen
Kontaktdaten:

Das hier sind 6 Vektor-Gleichungen

dr1/dt = r1'
dr2/dt = r2'
dr3/dt = r3'

dr1'/dt = - G*m2*(r1-r2)/|r1-r2|^3 - G*m3*(r1-r3)/|r1-r3|^3
dr2'/dt = - G*m3*(r2-r3)/|r2-r3|^3 - G*m1*(r2-r1)/|r2-r1|^3
dr3'/dt = - G*m1*(r3-r1)/|r3-r1|^3 - G*m2*(r3-r2)/|r3-r2|^3

In Python musst Du daraus 6x3 skalare Gleichungen machen und an odeint übergeben.

Du machst aber nur 6 skalare Gleichung daraus. Von Deinen 18 zinit Werten benutzt odeint nur die ersten 6. Deshalb enthalten nur r[:,0] bis r[:,5] sinnvolle Werte, ab r[:,6] stehen undefinierte Werte drin.

Bevor Du 3D Plots machst, kannst Du mit

Code: Alles auswählen

plt.figure()
for i in range(r.shape[1]):
    plt.plot(time, r[:,i], label=str(i))
plt.legend()
einfach überprüfen ob Deine Lösung plausibel ist.
a fool with a tool is still a fool, www.magben.de, YouTube
MdI
User
Beiträge: 2
Registriert: Donnerstag 5. Juni 2014, 11:59

Vielen Dank für die Antwort. Ich habs aber auch geschafft, das Gleichungssystem mit Vektoren aufzuschreiben, was natürlich insbesondere bei einer Anzahl von Dimensionen, die 3 weit übersteigt übersichtlicher ist. Am Anfang der deriv() müsste dann der init_state erst in Vektoren gepackt und nach der Berechnung wieder in eine 1-d-List entpackt werden. Hier ist das jetzt durchdefiniert, kann man natürlich auch automatisieren.

Jedenfalls waren da noch einige Fehler drin, die jetzt raus sind. Der Vollständigkeit halber hier der funktionierende Code für ein Sonne-Erde-Mond System.

Code: Alles auswählen


import scipy as scp
from scipy.integrate import odeint
import scipy.linalg as lin

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
"""
dr1/dt = r1'
dr2/dt = r2'
dr3/dt = r3'

dr1'/dt = - G*m2*(r1-r2)/|r1-r2|^3 - G*m3*(r1-r3)/|r1-r3|^3
dr2'/dt = - G*m3*(r2-r3)/|r2-r3|^3 - G*m1*(r2-r1)/|r2-r1|^3
dr3'/dt = - G*m1*(r3-r1)/|r3-r1|^3 - G*m2*(r3-r2)/|r3-r2|^3

Definiere:

z[0] = r1
z[1] = r1'
z[2] = r2
z[3] = r2'
z[4] = r3
z[5] = r3'

ODE-System:

dz[0]/dt = z[1]
dz[2]/dt = z[3]
dz[4]/dt = z[5]

dz[1]/dt = - G*m2*(z[0]-z[2])/|z[0]-z[2]|^3 - G*m3*(z[0]-z[4])/|z[0]-z[4]|^3
dz[3]/dt = - G*m2*(z[2]-z[4])/|z[2]-z[4]|^3 - G*m3*(z[2]-z[0])/|z[2]-z[0]|^3
dz[5]/dt = - G*m2*(z[4]-z[0])/|z[4]-z[0]|^3 - G*m3*(z[4]-z[0])/|z[4]-z[0]|^3

"""

#m1 = 1000 # Sonne
#m2 = 1000 # Erde
#m3 = 1000 # Mond


mS = 1.989*10**30 # Sonne
mE = 5.972*10**24 # Erde
mM = 7.349*10**22 # Mond


x0S = 0.0
y0S = 0.0
z0S = 0.0

x0E = 149.6 * 10**9 # große Halbachse in m
y0E = 0.0
z0E = 0.0

x0M = x0E + 384400000 # große Halbachse in m
y0M = 0.0
z0M = 0.0

vx0S = 0.0
vy0S = 0.0
vz0S = 0.0

vx0E = 0.0
vy0E = 29.29 * 10**3
vz0E = 0.0

vx0M = 0.0
vy0M = vy0E + 0.964 * 10**3 # Minimalgeschwindigkeit in m/s
vz0M = 0.0


zinit = [	x0S, y0S, z0S, vx0S, vy0S, vz0S,\
			x0E, y0E, z0E, vx0E, vy0E, vz0E,\
			x0M, y0M, z0M, vx0M, vy0M, vz0M	]

def deriv(state, time):
	G = 6.673*10**(-11)

	rS = scp.array(state[0:3])
	vS = scp.array(state[3:6])
	rE = scp.array(state[6:9])
	vE = scp.array(state[9:12])
	rM = scp.array(state[12:15])
	vM = scp.array(state[15:18])

	#drS/dt = rS' = vS
	drS = vS
	#drS'/dt = d²rS/dt² = - G*mE*(rS-rE)/lin.norm(rS-rE)**3 - G*mM*(rS-rM)/lin.norm(rS-rM)**3
	dvrS = - G*mE*(rS-rE)/lin.norm(rS-rE)**3 - G*mM*(rS-rM)/lin.norm(rS-rM)**3

	drE = vE
	dvrE = - G*mM*(rE-rM)/lin.norm(rE-rM)**3 - G*mS*(rE-rS)/lin.norm(rE-rS)**3

	drM = vM
	dvrM = - G*mS*(rM-rS)/lin.norm(rM-rS)**3 - G*mE*(rM-rE)/lin.norm(rM-rE)**3
	
	deriv = [	drS[0], drS[1], drS[2], dvrS[0], dvrS[1], dvrS[2],\
				drE[0], drE[1], drE[2], dvrE[0], dvrE[1], dvrE[2],\
				drM[0], drM[1], drM[2], dvrM[0], dvrM[1], dvrM[2]]

	return deriv

time = scp.linspace(0.0, 60*60*24*180, 380)


r = odeint(deriv, zinit, time)


x1 = []
x2 = []
x3 = []
y1 = []
y2 = []
y3 = []
z1 = []
z2 = []
z3 = []

for t in r:
	x1.append(t[0])
	y1.append(t[1])
	z1.append(t[2])
	x2.append(t[6])
	y2.append(t[7])
	z2.append(t[8])
	x3.append(t[12])
	y3.append(t[13])
	z3.append(t[14])

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
lim = 169948400.0*10**3
ax.set_xlim3d(-lim, lim)
ax.set_ylim3d(-lim, lim)
ax.set_zlim3d(-lim, lim)

ax.plot(x1,y1,z1, 'yo')
ax.plot(x2,y2,z2, 'bo')
ax.plot(x3,y3,z3, 'r.')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
Aufgrund der Größenverhältnisse am Plot etwas schwer zu erkennen, aber beim reinzoomen sieht man, dass es funktioniert.
EyDu
User
Beiträge: 4881
Registriert: Donnerstag 20. Juli 2006, 23:06
Wohnort: Berlin

Du solltest deinen Code noch ein wenig überarbeiten, ich list die wichtigsten Hinweise einfach mal der Reihe nach auf: Als erstes solltest du die ganzen durchnummerierten Namen loswerden. Wenn du beginnst Nummern zu vergeben, dann ist das ein eindeutiger Hinweis auf die Verwendung einer anderen Datenstruktur. In deinem Fall bietet sich Tupel, Listen oder NumPy-Arrays, bzw. Matrizen, an. Dann wird dein Code gleich kompakter, einfacher und deutlich übersichtlicher.

Als nächstes solltest du den ganzen doppelten Code loswerden. Zeilen 90, 93 und 96 sind quasi identisch, das solltest du in eine Funktion auslagern. Dann musst du das ganze auch nur einmal schreiben und hast nur eine Fehlerquelle, bzw. eine Stelle falls du etwas ändern musst. Auch Zeilen 120 bis 129 kannst du prima zusammenfassen, jede Zeile unterscheidet sich an nur genau zwei Stellen. Mit der richtigen Datenstruktur wäre das aber alles nicht notwendig, dann wäre der ganze Block ein Dreizeiler.

Die Zeilen 88, 92 und 95 sind überflüssig und statt in Zeile 98 das Ergebnis an einen Namen zu binden, kannst du das auch gleich mittels return-Anweisung zurückgeben. Bei der Gelegenheit solltest du dir auch gleich über Namensgebung Gedanken machen, die sind bei dir alle recht nichtssagend. Dir mag das jetzt noch sinnvoll erscheinen, aber in einem Monat weißt du selbst nicht mehr, was du da eigentlich gemacht hast. Behalte das immer im Hinterkopf.

Auch solltest du keinen Code, abgesehen von Funktionsdefinitionen, Klassen, Importen und Konstanten, auf modulebene haben. Hast du das doch, so handelst du dir einen unschönen Nebeneffekt ein: Du kannst deinen Code nicht vernünftig importieren und wiederverwenden, da er bei jedem Import ausgeführt wird. Das möchte man häufig natürlich nicht. Packe daher den Code in eine main-Funktion und rufe die über das folgende Idiom auf:

Code: Alles auswählen

if __name__ == "__main__":
    main()
Damit ist sichergestellt, dass die main-Funktion nur dann ausgeführt wird, wenn du das Modul direkt startest.

Am besten wirfst du noch einen genauen Blick auf NumPy, du nutzt es hier gar nicht richtig aus. Du machst viel zu viel von Hand, machst selbst Indexoperationen und baust quasi Matrizen zusammen und zerlegst sie wieder. Wenn du NumPy richtig verwendest, dann kannst du den ganzen Code ordentlich zusammendampfen. Dann sieht er den Originalformeln auch deutlich ähnlicher und nicht mehr so zerstreut.
Das Leben ist wie ein Tennisball.
Benutzeravatar
MagBen
User
Beiträge: 799
Registriert: Freitag 6. Juni 2014, 05:56
Wohnort: Bremen
Kontaktdaten:

Sehr schön, dass Du es geschafft hast.

Ein paar Tips, damit der Code mehr nach Python und Numerik aussieht:

Lass das 10**n weg, schreib die Zeilen 43 bis 45 lieber so

Code: Alles auswählen

mS = 1.989e30 # Sonne
mE = 5.972e24 # Erde
mM = 7.349e22 # Mond
Die Zeilen 110 bis 129 sehen so mehr nach Numpy aus:

Code: Alles auswählen

x1 = r[:,0]
y1 = r[:,1]
z1 = r[:,2]
x2 = r[:,6]
y2 = r[:,7]
z2 = r[:,8]
x3 = r[:,12]
y3 = r[:,13]
z3 = r[:,14]
Oder noch kürzer so (aber das ist Geschmackssache):

Code: Alles auswählen

r1 = r[:,0:3].T
r2 = r[:,6:9].T
r3 = r[:,12:15].T
und die Zeilen 138 bis 140 sehen damit so aus:

Code: Alles auswählen

ax.plot(r1[0], r1[1], r1[2], 'yo')
ax.plot(r2[0], r2[1], r2[2], 'bo')
ax.plot(r3[0], r3[1], r3[2], 'r.')
a fool with a tool is still a fool, www.magben.de, YouTube
Antworten