TPD

Module Contents

TPD

class TPD(sinkhorn_alpha, sinkhorn_max_iter=5000, stopThr=0.005)

Bases: torch.nn.Module

sinkhorn_alpha
sinkhorn_max_iter = 5000
stopThr = 0.005
epsilon = 1e-16
forward(topic_embeddings_list, weight_loss_TPD=20.0)
sinkhorn(M, return_transp=False)