Wenn ich nur DataLoader.load_data() über Mein_model.load_data ausführe, dann funktioniert alles. Also die daten werden ohne fehlermeldung geladen. Der Fehler wird angezeigt wenn ich versuche tf.unittest meinen Code zu testen. Daher denke ich, dass mein unittest fehlerhaft ist (siehe 2ter Codeabschnitt)
Vereinfachter Code:
Code: Alles auswählen
# Funktion des Moduls Unet/ Klasse Mein_model
class Mein_model:
def load_data(self):
# HIER soll laut unittest der Fehler liegen
self.dataset, self.info = DataLoader.load_data() # DataLoader.load_data() ist eine funktion eines anderen Moduls
# anderes Modul
class DataLoader:
@staticmethod
def load_data():
return tfds.load(name="mnist", with_info = True)
Code: Alles auswählen
def dummy_load_data(*args, **kwargs):
with tfds.testing.mock_data(num_examples=1): # Mock tfds to generate random data
return tfds.load("mnist",with_info=True)
class UnetTest(tf.test.TestCase):
def setUp(self):
super(UnetTest,self).setUp()
@patch(target="model.Unet.DataLoader.load_data")
def test_load_data(self, mock_obj):
mock_obj.side_effect = dummy_load_data()
expected_shape = tf.TensorShape([None, 120, 120,3]) # Represents the shape of a Tensor.
Mein_model.load_data() # HIER wird der Fehlermeldung angezeigt