TensorTrain
- class trainsum.tensortrain.TensorTrain(base, copy_data=True)
N-dimensional tensor train. Main class for representing and manipulating tensor trains.Should not be instantiated directly, but rather through the tensortrain function.
- Parameters:
base (TrainBase)
copy_data (bool)
- property device: Any
Get and set the device.
- property dtype: Any
Get and set the data type.
- property shape: TrainShape
Return the shape of the tensor train with the current ranks.
- property cores: Sequence
Tensor cores.
- to_tensor()
Construct the full tensor from the tensor train.
- Return type:
S
- extend(*trains, copy_data=True)
Extend the tensor train by fusing it with another tensor train.
- Parameters:
trains (Self)
copy_data (bool)
- Return type:
None
- conj()
Return the complex conjugate of the tensor train.
- Return type:
Self
- normalize(idx, /)
Create the canonical form of the tensor train with respect to the core at index idx.
- Parameters:
idx (int)
- Return type:
None
- truncate()
Reduce the ranks of the tensor train according to the current einsum options.
- Return type:
None
- transform(func)
Perform an element-wise transformation of the tensor train defined by some function.
- Parameters:
func (Callable[[S], S])
- Return type:
Self
- reverse()
Return a new tensor train with reversed order of cores and digits.
- Return type:
Self
- permute_dims(order)
Permute the dimensions of the tensor train according to the given order.
- Parameters:
order (Sequence[int])
- Return type:
None