1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
| def train(node_list, edge_list, label_list, T, ndict_path="./node_dict.json"): if os.path.exists(ndict_path): with open(ndict_path, "r") as fp: node_dict = json.load(fp) else: node_dict = dict([(node, ind) for ind, node in enumerate(node_list)]) node_dict = {"stoi" : node_dict, "itos" : node_list} with open(ndict_path, "w") as fp: json.dump(node_dict, fp)
Degree = dict() for n1, n2 in edge_list: if n1 in Degree: Degree[n1].add(n2) else: Degree[n1] = {n2} if n2 in Degree: Degree[n2].add(n1) else: Degree[n2] = {n1} node_inds = [] node_neis = [] for n in node_list: node_inds += [node_dict["stoi"][n]] * len(Degree[n]) node_neis += list(map(lambda x: node_dict["stoi"][x],list(Degree[n]))) dg_list = list(map(lambda x: len(Degree[node_dict["itos"][x]]), node_inds)) train_node_list = [0,1,2,6,7,8,12,13,14] train_node_label = [0,0,0,1,1,1,2,2,2] test_node_list = [3,4,5,9,10,11,15,16,17] test_node_label = [0,0,0,1,1,1,2,2,2] model = OriLinearGNN(node_num=len(node_list), feat_dim=2, stat_dim=2, T=T) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.01) criterion = nn.CrossEntropyLoss(size_average=True) min_loss = float('inf') train_loss_list = [] train_acc_list = [] test_acc_list = [] node_inds_tensor = Variable(torch.Tensor(node_inds).long()) node_neis_tensor = Variable(torch.Tensor(node_neis).long()) train_label = Variable(torch.Tensor(train_node_label).long()) for ep in range(500): res = model(node_inds_tensor, node_neis_tensor, dg_list) train_res = torch.index_select(res, 0, torch.Tensor(train_node_list).long()) test_res = torch.index_select(res, 0, torch.Tensor(test_node_list).long()) loss = criterion(input=train_res, target=train_label) loss_val = loss.item() train_acc = CalAccuracy(train_res.cpu().detach().numpy(), np.array(train_node_label)) test_acc = CalAccuracy(test_res.cpu().detach().numpy(), np.array(test_node_label)) optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() train_loss_list.append(loss_val) test_acc_list.append(test_acc) train_acc_list.append(train_acc) if loss_val < min_loss: min_loss = loss_val print("==> [Epoch {}] : loss {:.4f}, min_loss {:.4f}, train_acc {:.3f}, test_acc {:.3f}".format(ep, loss_val, min_loss, train_acc, test_acc)) return train_loss_list, train_acc_list, test_acc_list
|