ich arbeite seit längerem an einem Code aber nichts funktioniert so wie es sein sollte. Vielleicht fällt euch was auf bei dem Code.
Vielen Dank im Voraus

Code: Alles auswählen
import traceback
import config
import copy
import numpy as np
import tensorflow as tf
import astor
from semantic_graph import GraphHyp, get_final_propagation_schedule
from graph_model import GraphModel
import logging
import sys
from typing import List
class GraphDecoder():
def __init__(self, model:GraphModel):
super().__init__()
self.graph_model = model
def decode(self, example, grammar, terminal_vocab, beam_size, max_time_step, log=False):
# beam search decoding
eos = 1
unk = terminal_vocab.unk
query_tokens = example.data[0]
query_embed = self.graph_model.encode_query(query_tokens)
query_mask = tf.not_equal(query_tokens, 0)
completed_hyps = []
completed_hyp_num = 0
live_hyp_num = 1
initial_node_state_capacity = 100
root_hyp = GraphHyp(grammar)
root_hyp.decoder_output = np.zeros(self.graph_model.decoder_hidden_dim).astype('float32')
root_hyp.decoder_hidden_state = np.zeros(self.graph_model.decoder_hidden_dim).astype('float32')
root_hyp.action_embed = np.zeros(self.graph_model.rule_embed_dim).astype('float32')
root_hyp.current_node_states = np.zeros((initial_node_state_capacity, self.graph_model.node_embed_dim), dtype='float32')
hyp_samples:List[GraphHyp] = [root_hyp] # [list() for i in range(live_hyp_num)]
# source word id in the terminal vocab
src_token_id = [terminal_vocab[t] for t in example.query][:config.max_query_length]
unk_pos_list = [x for x, t in enumerate(src_token_id) if t == unk]
# sometimes a word may appear multi-times in the source, in this case,
# we just copy its first appearing position. Therefore we mask the words
# appearing second and onwards to -1
token_set = set()
for i, tid in enumerate(src_token_id):
if tid in token_set:
src_token_id[i] = -1 # todo: maybe allow multiple indices here?
else: token_set.add(tid)
for t in range(max_time_step):
hyp_num = len(hyp_samples)
# print('time step [%d]' % t)
decoder_prev_output = np.array([hyp.decoder_output for hyp in hyp_samples]).astype('float32')
decoder_prev_state = np.array([hyp.decoder_hidden_state for hyp in hyp_samples]).astype('float32')
hist_h = np.zeros((hyp_num, max_time_step, self.graph_model.decoder_hidden_dim)).astype('float32')
if t > 0:
for i, hyp in enumerate(hyp_samples):
hist_h[i, :len(hyp.hist_h), :] = hyp.hist_h
prop_infos = [hyp.get_new_node_representation_info() for hyp in hyp_samples]
node_remapping, node_count, \
node_ids, node_type_tensor, synthesized_terminal_node_ids_tensor, terminal_tokens_tensor, terminal_token_types_tensor, \
sending_nodes_list, edge_labels_list, msg_targets_list, receiving_nodes_list, \
propagation_substeps, sending_node_num, receiving_node_num = get_final_propagation_schedule(prop_infos)
current_propagating_node_states = np.zeros((node_count, self.graph_model.node_embed_dim), dtype='float32')
# gather relevant node states from hyps
for hyp_id, hyp_remapping in node_remapping.items():
current_hyp = hyp_samples[hyp_id]
hyp_node_states_size = np.shape(current_hyp.current_node_states)[0]
hyp_node_ids = np.array(list(hyp_remapping.keys()))
propagating_node_ids = np.array(list(hyp_remapping.values()))
# expand hyp node states if necessary
while np.amax(hyp_node_ids) + 1 > hyp_node_states_size:
current_hyp.current_node_states = np.concatenate((current_hyp.current_node_states, np.zeros_like(current_hyp.current_node_states)), axis=0)
hyp_node_states_size *= 2
current_propagating_node_states[propagating_node_ids, :] = hyp_samples[hyp_id].current_node_states[hyp_node_ids, :]
prev_action = np.array([hyp.last_action for hyp in hyp_samples]).astype('int32')
prev_action = tf.convert_to_tensor(np.expand_dims(prev_action, axis=1))
prev_action_type = np.array([hyp.last_action_type for hyp in hyp_samples]).astype('int32')
prev_action_type = tf.convert_to_tensor(np.expand_dims(prev_action_type, axis=1))
terminal_count = synthesized_terminal_node_ids_tensor.shape[0]
query_embed_terminal_tiled = tf.convert_to_tensor(np.tile(query_embed, [terminal_count, 1, 1]))
query_embed_tiled = tf.convert_to_tensor(np.tile(query_embed, [live_hyp_num, 1, 1]))
query_mask_tiled = tf.convert_to_tensor(np.tile(query_mask, [live_hyp_num, 1]))
terminal_token_embed = self.graph_model.encode_tokens(query_embed_terminal_tiled,
tf.convert_to_tensor(terminal_tokens_tensor[:, :, 0]), tf.convert_to_tensor(terminal_tokens_tensor[:, :, 1]),
tf.convert_to_tensor(terminal_token_types_tensor[:, :, :]))
prev_terminal_token_embed = self.graph_model.encode_tokens(query_embed_tiled,
prev_action[:, :, 1], prev_action[:, :, 2], prev_action_type[:, :, 1:])
node_representations = self.graph_model.get_node_representations(
tf.convert_to_tensor(node_type_tensor), tf.convert_to_tensor(synthesized_terminal_node_ids_tensor), tf.convert_to_tensor(terminal_token_embed),
tf.reduce_any(tf.cast(tf.convert_to_tensor(terminal_token_types_tensor), tf.bool)),
tf.convert_to_tensor(propagation_substeps), tf.convert_to_tensor(sending_node_num), tf.convert_to_tensor(receiving_node_num),
tf.convert_to_tensor(sending_nodes_list), tf.convert_to_tensor(edge_labels_list), tf.convert_to_tensor(msg_targets_list), tf.convert_to_tensor(receiving_nodes_list),
initial_node_embeddings=tf.convert_to_tensor(current_propagating_node_states), updated_node_ids=tf.convert_to_tensor(node_ids)) * 1
node_representations_array = np.array(node_representations, dtype='float32')
# scatter relevant node states back into the hyp nodes
for hyp_id, hyp_remapping in node_remapping.items():
current_hyp = hyp_samples[hyp_id]
hyp_node_ids = np.array(list(hyp_remapping.keys()))
propagating_node_ids = np.array(list(hyp_remapping.values()))
current_hyp.current_node_states[hyp_node_ids, :] = node_representations_array[propagating_node_ids, :]
node_embed_per_action = np.array([hyp.current_node_states[hyp.frontier_nt().graph_id, :] for hyp in hyp_samples], dtype='float32')
node_embed_per_action = tf.convert_to_tensor(np.expand_dims(node_embed_per_action, axis=1))
#(batch_size, max_example_action_num, rule_num / target_vocab_size / max_query_length / _ / _)
rule_prob, vocab_prob, copy_prob, decoder_outputs, decoder_hidden_states = self.graph_model.predict_actions(
query_embed_tiled, prev_action[:, :, 0],
prev_terminal_token_embed, prev_action_type, node_embed_per_action,
query_mask=query_mask_tiled, tgt_action_seq_type=tf.ones([1, 1, 3], dtype=tf.int32), time_steps=tf.constant([t], dtype=tf.int32),
decoder_prev_output=tf.convert_to_tensor(decoder_prev_output), decoder_prev_state=tf.convert_to_tensor(decoder_prev_state),
output_hist=tf.convert_to_tensor(hist_h), train=False)
rule_prob = np.array(rule_prob[:,0,:])
vocab_prob = np.array(vocab_prob[:,0,:])
copy_prob = np.array(copy_prob[:,0,:])
decoder_output_states = np.array(decoder_outputs[:,0,:])
decoder_hidden_states = np.array(decoder_hidden_states[:,0,:])
new_hyp_samples = []
word_prob = vocab_prob
word_prob[:, unk] = 0
cpy_src_ids = np.ones_like(word_prob) * -1.0
hyp_scores = np.array([hyp.score for hyp in hyp_samples])
rule_apply_cand_hyp_ids = []
rule_apply_cand_scores = []
rule_apply_cand_rules = []
rule_apply_cand_rule_ids = []
hyp_frontier_nts = []
word_gen_hyp_ids = []
unk_words = []
for k in range(live_hyp_num):
hyp = hyp_samples[k]
frontier_nt = hyp.frontier_nt()
hyp_frontier_nts.append(frontier_nt)
assert hyp, 'none hyp!'
# if it's not a leaf
if not grammar.is_terminal(frontier_nt):
# iterate over all the possible rules
rules = grammar[frontier_nt.as_type_node] if config.head_nt_constraint else grammar
assert len(rules) > 0, 'fail to expand nt node %s' % frontier_nt
for rule in rules:
rule_id = grammar.rule_to_id[rule]
cur_rule_score = np.log(rule_prob[k, rule_id])
new_hyp_score = hyp.score + cur_rule_score
rule_apply_cand_hyp_ids.append(k)
rule_apply_cand_scores.append(new_hyp_score)
rule_apply_cand_rules.append(rule)
rule_apply_cand_rule_ids.append(rule_id)
else: # it's a leaf that holds values
cand_copy_prob = 0.0
for i, tid in enumerate(src_token_id): #todo: enumerate over set instead?
if tid != -1:
word_prob[k, tid] += copy_prob[k, i]
cpy_src_ids[k, tid] = i
# and unk copy probability
if len(unk_pos_list) > 0:
unk_pos = copy_prob[k, unk_pos_list].argmax() #todo: only unk with max probability considered? maybe iterate
unk_pos = unk_pos_list[unk_pos]
unk_copy_score = copy_prob[k, unk_pos]
word_prob[k, unk] = unk_copy_score
cpy_src_ids[k, unk] = unk_pos
unk_word = example.query[unk_pos]
unk_words.append(unk_word)
word_gen_hyp_ids.append(k)
word_scores = np.log(word_prob + 1.e-7)
word_gen_hyp_num = len(word_gen_hyp_ids)
rule_apply_cand_num = len(rule_apply_cand_scores)
if word_gen_hyp_num > 0:
word_gen_cand_scores = hyp_scores[word_gen_hyp_ids, None] + word_scores[word_gen_hyp_ids, :]
word_gen_cand_scores_flat = np.core.reshape(word_gen_cand_scores, [-1])
cpy_cand_src_ids = cpy_src_ids[word_gen_hyp_ids, :]
cand_scores = np.core.concatenate([rule_apply_cand_scores, word_gen_cand_scores_flat])
else:
cand_scores = np.array(rule_apply_cand_scores)
top_cand_ids = (-cand_scores).argsort()[:beam_size - completed_hyp_num]
# expand_cand_num = 0
for cand_id in top_cand_ids:
# cand is rule application
new_hyp = None
if cand_id < rule_apply_cand_num:
hyp_id = rule_apply_cand_hyp_ids[cand_id]
hyp = hyp_samples[hyp_id]
rule_id = rule_apply_cand_rule_ids[cand_id]
rule = rule_apply_cand_rules[cand_id]
new_hyp_score = rule_apply_cand_scores[cand_id]
new_hyp = GraphHyp(hyp)
new_hyp.apply_rule(rule)
new_hyp.score = new_hyp_score
new_hyp.decoder_output = copy.copy(decoder_output_states[hyp_id])
new_hyp.hist_h.append(copy.copy(new_hyp.decoder_output))
new_hyp.decoder_hidden_state = copy.copy(decoder_hidden_states[hyp_id])
else:
tid = (cand_id - rule_apply_cand_num) % word_scores.shape[1]
word_gen_hyp_id = (cand_id - rule_apply_cand_num) // word_scores.shape[1]
hyp_id = word_gen_hyp_ids[word_gen_hyp_id]
cpy_src_id = cpy_cand_src_ids[word_gen_hyp_id, tid]
if tid == unk:
token = unk_words[word_gen_hyp_id]
else:
token = terminal_vocab.id_token_map[tid]
if cpy_src_id == -1:
token_gen_id = tid
else:
token_gen_id = -1
frontier_nt = hyp_frontier_nts[hyp_id]
hyp = hyp_samples[hyp_id]
new_hyp_score = word_gen_cand_scores[word_gen_hyp_id, tid]
new_hyp = GraphHyp(hyp)
new_hyp.append_token(token, token_gen_id=token_gen_id, token_src_id=cpy_src_id)
new_hyp.score = new_hyp_score
new_hyp.decoder_output = copy.copy(decoder_output_states[hyp_id])
new_hyp.hist_h.append(copy.copy(new_hyp.decoder_output))
new_hyp.decoder_hidden_state = copy.copy(decoder_hidden_states[hyp_id])
# get the new frontier nt after rule application
new_frontier_nt = new_hyp.frontier_nt()
# if new_frontier_nt is None, then we have a new completed hyp!
if new_frontier_nt is None:
new_hyp.n_timestep = t + 1
completed_hyps.append(new_hyp)
completed_hyp_num += 1
else:
new_hyp.parent_rule_id = grammar.rule_to_id[new_frontier_nt.parent.applied_rule]
new_hyp_samples.append(new_hyp)
live_hyp_num = min(len(new_hyp_samples), beam_size - completed_hyp_num)
if live_hyp_num < 1:
break
hyp_samples = new_hyp_samples
# prune the hyp space
if completed_hyp_num >= beam_size:
break
completed_hyps = sorted(completed_hyps, key=lambda x: x.score, reverse=True)
return completed_hyps
def decode_python_dataset(model:GraphModel, dataset, verbose=True):
from lang.py.parse import decode_tree_to_python_ast
if verbose:
logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)
decoder = GraphDecoder(model)
decode_results = []
cum_num = 0
for example in dataset.examples[329:339]:
cand_list = decoder.decode(example, dataset.grammar, dataset.terminal_vocab,
beam_size=config.beam_size, max_time_step=config.decode_max_time_step)
if cum_num % 10 == 0:
logging.debug(cum_num)
exg_decode_results = []
for cid, cand in enumerate(cand_list[:10]):
try:
ast_tree = decode_tree_to_python_ast(cand.tree)
code = astor.to_source(ast_tree)
exg_decode_results.append((cid, cand, ast_tree, code, cand.tree))
except:
if verbose:
print("Exception in converting tree to code:")
print('-' * 60)
print('raw_id: %d, beam pos: %d' % (example.raw_id, cid))
traceback.print_exc(file=sys.stdout)
print('-' * 60)
cum_num += 1
if cum_num % 50 == 0 and verbose:
print('%d examples so far ...' % cum_num)
decode_results.append(exg_decode_results)
return decode_results