Confusion Matrix für Keras LSTM Output

Wenn du dir nicht sicher bist, in welchem der anderen Foren du die Frage stellen sollst, dann bist du hier im Forum für allgemeine Fragen sicher richtig.
Antworten
Bayne
User
Beiträge: 40
Registriert: Freitag 31. Mai 2019, 16:28

Wie gebe ich eine Confusion für mein LSTM RNN aus?



bisheriger Versuch: (lässt sich jedoch nicht plotten

Code: Alles auswählen

'''|___| CONFUSIONMATRIX |___|'''

def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues):
    """ This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`."""
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    # Only use the labels that appear in the data
    classes = classes[unique_labels(y_true, y_pred)]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]), 
           yticks=np.arange(cm.shape[0]), xticklabels=classes, yticklabels=classes,  # ... and label them with the respective list entries
           title=title, ylabel='True label', xlabel='Predicted label')

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Rotate the tick labels and set their alignment.

    
    fmt = '.2f' if normalize else 'd' # Loop over data dimensions and create text annotations.
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax


class_names=str(list("0,0","1,0","2,0","3,0","4,0","1,-1","2,-1","3,-1","4,-1","2,-2","3,-2","4,-2","-3,2","3,-3","4,-3","4,-4",
                    "-1,0","-2,0","-3,0","-4,0","-2,1","-3,1","-4,1","-4,2","-4,3"))#set(test_y)))


print(class_names)

np.set_printoptions(precision=2)

pred_y = model.predict_classes(test_x)     #predict on Test Set

#matrix = confusion_matrix(test_y.argmax(axis=1), pred_y.argmax(axis=1))

# Plot non-normalized confusion matrix
plot_confusion_matrix(test_y, pred_y, classes=class_names, title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plot_confusion_matrix(test_y, pred_y, classes=class_names, normalize=True, title='Normalized confusion matrix')
plt.show()
Antworten