Hilfe mit Code
Verfasst: Freitag 12. Juni 2020, 20:50
Hallo,
ich arbeite seit längerem an einem Code aber nichts funktioniert so wie es sein sollte. Vielleicht fällt euch was auf bei dem Code.
Vielen Dank im Voraus
ich arbeite seit längerem an einem Code aber nichts funktioniert so wie es sein sollte. Vielleicht fällt euch was auf bei dem Code.
Vielen Dank im Voraus

Code: Alles auswählen
import traceback
import config
import copy
import numpy as np
import tensorflow as tf
import astor
from semantic_graph import GraphHyp, get_final_propagation_schedule
from graph_model import GraphModel
import logging
import sys
from typing import List
class GraphDecoder():
def __init__(self, model:GraphModel):
super().__init__()
self.graph_model = model
def decode(self, example, grammar, terminal_vocab, beam_size, max_time_step, log=False):
# beam search decoding
eos = 1
unk = terminal_vocab.unk
query_tokens = example.data[0]
query_embed = self.graph_model.encode_query(query_tokens)
query_mask = tf.not_equal(query_tokens, 0)
completed_hyps = []
completed_hyp_num = 0
live_hyp_num = 1
initial_node_state_capacity = 100
root_hyp = GraphHyp(grammar)
root_hyp.decoder_output = np.zeros(self.graph_model.decoder_hidden_dim).astype('float32')
root_hyp.decoder_hidden_state = np.zeros(self.graph_model.decoder_hidden_dim).astype('float32')
root_hyp.action_embed = np.zeros(self.graph_model.rule_embed_dim).astype('float32')
root_hyp.current_node_states = np.zeros((initial_node_state_capacity, self.graph_model.node_embed_dim), dtype='float32')
hyp_samples:List[GraphHyp] = [root_hyp] # [list() for i in range(live_hyp_num)]
# source word id in the terminal vocab
src_token_id = [terminal_vocab[t] for t in example.query][:config.max_query_length]
unk_pos_list = [x for x, t in enumerate(src_token_id) if t == unk]
# sometimes a word may appear multi-times in the source, in this case,
# we just copy its first appearing position. Therefore we mask the words
# appearing second and onwards to -1
token_set = set()
for i, tid in enumerate(src_token_id):
if tid in token_set:
src_token_id[i] = -1 # todo: maybe allow multiple indices here?
else: token_set.add(tid)
for t in range(max_time_step):
hyp_num = len(hyp_samples)
# print('time step [%d]' % t)
decoder_prev_output = np.array([hyp.decoder_output for hyp in hyp_samples]).astype('float32')
decoder_prev_state = np.array([hyp.decoder_hidden_state for hyp in hyp_samples]).astype('float32')
hist_h = np.zeros((hyp_num, max_time_step, self.graph_model.decoder_hidden_dim)).astype('float32')
if t > 0:
for i, hyp in enumerate(hyp_samples):
hist_h[i, :len(hyp.hist_h), :] = hyp.hist_h
prop_infos = [hyp.get_new_node_representation_info() for hyp in hyp_samples]
node_remapping, node_count, \
node_ids, node_type_tensor, synthesized_terminal_node_ids_tensor, terminal_tokens_tensor, terminal_token_types_tensor, \
sending_nodes_list, edge_labels_list, msg_targets_list, receiving_nodes_list, \
propagation_substeps, sending_node_num, receiving_node_num = get_final_propagation_schedule(prop_infos)
current_propagating_node_states = np.zeros((node_count, self.graph_model.node_embed_dim), dtype='float32')
# gather relevant node states from hyps
for hyp_id, hyp_remapping in node_remapping.items():
current_hyp = hyp_samples[hyp_id]
hyp_node_states_size = np.shape(current_hyp.current_node_states)[0]
hyp_node_ids = np.array(list(hyp_remapping.keys()))
propagating_node_ids = np.array(list(hyp_remapping.values()))
# expand hyp node states if necessary
while np.amax(hyp_node_ids) + 1 > hyp_node_states_size:
current_hyp.current_node_states = np.concatenate((current_hyp.current_node_states, np.zeros_like(current_hyp.current_node_states)), axis=0)
hyp_node_states_size *= 2
current_propagating_node_states[propagating_node_ids, :] = hyp_samples[hyp_id].current_node_states[hyp_node_ids, :]
prev_action = np.array([hyp.last_action for hyp in hyp_samples]).astype('int32')
prev_action = tf.convert_to_tensor(np.expand_dims(prev_action, axis=1))
prev_action_type = np.array([hyp.last_action_type for hyp in hyp_samples]).astype('int32')
prev_action_type = tf.convert_to_tensor(np.expand_dims(prev_action_type, axis=1))
terminal_count = synthesized_terminal_node_ids_tensor.shape[0]
query_embed_terminal_tiled = tf.convert_to_tensor(np.tile(query_embed, [terminal_count, 1, 1]))
query_embed_tiled = tf.convert_to_tensor(np.tile(query_embed, [live_hyp_num, 1, 1]))
query_mask_tiled = tf.convert_to_tensor(np.tile(query_mask, [live_hyp_num, 1]))
terminal_token_embed = self.graph_model.encode_tokens(query_embed_terminal_tiled,
tf.convert_to_tensor(terminal_tokens_tensor[:, :, 0]), tf.convert_to_tensor(terminal_tokens_tensor[:, :, 1]),
tf.convert_to_tensor(terminal_token_types_tensor[:, :, :]))
prev_terminal_token_embed = self.graph_model.encode_tokens(query_embed_tiled,
prev_action[:, :, 1], prev_action[:, :, 2], prev_action_type[:, :, 1:])
node_representations = self.graph_model.get_node_representations(
tf.convert_to_tensor(node_type_tensor), tf.convert_to_tensor(synthesized_terminal_node_ids_tensor), tf.convert_to_tensor(terminal_token_embed),
tf.reduce_any(tf.cast(tf.convert_to_tensor(terminal_token_types_tensor), tf.bool)),
tf.convert_to_tensor(propagation_substeps), tf.convert_to_tensor(sending_node_num), tf.convert_to_tensor(receiving_node_num),
tf.convert_to_tensor(sending_nodes_list), tf.convert_to_tensor(edge_labels_list), tf.convert_to_tensor(msg_targets_list), tf.convert_to_tensor(receiving_nodes_list),
initial_node_embeddings=tf.convert_to_tensor(current_propagating_node_states), updated_node_ids=tf.convert_to_tensor(node_ids)) * 1
node_representations_array = np.array(node_representations, dtype='float32')
# scatter relevant node states back into the hyp nodes
for hyp_id, hyp_remapping in node_remapping.items():
current_hyp = hyp_samples[hyp_id]
hyp_node_ids = np.array(list(hyp_remapping.keys()))
propagating_node_ids = np.array(list(hyp_remapping.values()))
current_hyp.current_node_states[hyp_node_ids, :] = node_representations_array[propagating_node_ids, :]
node_embed_per_action = np.array([hyp.current_node_states[hyp.frontier_nt().graph_id, :] for hyp in hyp_samples], dtype='float32')
node_embed_per_action = tf.convert_to_tensor(np.expand_dims(node_embed_per_action, axis=1))
#(batch_size, max_example_action_num, rule_num / target_vocab_size / max_query_length / _ / _)
rule_prob, vocab_prob, copy_prob, decoder_outputs, decoder_hidden_states = self.graph_model.predict_actions(
query_embed_tiled, prev_action[:, :, 0],
prev_terminal_token_embed, prev_action_type, node_embed_per_action,
query_mask=query_mask_tiled, tgt_action_seq_type=tf.ones([1, 1, 3], dtype=tf.int32), time_steps=tf.constant([t], dtype=tf.int32),
decoder_prev_output=tf.convert_to_tensor(decoder_prev_output), decoder_prev_state=tf.convert_to_tensor(decoder_prev_state),
output_hist=tf.convert_to_tensor(hist_h), train=False)
rule_prob = np.array(rule_prob[:,0,:])
vocab_prob = np.array(vocab_prob[:,0,:])
copy_prob = np.array(copy_prob[:,0,:])
decoder_output_states = np.array(decoder_outputs[:,0,:])
decoder_hidden_states = np.array(decoder_hidden_states[:,0,:])
new_hyp_samples = []
word_prob = vocab_prob
word_prob[:, unk] = 0
cpy_src_ids = np.ones_like(word_prob) * -1.0
hyp_scores = np.array([hyp.score for hyp in hyp_samples])
rule_apply_cand_hyp_ids = []
rule_apply_cand_scores = []
rule_apply_cand_rules = []
rule_apply_cand_rule_ids = []
hyp_frontier_nts = []
word_gen_hyp_ids = []
unk_words = []
for k in range(live_hyp_num):
hyp = hyp_samples[k]
frontier_nt = hyp.frontier_nt()
hyp_frontier_nts.append(frontier_nt)
assert hyp, 'none hyp!'
# if it's not a leaf
if not grammar.is_terminal(frontier_nt):
# iterate over all the possible rules
rules = grammar[frontier_nt.as_type_node] if config.head_nt_constraint else grammar
assert len(rules) > 0, 'fail to expand nt node %s' % frontier_nt
for rule in rules:
rule_id = grammar.rule_to_id[rule]
cur_rule_score = np.log(rule_prob[k, rule_id])
new_hyp_score = hyp.score + cur_rule_score
rule_apply_cand_hyp_ids.append(k)
rule_apply_cand_scores.append(new_hyp_score)
rule_apply_cand_rules.append(rule)
rule_apply_cand_rule_ids.append(rule_id)
else: # it's a leaf that holds values
cand_copy_prob = 0.0
for i, tid in enumerate(src_token_id): #todo: enumerate over set instead?
if tid != -1:
word_prob[k, tid] += copy_prob[k, i]
cpy_src_ids[k, tid] = i
# and unk copy probability
if len(unk_pos_list) > 0:
unk_pos = copy_prob[k, unk_pos_list].argmax() #todo: only unk with max probability considered? maybe iterate
unk_pos = unk_pos_list[unk_pos]
unk_copy_score = copy_prob[k, unk_pos]
word_prob[k, unk] = unk_copy_score
cpy_src_ids[k, unk] = unk_pos
unk_word = example.query[unk_pos]
unk_words.append(unk_word)
word_gen_hyp_ids.append(k)
word_scores = np.log(word_prob + 1.e-7)
word_gen_hyp_num = len(word_gen_hyp_ids)
rule_apply_cand_num = len(rule_apply_cand_scores)
if word_gen_hyp_num > 0:
word_gen_cand_scores = hyp_scores[word_gen_hyp_ids, None] + word_scores[word_gen_hyp_ids, :]
word_gen_cand_scores_flat = np.core.reshape(word_gen_cand_scores, [-1])
cpy_cand_src_ids = cpy_src_ids[word_gen_hyp_ids, :]
cand_scores = np.core.concatenate([rule_apply_cand_scores, word_gen_cand_scores_flat])
else:
cand_scores = np.array(rule_apply_cand_scores)
top_cand_ids = (-cand_scores).argsort()[:beam_size - completed_hyp_num]
# expand_cand_num = 0
for cand_id in top_cand_ids:
# cand is rule application
new_hyp = None
if cand_id < rule_apply_cand_num:
hyp_id = rule_apply_cand_hyp_ids[cand_id]
hyp = hyp_samples[hyp_id]
rule_id = rule_apply_cand_rule_ids[cand_id]
rule = rule_apply_cand_rules[cand_id]
new_hyp_score = rule_apply_cand_scores[cand_id]
new_hyp = GraphHyp(hyp)
new_hyp.apply_rule(rule)
new_hyp.score = new_hyp_score
new_hyp.decoder_output = copy.copy(decoder_output_states[hyp_id])
new_hyp.hist_h.append(copy.copy(new_hyp.decoder_output))
new_hyp.decoder_hidden_state = copy.copy(decoder_hidden_states[hyp_id])
else:
tid = (cand_id - rule_apply_cand_num) % word_scores.shape[1]
word_gen_hyp_id = (cand_id - rule_apply_cand_num) // word_scores.shape[1]
hyp_id = word_gen_hyp_ids[word_gen_hyp_id]
cpy_src_id = cpy_cand_src_ids[word_gen_hyp_id, tid]
if tid == unk:
token = unk_words[word_gen_hyp_id]
else:
token = terminal_vocab.id_token_map[tid]
if cpy_src_id == -1:
token_gen_id = tid
else:
token_gen_id = -1
frontier_nt = hyp_frontier_nts[hyp_id]
hyp = hyp_samples[hyp_id]
new_hyp_score = word_gen_cand_scores[word_gen_hyp_id, tid]
new_hyp = GraphHyp(hyp)
new_hyp.append_token(token, token_gen_id=token_gen_id, token_src_id=cpy_src_id)
new_hyp.score = new_hyp_score
new_hyp.decoder_output = copy.copy(decoder_output_states[hyp_id])
new_hyp.hist_h.append(copy.copy(new_hyp.decoder_output))
new_hyp.decoder_hidden_state = copy.copy(decoder_hidden_states[hyp_id])
# get the new frontier nt after rule application
new_frontier_nt = new_hyp.frontier_nt()
# if new_frontier_nt is None, then we have a new completed hyp!
if new_frontier_nt is None:
new_hyp.n_timestep = t + 1
completed_hyps.append(new_hyp)
completed_hyp_num += 1
else:
new_hyp.parent_rule_id = grammar.rule_to_id[new_frontier_nt.parent.applied_rule]
new_hyp_samples.append(new_hyp)
live_hyp_num = min(len(new_hyp_samples), beam_size - completed_hyp_num)
if live_hyp_num < 1:
break
hyp_samples = new_hyp_samples
# prune the hyp space
if completed_hyp_num >= beam_size:
break
completed_hyps = sorted(completed_hyps, key=lambda x: x.score, reverse=True)
return completed_hyps
def decode_python_dataset(model:GraphModel, dataset, verbose=True):
from lang.py.parse import decode_tree_to_python_ast
if verbose:
logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)
decoder = GraphDecoder(model)
decode_results = []
cum_num = 0
for example in dataset.examples[329:339]:
cand_list = decoder.decode(example, dataset.grammar, dataset.terminal_vocab,
beam_size=config.beam_size, max_time_step=config.decode_max_time_step)
if cum_num % 10 == 0:
logging.debug(cum_num)
exg_decode_results = []
for cid, cand in enumerate(cand_list[:10]):
try:
ast_tree = decode_tree_to_python_ast(cand.tree)
code = astor.to_source(ast_tree)
exg_decode_results.append((cid, cand, ast_tree, code, cand.tree))
except:
if verbose:
print("Exception in converting tree to code:")
print('-' * 60)
print('raw_id: %d, beam pos: %d' % (example.raw_id, cid))
traceback.print_exc(file=sys.stdout)
print('-' * 60)
cum_num += 1
if cum_num % 50 == 0 and verbose:
print('%d examples so far ...' % cum_num)
decode_results.append(exg_decode_results)
return decode_results