Code: Alles auswählen
class Unet_Model(BaseModel):
def __init__(self,config):
super().__init__(config)
self.dataset = None
self.info = None
self.train_dataset = None
self.test_dataset = None
def load_data(self):
"""Loads and Preprocess data"""
self.dataset, self.info = DataLoader().load_data(self.config)
self._preprocess_data()
def _preprocess_data(self):
"""Splits into training and test """
train = self.dataset["train"].map(map_func =self._load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
test = self.dataset["test"].map(map_func = self._load_image_test)
self.train_dataset = train.cache().shuffle(self.buffer_size).batch(self.batch_size).prefetch(self.buffer_size)
self.test_dataset = test.shuffle(self.buffer_size)
Code: Alles auswählen
import tensorflow_datasets as tfds
class Unet_Model:
def __init__(self):
self.dataset = None
self.info = None
self.train_dataset = []
self.test_dataset = []
def load_data(self):
self.dataset, self.info = DataLoader.load_data()
self._preprocess_data()
def _preprocess_data(self):
self.train_dataset = self.dataset["train"]
self.test_dataset = self.dataset["test"]
KeyError: 'train'
Da mein vereinfachter Code bei jupyter notebook läuft, muss doch das self.dataset attribut den key "train" haben
Der code von meiner Testklasse ist folgend angegeben
Code: Alles auswählen
def dummy_load_data(*args, **kwargs):
with tfds.testing.mock_data(num_examples=1): # Mock tfds to generate random data
return tfds.load(Konfigurationen.data.path,with_info=Konfigurationen.data.load_with_info)
class UnetTest(tf.test.TestCase):
def setUp(self):
super().setUp() # super funktion muss bei tf.test.TestCase aufgerufen werden
self.unet = Unet_Model(CFG)
@patch(target="model.Unet.DataLoader.load_data")
def test_load_data(self, mock_obj):
mock_obj.return_value = dummy_load_data()
expected_shape = tf.TensorShape([None, self.unet.config.data.image_size, self.unet.config.data.image_size,3])
self.unet.load_data()
mock_obj.assert_called()
self.assertEqual(self.unet.train_dataset.element_spec[0].shape, expected_shape)
self.assertEqual(self.unet.train_dataset.element_spec[0].shape, expected_shape)