Fehler im Code #2
Verfasst: Samstag 30. Juli 2022, 16:11
Ich habe wieder einen Fehler im Code den ich einfach nicht finde.
Wenn ich den oberen code vereinfacht auf jupyter notebook laufen lasse, dann ist alles okay. Alles funktioniert. (siehe folgender code)
Wenn ich aber mit unittest kontrollieren will ob mein oberer Code (erster Code abschnitt ganz oben) passt, dann wird mir ständig folgende Fehlermeldung angezeigt: train = self.dataset["train"].map(map_func =self._load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
KeyError: 'train'
Da mein vereinfachter Code bei jupyter notebook läuft, muss doch das self.dataset attribut den key "train" haben
Der code von meiner Testklasse ist folgend angegeben
Code: Alles auswählen
class Unet_Model(BaseModel):
def __init__(self,config):
super().__init__(config)
self.dataset = None
self.info = None
self.train_dataset = None
self.test_dataset = None
def load_data(self):
"""Loads and Preprocess data"""
self.dataset, self.info = DataLoader().load_data(self.config)
self._preprocess_data()
def _preprocess_data(self):
"""Splits into training and test """
train = self.dataset["train"].map(map_func =self._load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test = self.dataset["test"].map(map_func = self._load_image_test)
self.train_dataset = train.cache().shuffle(self.buffer_size).batch(self.batch_size).prefetch(self.buffer_size)
self.test_dataset = test.shuffle(self.buffer_size)
Code: Alles auswählen
import tensorflow_datasets as tfds
class Unet_Model:
def __init__(self):
self.dataset = None
self.info = None
self.train_dataset = []
self.test_dataset = []
def load_data(self):
self.dataset, self.info = DataLoader.load_data()
self._preprocess_data()
def _preprocess_data(self):
self.train_dataset = self.dataset["train"]
self.test_dataset = self.dataset["test"]
KeyError: 'train'
Da mein vereinfachter Code bei jupyter notebook läuft, muss doch das self.dataset attribut den key "train" haben
Der code von meiner Testklasse ist folgend angegeben
Code: Alles auswählen
def dummy_load_data(*args, **kwargs):
with tfds.testing.mock_data(num_examples=1): # Mock tfds to generate random data
return tfds.load(Konfigurationen.data.path,with_info=Konfigurationen.data.load_with_info)
class UnetTest(tf.test.TestCase):
def setUp(self):
super().setUp() # super funktion muss bei tf.test.TestCase aufgerufen werden
self.unet = Unet_Model(CFG)
@patch(target="model.Unet.DataLoader.load_data")
def test_load_data(self, mock_obj):
mock_obj.return_value = dummy_load_data()
expected_shape = tf.TensorShape([None, self.unet.config.data.image_size, self.unet.config.data.image_size,3])
self.unet.load_data()
mock_obj.assert_called()
self.assertEqual(self.unet.train_dataset.element_spec[0].shape, expected_shape)
self.assertEqual(self.unet.train_dataset.element_spec[0].shape, expected_shape)