TensorTrain

class trainsum.tensortrain.TensorTrain(base, copy_data=True)

N-dimensional tensor train. Main class for representing and manipulating tensor trains.Should not be instantiated directly, but rather through the tensortrain function.

Parameters:
  • base (TrainBase)

  • copy_data (bool)

property device: Any

Get and set the device.

property dtype: Any

Get and set the data type.

property shape: TrainShape

Return the shape of the tensor train with the current ranks.

property cores: Sequence

Tensor cores.

to_tensor()

Construct the full tensor from the tensor train.

Return type:

S

extend(*trains, copy_data=True)

Extend the tensor train by fusing it with another tensor train.

Parameters:
  • trains (Self)

  • copy_data (bool)

Return type:

None

conj()

Return the complex conjugate of the tensor train.

Return type:

Self

normalize(idx, /)

Create the canonical form of the tensor train with respect to the core at index idx.

Parameters:

idx (int)

Return type:

None

truncate()

Reduce the ranks of the tensor train according to the current einsum options.

Return type:

None

transform(func)

Perform an element-wise transformation of the tensor train defined by some function.

Parameters:

func (Callable[[S], S])

Return type:

Self

reverse()

Return a new tensor train with reversed order of cores and digits.

Return type:

Self

permute_dims(order)

Permute the dimensions of the tensor train according to the given order.

Parameters:

order (Sequence[int])

Return type:

None