import tensorly as tl
Tensor Factorisation
Interactive tutorial on tensor factorisation with practical implementations and visualizations
import cvxpy as cp
import numpy as np
# Ensure repeatably random problem data.
0)
np.random.seed(
# Generate random data matrix A.
= 10
m = 10
n = 5
o = 2
k = 10*np.ones((m, n, o)) + np.random.randn(m, n, o)
D
# Initialize Y randomly.
= 10*np.ones((m, k))
A_init = 10*np.ones((n, k))
B_init = 10*np.ones((o, k))
C_init
= np.einsum('ir, jr, kr ->ijk', A_init, B_init, C_init) Pred_A
# Ensure same initial random Y, rather than generate new one
# when executing this cell.
= B_init
B = C_init
C
# Perform alternating minimization.
= 100
MAX_ITERS = np.zeros(MAX_ITERS)
residual for iter_num in range(0, 1+MAX_ITERS):
if iter_num % 3 == 0:
= cp.Variable(shape=(n, k))
A = [A >= 0]
constraint = A@tl.tenalg.khatri_rao([C, B]).T
prediction elif iter_num % 3 == 1:
= cp.Variable(shape=(m, k))
B = [B >= 0]
constraint = B@tl.tenalg.khatri_rao([A, C]).T
prediction elif iter_num % 3 == 2:
= cp.Variable(shape=(o, k))
C = [C >= 0]
constraint = C@tl.tenalg.khatri_rao([B, A]).T
prediction
= cp.Minimize(cp.norm(D.reshape(prediction.shape) - prediction, 'fro')/D.size)
obj = cp.Problem(obj, constraint)
prob =cp.SCS, max_iters=10000)
prob.solve(solver
if prob.status != cp.OPTIMAL:
raise Exception("Solver did not converge!")
print('Iteration {}, residual norm {}'.format(iter_num, prob.value))
-1] = prob.value
residual[iter_num
# Convert variable to NumPy array constant for next iteration.
if iter_num % 3 == 0:
= A.value
A elif iter_num%3 == 1:
= B.value
B else:
= C.value C
Iteration 0, residual norm 0.044229038768601577
Iteration 1, residual norm 0.04438975125966638
Iteration 2, residual norm 0.04485089174072711
Iteration 3, residual norm 0.0446730384004453
Iteration 4, residual norm 0.044526862069177754
Iteration 5, residual norm 0.04484264445045543
Iteration 6, residual norm 0.044478708695822676
Iteration 7, residual norm 0.04447482085818169
Iteration 8, residual norm 0.045090403949033964
Iteration 9, residual norm 0.04454108258260972
Iteration 10, residual norm 0.04409564845828368
Iteration 11, residual norm 0.04513476609369982
Iteration 12, residual norm 0.04452491993393663
Iteration 13, residual norm 0.04378836637021587
Iteration 14, residual norm 0.045190029913906464
Iteration 15, residual norm 0.04437455947694065
Iteration 16, residual norm 0.04376389953499031
Iteration 17, residual norm 0.04474118548369649
Iteration 18, residual norm 0.04485265629997688
Iteration 19, residual norm 0.04423721109971314
Iteration 20, residual norm 0.04463836313028598
Iteration 21, residual norm 0.04509480034427943
Iteration 22, residual norm 0.04449824189344165
Iteration 23, residual norm 0.0446297404908044
Iteration 24, residual norm 0.044916791718091605
Iteration 25, residual norm 0.04475629136811872
Iteration 26, residual norm 0.044530205009593614
Iteration 27, residual norm 0.04481551774462266
Iteration 28, residual norm 0.04450448414700286
Iteration 29, residual norm 0.04473161074066948
Iteration 30, residual norm 0.044760915844488124
Iteration 31, residual norm 0.04447831568798033
Iteration 32, residual norm 0.04458463344810684
Iteration 33, residual norm 0.04486246216584251
Iteration 34, residual norm 0.04441553594150665
Iteration 35, residual norm 0.04465894246345131
Iteration 36, residual norm 0.04504337082596328
Iteration 37, residual norm 0.044526299849176346
Iteration 38, residual norm 0.04482460772221398
Iteration 39, residual norm 0.04500800584243032
Iteration 40, residual norm 0.04450100433007423
Iteration 41, residual norm 0.04484822280377619
Iteration 42, residual norm 0.04482311584413178
Iteration 43, residual norm 0.044522803791676675
Iteration 44, residual norm 0.044772484551616504
Iteration 45, residual norm 0.04458696318356589
Iteration 46, residual norm 0.044520256516783666
Iteration 47, residual norm 0.04482402095966313
Iteration 48, residual norm 0.04426750617419428
Iteration 49, residual norm 0.044306760469399145
Iteration 50, residual norm 0.04488415609564971
Iteration 51, residual norm 0.04462703244119978
Iteration 52, residual norm 0.04393438972821902
Iteration 53, residual norm 0.04496061133340096
Iteration 54, residual norm 0.04474813896252984
Iteration 55, residual norm 0.04412129577027855
Iteration 56, residual norm 0.04476074674202858
Iteration 57, residual norm 0.04493986609489073
Iteration 58, residual norm 0.04433894515459506
Iteration 59, residual norm 0.04479359793862809
Iteration 60, residual norm 0.04507730499950036
Iteration 61, residual norm 0.04443726027910054
Iteration 62, residual norm 0.04482400811580273
Iteration 63, residual norm 0.044904850352243696
Iteration 64, residual norm 0.04455396547559336
Iteration 65, residual norm 0.04468608722592516
Iteration 66, residual norm 0.04458327810705387
Iteration 67, residual norm 0.0445511237555968
Iteration 68, residual norm 0.044703743776719096
Iteration 69, residual norm 0.04439204969221895
Iteration 70, residual norm 0.04419479288919463
Iteration 71, residual norm 0.04496926843613955
Iteration 72, residual norm 0.04456014579787714
Iteration 73, residual norm 0.04390674892508923
Iteration 74, residual norm 0.04471211816232015
Iteration 75, residual norm 0.04477670162347586
Iteration 76, residual norm 0.044259574198538376
Iteration 77, residual norm 0.044866509022554346
Iteration 78, residual norm 0.04499103987882674
Iteration 79, residual norm 0.04443576055416042
Iteration 80, residual norm 0.04494346910320409
Iteration 81, residual norm 0.04480069447165729
Iteration 82, residual norm 0.044256568121726014
Iteration 83, residual norm 0.044744496252969
Iteration 84, residual norm 0.04448314930970543
Iteration 85, residual norm 0.044155459981671225
Iteration 86, residual norm 0.04521780706935539
Iteration 87, residual norm 0.04446147094598418
Iteration 88, residual norm 0.04404940602099275
Iteration 89, residual norm 0.04516034088743775
Iteration 90, residual norm 0.04457572511910495
Iteration 91, residual norm 0.044197230653923565
Iteration 92, residual norm 0.045118201306481545
Iteration 93, residual norm 0.04421183315228502
Iteration 94, residual norm 0.04375429174860549
Iteration 95, residual norm 0.04441022241202105
Iteration 96, residual norm 0.04401508759046129
Iteration 97, residual norm 0.04393321896924147
Iteration 98, residual norm 0.04435341885307329
Iteration 99, residual norm 0.04411809866902199
Iteration 100, residual norm 0.044273830016112556
A
array([[ 1.23599056e-01, 4.95799879e-02],
[-6.88417340e-12, 4.89807637e-02],
[ 1.97771012e-01, 5.00354889e-02],
[ 2.97660174e-11, 4.86588823e-02],
[ 1.16488475e-01, 4.82973254e-02],
[-1.07464945e-11, 4.91216434e-02],
[ 2.41321912e-11, 4.81319502e-02],
[ 1.23285226e-01, 4.79485790e-02],
[ 8.32457971e-12, 4.89433882e-02],
[ 1.25287913e-11, 4.91104662e-02]])
B
array([[ 9.47630862e+00, 1.31432989e+01],
[ 2.21280935e+00, 1.29931743e+01],
[ 5.84494148e-10, 1.33764365e+01],
[ 1.08369019e+01, 1.28286679e+01],
[ 5.85242755e+00, 1.28260099e+01],
[-2.53650233e-09, 1.30457870e+01],
[ 2.49310302e+00, 1.27637902e+01],
[-9.13914054e-10, 1.27890647e+01],
[ 1.19440017e+00, 1.29894214e+01],
[-5.49265493e-10, 1.30485092e+01]])
C
array([[ 1.52713520e-01, 1.58149369e+01],
[ 4.52202894e-01, 1.58153740e+01],
[ 5.73957401e-10, 1.56441552e+01],
[-1.73226998e-12, 1.54265628e+01],
[-1.68743313e-10, 1.57148110e+01]])
'ir, jr, kr ->ijk', A, B, C) np.einsum(
array([[[10.48458589, 10.83565147, 10.19442928, 10.05263637,
10.24047173],
[10.22977218, 10.31196443, 10.07798707, 9.93781374,
10.12350361],
[10.48852243, 10.4888123 , 10.37525943, 10.23095138,
10.42211858],
[10.26356388, 10.66498719, 9.95038983, 9.81199123,
9.99533009],
[10.16739655, 10.38431148, 9.94832818, 9.80995826,
9.99325913],
[10.22925871, 10.22954142, 10.11879543, 9.97805451,
10.16449628],
[10.05520169, 10.1477645 , 9.90006825, 9.76236957,
9.94478124],
[10.02796162, 10.02823877, 9.9196721 , 9.78170076,
9.96447363],
[10.20760682, 10.25210095, 10.07507622, 9.93494338,
10.12057961],
[10.2313932 , 10.23167597, 10.12090687, 9.98013658,
10.16661726]],
[[10.18116321, 10.18144459, 10.07121931, 9.93114011,
10.11670529],
[10.06487253, 10.0651507 , 9.95618442, 9.81770523,
10.00115085],
[10.36175805, 10.36204442, 10.24986394, 10.1073 ,
10.29615676],
[ 9.93744133, 9.93771598, 9.83012932, 9.69340341,
9.87452643],
[ 9.93538237, 9.93565696, 9.82809259, 9.69139501,
9.87248051],
[10.10562779, 10.10590708, 9.99649957, 9.85745964,
10.04164808],
[ 9.88718524, 9.88745849, 9.78041593, 9.64438148,
9.82458851],
[ 9.90676358, 9.90703738, 9.79978285, 9.66347903,
9.84404291],
[10.06196547, 10.06224355, 9.95330875, 9.81486955,
9.99826219],
[10.10773648, 10.10801583, 9.99858549, 9.85951655,
10.04374343]],
[[10.68660524, 11.24817743, 10.28808747, 10.14499189,
10.33455292],
[10.34843597, 10.47978553, 10.17057548, 10.02911436,
10.2165102 ],
[10.58488251, 10.58517505, 10.47057894, 10.32494511,
10.5178686 ],
[10.47872824, 11.12088193, 10.04180598, 9.90213589,
10.08715912],
[10.32608231, 10.67300395, 10.0397254 , 9.90008425,
10.08506914],
[10.32323689, 10.32352219, 10.21175876, 10.06972482,
10.25787948],
[10.17538798, 10.3233344 , 9.99102209, 9.85205835,
10.03614586],
[10.12009045, 10.12037014, 10.01080605, 9.87156714,
10.05601918],
[10.314708 , 10.38573678, 10.16763789, 10.02621762,
10.21355934],
[10.32539099, 10.32567635, 10.2138896 , 10.07182603,
10.26001994]],
[[10.11425681, 10.11453634, 10.00503541, 9.86587676,
10.05022247],
[ 9.99873034, 9.99900668, 9.89075648, 9.75318732,
9.93542741],
[10.29366485, 10.29394934, 10.18250606, 10.040879 ,
10.22849466],
[ 9.87213657, 9.87240941, 9.76552976, 9.62970236,
9.80963512],
[ 9.87009114, 9.87036392, 9.76350642, 9.62770716,
9.80760264],
[10.03921777, 10.03949523, 9.9308067 , 9.79268048,
9.97565852],
[ 9.82221073, 9.82248219, 9.71614307, 9.58100258,
9.76002537],
[ 9.84166042, 9.84193242, 9.73538272, 9.59997463,
9.77935192],
[ 9.99584238, 9.99611864, 9.88789971, 9.75037028,
9.93255774],
[10.04131261, 10.04159012, 9.93287891, 9.79472388,
9.97774009]],
[[10.20768094, 10.53855895, 9.93069359, 9.79256895,
9.97554489],
[ 9.96379984, 10.04127254, 9.8172638 , 9.68071684,
9.86160281],
[10.21717838, 10.21746076, 10.10684556, 9.96627084,
10.15249244],
[ 9.99156385, 10.36990232, 9.69296758, 9.55814943,
9.73674521],
[ 9.90086298, 10.10530773, 9.69095927, 9.55616906,
9.73472783],
[ 9.96462196, 9.96489736, 9.85701643, 9.71991656,
9.90153498],
[ 9.7935781 , 9.88082458, 9.64394785, 9.50981151,
9.68750408],
[ 9.76853255, 9.76880252, 9.66304454, 9.52864259,
9.70668703],
[ 9.94281649, 9.98475981, 9.81442826, 9.67792073,
9.85875446],
[ 9.96670123, 9.96697669, 9.85907325, 9.72194476,
9.90360109]],
[[10.21044653, 10.21072872, 10.1001864 , 9.95970431,
10.14580321],
[10.09382137, 10.09410034, 9.98482065, 9.84594316,
10.02991642],
[10.3915608 , 10.39184799, 10.27934486, 10.13637088,
10.32577082],
[ 9.96602365, 9.96629909, 9.85840299, 9.72128382,
9.9029278 ],
[ 9.96395877, 9.96423415, 9.8563604 , 9.71926965,
9.90087599],
[10.13469385, 10.13497395, 10.02525176, 9.88581192,
10.07053013],
[ 9.91562301, 9.91589705, 9.80854661, 9.67212089,
9.85284625],
[ 9.93525767, 9.93553225, 9.82796924, 9.69127337,
9.87235659],
[10.09090595, 10.09118483, 9.98193671, 9.84309933,
10.02701945],
[10.13680861, 10.13708877, 10.02734368, 9.88787475,
10.0726315 ]],
[[10.00472845, 10.00500495, 9.89668982, 9.75903813,
9.94138755],
[ 9.89045303, 9.89072638, 9.78364843, 9.64756902,
9.82783562],
[10.18219366, 10.18247507, 10.07223863, 9.93214525,
10.11772921],
[ 9.76523015, 9.76550004, 9.65977781, 9.5254213 ,
9.70340554],
[ 9.76320687, 9.7634767 , 9.65777638, 9.5234477 ,
9.70139507],
[ 9.93050202, 9.93077647, 9.82326494, 9.68663451,
9.86763105],
[ 9.71584497, 9.71611349, 9.61092593, 9.47724889,
9.65433302],
[ 9.73508404, 9.73535309, 9.62995723, 9.49601549,
9.67345028],
[ 9.88759634, 9.88786961, 9.78082259, 9.64478249,
9.82499702],
[ 9.93257417, 9.93284868, 9.82531472, 9.68865578,
9.86969008]],
[[10.14502635, 10.4951919 , 9.85898582, 9.72185855,
9.90351326],
[ 9.89443406, 9.97640907, 9.74637509, 9.61081411,
9.79039394],
[10.14340196, 10.1436823 , 10.03386583, 9.89430618,
10.0791831 ],
[ 9.93205681, 10.33245243, 9.62297639, 9.48913174,
9.66643791],
[ 9.83619694, 10.05255267, 9.62098258, 9.48716567,
9.6644351 ],
[ 9.8926692 , 9.89294261, 9.78584067, 9.64973077,
9.83003776],
[ 9.7257684 , 9.81808778, 9.57431062, 9.44114286,
9.61755234],
[ 9.69799572, 9.69826374, 9.59326942, 9.45983796,
9.63659677],
[ 9.87241434, 9.91678695, 9.74356002, 9.60803819,
9.78756615],
[ 9.89473346, 9.89500692, 9.78788264, 9.65174433,
9.83208895]],
[[10.17339432, 10.17367549, 10.06353431, 9.92356201,
10.10898558],
[10.05719238, 10.05747033, 9.9485872 , 9.81021368,
9.99351932],
[10.35385135, 10.3541375 , 10.24204263, 10.09958747,
10.28830012],
[ 9.92985842, 9.93013285, 9.82262829, 9.68600671,
9.86699153],
[ 9.92780103, 9.92807541, 9.82059312, 9.68399985,
9.86494716],
[10.09791654, 10.09819562, 9.98887159, 9.84993776,
10.03398565],
[ 9.87964067, 9.87991372, 9.77295283, 9.63702219,
9.81709171],
[ 9.89920408, 9.89947766, 9.79230498, 9.65610517,
9.83653126],
[10.05428753, 10.0545654 , 9.94571372, 9.80738017,
9.99063287],
[10.10002362, 10.10030276, 9.99095592, 9.8519931 ,
10.0360794 ]],
[[10.20812322, 10.20840535, 10.09788818, 9.95743805,
10.14349461],
[10.0915246 , 10.0918035 , 9.98254868, 9.84370279,
10.02763418],
[10.38919627, 10.3894834 , 10.27700587, 10.13406442,
10.32342127],
[ 9.96375596, 9.96403133, 9.85615978, 9.71907182,
9.90067446],
[ 9.96169154, 9.96196686, 9.85411766, 9.7170581 ,
9.89862311],
[10.13238778, 10.13266781, 10.02297059, 9.88356248,
10.06823865],
[ 9.91336678, 9.91364076, 9.80631475, 9.66992007,
9.8506043 ],
[ 9.93299697, 9.93327149, 9.82573295, 9.68906819,
9.87011021],
[10.08860984, 10.08888866, 9.97966539, 9.84085961,
10.02473787],
[10.13450206, 10.13478215, 10.02506203, 9.88562483,
10.07033954]]])