# Tensor Factorisation

Nipun Batra  
2025-07-24

<figure>
<a
href="https://colab.research.google.com/github/nipunbatra/ml-teaching/blob/master/notebooks/tensor-factorisation.ipynb"><img
src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<figcaption>Open In Colab</figcaption>
</figure>

In [1]:
import tensorly as tl


In [2]:
import cvxpy as cp
import numpy as np

# Ensure repeatably random problem data.
np.random.seed(0)

# Generate random data matrix A.
m = 10
n = 10
o = 5
k = 2
D = 10*np.ones((m, n, o)) + np.random.randn(m, n, o)

# Initialize Y randomly.
A_init = 10*np.ones((m, k))
B_init = 10*np.ones((n, k))
C_init = 10*np.ones((o, k))




In [3]:
Pred_A = np.einsum('ir, jr, kr ->ijk', A_init, B_init, C_init)

In [4]:
# Ensure same initial random Y, rather than generate new one
# when executing this cell.
B = B_init
C = C_init

# Perform alternating minimization.
MAX_ITERS = 100
residual = np.zeros(MAX_ITERS)
for iter_num in range(0, 1+MAX_ITERS):

    if iter_num % 3 == 0:
        A = cp.Variable(shape=(n, k))
        constraint = [A >= 0]
        prediction = A@tl.tenalg.khatri_rao([C, B]).T
    elif iter_num % 3 == 1:
        B = cp.Variable(shape=(m, k))
        constraint = [B >= 0]
        prediction = B@tl.tenalg.khatri_rao([A, C]).T
    elif iter_num % 3 == 2:
        C = cp.Variable(shape=(o, k))
        constraint = [C >= 0]
        prediction = C@tl.tenalg.khatri_rao([B, A]).T

    obj = cp.Minimize(cp.norm(D.reshape(prediction.shape) - prediction, 'fro')/D.size)
    prob = cp.Problem(obj, constraint)
    prob.solve(solver=cp.SCS, max_iters=10000)

    if prob.status != cp.OPTIMAL:
        raise Exception("Solver did not converge!")

    print('Iteration {}, residual norm {}'.format(iter_num, prob.value))
    residual[iter_num-1] = prob.value

    # Convert variable to NumPy array constant for next iteration.
    if iter_num % 3 == 0:
        A = A.value
    elif iter_num%3 == 1:
        B = B.value
    else:
        C = C.value

Iteration 0, residual norm 0.044229038768601577
Iteration 1, residual norm 0.04438975125966638
Iteration 2, residual norm 0.04485089174072711
Iteration 3, residual norm 0.0446730384004453
Iteration 4, residual norm 0.044526862069177754
Iteration 5, residual norm 0.04484264445045543
Iteration 6, residual norm 0.044478708695822676
Iteration 7, residual norm 0.04447482085818169
Iteration 8, residual norm 0.045090403949033964
Iteration 9, residual norm 0.04454108258260972
Iteration 10, residual norm 0.04409564845828368
Iteration 11, residual norm 0.04513476609369982
Iteration 12, residual norm 0.04452491993393663
Iteration 13, residual norm 0.04378836637021587
Iteration 14, residual norm 0.045190029913906464
Iteration 15, residual norm 0.04437455947694065
Iteration 16, residual norm 0.04376389953499031
Iteration 17, residual norm 0.04474118548369649
Iteration 18, residual norm 0.04485265629997688
Iteration 19, residual norm 0.04423721109971314
Iteration 20, residual norm 0.0446383631302859

In [5]:
A

array([[ 1.23599056e-01,  4.95799879e-02],
       [-6.88417340e-12,  4.89807637e-02],
       [ 1.97771012e-01,  5.00354889e-02],
       [ 2.97660174e-11,  4.86588823e-02],
       [ 1.16488475e-01,  4.82973254e-02],
       [-1.07464945e-11,  4.91216434e-02],
       [ 2.41321912e-11,  4.81319502e-02],
       [ 1.23285226e-01,  4.79485790e-02],
       [ 8.32457971e-12,  4.89433882e-02],
       [ 1.25287913e-11,  4.91104662e-02]])

In [6]:
B

array([[ 9.47630862e+00,  1.31432989e+01],
       [ 2.21280935e+00,  1.29931743e+01],
       [ 5.84494148e-10,  1.33764365e+01],
       [ 1.08369019e+01,  1.28286679e+01],
       [ 5.85242755e+00,  1.28260099e+01],
       [-2.53650233e-09,  1.30457870e+01],
       [ 2.49310302e+00,  1.27637902e+01],
       [-9.13914054e-10,  1.27890647e+01],
       [ 1.19440017e+00,  1.29894214e+01],
       [-5.49265493e-10,  1.30485092e+01]])

In [7]:
C

array([[ 1.52713520e-01,  1.58149369e+01],
       [ 4.52202894e-01,  1.58153740e+01],
       [ 5.73957401e-10,  1.56441552e+01],
       [-1.73226998e-12,  1.54265628e+01],
       [-1.68743313e-10,  1.57148110e+01]])

In [8]:
np.einsum('ir, jr, kr ->ijk', A, B, C)

array([[[10.48458589, 10.83565147, 10.19442928, 10.05263637,
         10.24047173],
        [10.22977218, 10.31196443, 10.07798707,  9.93781374,
         10.12350361],
        [10.48852243, 10.4888123 , 10.37525943, 10.23095138,
         10.42211858],
        [10.26356388, 10.66498719,  9.95038983,  9.81199123,
          9.99533009],
        [10.16739655, 10.38431148,  9.94832818,  9.80995826,
          9.99325913],
        [10.22925871, 10.22954142, 10.11879543,  9.97805451,
         10.16449628],
        [10.05520169, 10.1477645 ,  9.90006825,  9.76236957,
          9.94478124],
        [10.02796162, 10.02823877,  9.9196721 ,  9.78170076,
          9.96447363],
        [10.20760682, 10.25210095, 10.07507622,  9.93494338,
         10.12057961],
        [10.2313932 , 10.23167597, 10.12090687,  9.98013658,
         10.16661726]],

       [[10.18116321, 10.18144459, 10.07121931,  9.93114011,
         10.11670529],
        [10.06487253, 10.0651507 ,  9.95618442,  9.81770523,
         10.0