Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/numpy/core/einsumfunc.py: 3%

411 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2023-07-17 14:22 -0600

1""" 

2Implementation of optimized einsum. 

3 

4""" 

5import itertools 

6import operator 

7 

8from numpy.core.multiarray import c_einsum 

9from numpy.core.numeric import asanyarray, tensordot 

10from numpy.core.overrides import array_function_dispatch 

11 

12__all__ = ['einsum', 'einsum_path'] 

13 

14einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 

15einsum_symbols_set = set(einsum_symbols) 

16 

17 

18def _flop_count(idx_contraction, inner, num_terms, size_dictionary): 

19 """ 

20 Computes the number of FLOPS in the contraction. 

21 

22 Parameters 

23 ---------- 

24 idx_contraction : iterable 

25 The indices involved in the contraction 

26 inner : bool 

27 Does this contraction require an inner product? 

28 num_terms : int 

29 The number of terms in a contraction 

30 size_dictionary : dict 

31 The size of each of the indices in idx_contraction 

32 

33 Returns 

34 ------- 

35 flop_count : int 

36 The total number of FLOPS required for the contraction. 

37 

38 Examples 

39 -------- 

40 

41 >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 

42 30 

43 

44 >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5}) 

45 60 

46 

47 """ 

48 

49 overall_size = _compute_size_by_dict(idx_contraction, size_dictionary) 

50 op_factor = max(1, num_terms - 1) 

51 if inner: 

52 op_factor += 1 

53 

54 return overall_size * op_factor 

55 

56def _compute_size_by_dict(indices, idx_dict): 

57 """ 

58 Computes the product of the elements in indices based on the dictionary 

59 idx_dict. 

60 

61 Parameters 

62 ---------- 

63 indices : iterable 

64 Indices to base the product on. 

65 idx_dict : dictionary 

66 Dictionary of index sizes 

67 

68 Returns 

69 ------- 

70 ret : int 

71 The resulting product. 

72 

73 Examples 

74 -------- 

75 >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 

76 90 

77 

78 """ 

79 ret = 1 

80 for i in indices: 

81 ret *= idx_dict[i] 

82 return ret 

83 

84 

85def _find_contraction(positions, input_sets, output_set): 

86 """ 

87 Finds the contraction for a given set of input and output sets. 

88 

89 Parameters 

90 ---------- 

91 positions : iterable 

92 Integer positions of terms used in the contraction. 

93 input_sets : list 

94 List of sets that represent the lhs side of the einsum subscript 

95 output_set : set 

96 Set that represents the rhs side of the overall einsum subscript 

97 

98 Returns 

99 ------- 

100 new_result : set 

101 The indices of the resulting contraction 

102 remaining : list 

103 List of sets that have not been contracted, the new set is appended to 

104 the end of this list 

105 idx_removed : set 

106 Indices removed from the entire contraction 

107 idx_contraction : set 

108 The indices used in the current contraction 

109 

110 Examples 

111 -------- 

112 

113 # A simple dot product test case 

114 >>> pos = (0, 1) 

115 >>> isets = [set('ab'), set('bc')] 

116 >>> oset = set('ac') 

117 >>> _find_contraction(pos, isets, oset) 

118 ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'}) 

119 

120 # A more complex case with additional terms in the contraction 

121 >>> pos = (0, 2) 

122 >>> isets = [set('abd'), set('ac'), set('bdc')] 

123 >>> oset = set('ac') 

124 >>> _find_contraction(pos, isets, oset) 

125 ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) 

126 """ 

127 

128 idx_contract = set() 

129 idx_remain = output_set.copy() 

130 remaining = [] 

131 for ind, value in enumerate(input_sets): 

132 if ind in positions: 

133 idx_contract |= value 

134 else: 

135 remaining.append(value) 

136 idx_remain |= value 

137 

138 new_result = idx_remain & idx_contract 

139 idx_removed = (idx_contract - new_result) 

140 remaining.append(new_result) 

141 

142 return (new_result, remaining, idx_removed, idx_contract) 

143 

144 

145def _optimal_path(input_sets, output_set, idx_dict, memory_limit): 

146 """ 

147 Computes all possible pair contractions, sieves the results based 

148 on ``memory_limit`` and returns the lowest cost path. This algorithm 

149 scales factorial with respect to the elements in the list ``input_sets``. 

150 

151 Parameters 

152 ---------- 

153 input_sets : list 

154 List of sets that represent the lhs side of the einsum subscript 

155 output_set : set 

156 Set that represents the rhs side of the overall einsum subscript 

157 idx_dict : dictionary 

158 Dictionary of index sizes 

159 memory_limit : int 

160 The maximum number of elements in a temporary array 

161 

162 Returns 

163 ------- 

164 path : list 

165 The optimal contraction order within the memory limit constraint. 

166 

167 Examples 

168 -------- 

169 >>> isets = [set('abd'), set('ac'), set('bdc')] 

170 >>> oset = set() 

171 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

172 >>> _optimal_path(isets, oset, idx_sizes, 5000) 

173 [(0, 2), (0, 1)] 

174 """ 

175 

176 full_results = [(0, [], input_sets)] 

177 for iteration in range(len(input_sets) - 1): 

178 iter_results = [] 

179 

180 # Compute all unique pairs 

181 for curr in full_results: 

182 cost, positions, remaining = curr 

183 for con in itertools.combinations(range(len(input_sets) - iteration), 2): 

184 

185 # Find the contraction 

186 cont = _find_contraction(con, remaining, output_set) 

187 new_result, new_input_sets, idx_removed, idx_contract = cont 

188 

189 # Sieve the results based on memory_limit 

190 new_size = _compute_size_by_dict(new_result, idx_dict) 

191 if new_size > memory_limit: 

192 continue 

193 

194 # Build (total_cost, positions, indices_remaining) 

195 total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict) 

196 new_pos = positions + [con] 

197 iter_results.append((total_cost, new_pos, new_input_sets)) 

198 

199 # Update combinatorial list, if we did not find anything return best 

200 # path + remaining contractions 

201 if iter_results: 

202 full_results = iter_results 

203 else: 

204 path = min(full_results, key=lambda x: x[0])[1] 

205 path += [tuple(range(len(input_sets) - iteration))] 

206 return path 

207 

208 # If we have not found anything return single einsum contraction 

209 if len(full_results) == 0: 

210 return [tuple(range(len(input_sets)))] 

211 

212 path = min(full_results, key=lambda x: x[0])[1] 

213 return path 

214 

215def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost): 

216 """Compute the cost (removed size + flops) and resultant indices for 

217 performing the contraction specified by ``positions``. 

218 

219 Parameters 

220 ---------- 

221 positions : tuple of int 

222 The locations of the proposed tensors to contract. 

223 input_sets : list of sets 

224 The indices found on each tensors. 

225 output_set : set 

226 The output indices of the expression. 

227 idx_dict : dict 

228 Mapping of each index to its size. 

229 memory_limit : int 

230 The total allowed size for an intermediary tensor. 

231 path_cost : int 

232 The contraction cost so far. 

233 naive_cost : int 

234 The cost of the unoptimized expression. 

235 

236 Returns 

237 ------- 

238 cost : (int, int) 

239 A tuple containing the size of any indices removed, and the flop cost. 

240 positions : tuple of int 

241 The locations of the proposed tensors to contract. 

242 new_input_sets : list of sets 

243 The resulting new list of indices if this proposed contraction is performed. 

244 

245 """ 

246 

247 # Find the contraction 

248 contract = _find_contraction(positions, input_sets, output_set) 

249 idx_result, new_input_sets, idx_removed, idx_contract = contract 

250 

251 # Sieve the results based on memory_limit 

252 new_size = _compute_size_by_dict(idx_result, idx_dict) 

253 if new_size > memory_limit: 

254 return None 

255 

256 # Build sort tuple 

257 old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions) 

258 removed_size = sum(old_sizes) - new_size 

259 

260 # NB: removed_size used to be just the size of any removed indices i.e.: 

261 # helpers.compute_size_by_dict(idx_removed, idx_dict) 

262 cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict) 

263 sort = (-removed_size, cost) 

264 

265 # Sieve based on total cost as well 

266 if (path_cost + cost) > naive_cost: 

267 return None 

268 

269 # Add contraction to possible choices 

270 return [sort, positions, new_input_sets] 

271 

272 

273def _update_other_results(results, best): 

274 """Update the positions and provisional input_sets of ``results`` based on 

275 performing the contraction result ``best``. Remove any involving the tensors 

276 contracted. 

277 

278 Parameters 

279 ---------- 

280 results : list 

281 List of contraction results produced by ``_parse_possible_contraction``. 

282 best : list 

283 The best contraction of ``results`` i.e. the one that will be performed. 

284 

285 Returns 

286 ------- 

287 mod_results : list 

288 The list of modified results, updated with outcome of ``best`` contraction. 

289 """ 

290 

291 best_con = best[1] 

292 bx, by = best_con 

293 mod_results = [] 

294 

295 for cost, (x, y), con_sets in results: 

296 

297 # Ignore results involving tensors just contracted 

298 if x in best_con or y in best_con: 

299 continue 

300 

301 # Update the input_sets 

302 del con_sets[by - int(by > x) - int(by > y)] 

303 del con_sets[bx - int(bx > x) - int(bx > y)] 

304 con_sets.insert(-1, best[2][-1]) 

305 

306 # Update the position indices 

307 mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by) 

308 mod_results.append((cost, mod_con, con_sets)) 

309 

310 return mod_results 

311 

312def _greedy_path(input_sets, output_set, idx_dict, memory_limit): 

313 """ 

314 Finds the path by contracting the best pair until the input list is 

315 exhausted. The best pair is found by minimizing the tuple 

316 ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing 

317 matrix multiplication or inner product operations, then Hadamard like 

318 operations, and finally outer operations. Outer products are limited by 

319 ``memory_limit``. This algorithm scales cubically with respect to the 

320 number of elements in the list ``input_sets``. 

321 

322 Parameters 

323 ---------- 

324 input_sets : list 

325 List of sets that represent the lhs side of the einsum subscript 

326 output_set : set 

327 Set that represents the rhs side of the overall einsum subscript 

328 idx_dict : dictionary 

329 Dictionary of index sizes 

330 memory_limit : int 

331 The maximum number of elements in a temporary array 

332 

333 Returns 

334 ------- 

335 path : list 

336 The greedy contraction order within the memory limit constraint. 

337 

338 Examples 

339 -------- 

340 >>> isets = [set('abd'), set('ac'), set('bdc')] 

341 >>> oset = set() 

342 >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4} 

343 >>> _greedy_path(isets, oset, idx_sizes, 5000) 

344 [(0, 2), (0, 1)] 

345 """ 

346 

347 # Handle trivial cases that leaked through 

348 if len(input_sets) == 1: 

349 return [(0,)] 

350 elif len(input_sets) == 2: 

351 return [(0, 1)] 

352 

353 # Build up a naive cost 

354 contract = _find_contraction(range(len(input_sets)), input_sets, output_set) 

355 idx_result, new_input_sets, idx_removed, idx_contract = contract 

356 naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict) 

357 

358 # Initially iterate over all pairs 

359 comb_iter = itertools.combinations(range(len(input_sets)), 2) 

360 known_contractions = [] 

361 

362 path_cost = 0 

363 path = [] 

364 

365 for iteration in range(len(input_sets) - 1): 

366 

367 # Iterate over all pairs on first step, only previously found pairs on subsequent steps 

368 for positions in comb_iter: 

369 

370 # Always initially ignore outer products 

371 if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]): 

372 continue 

373 

374 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, 

375 naive_cost) 

376 if result is not None: 

377 known_contractions.append(result) 

378 

379 # If we do not have a inner contraction, rescan pairs including outer products 

380 if len(known_contractions) == 0: 

381 

382 # Then check the outer products 

383 for positions in itertools.combinations(range(len(input_sets)), 2): 

384 result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, 

385 path_cost, naive_cost) 

386 if result is not None: 

387 known_contractions.append(result) 

388 

389 # If we still did not find any remaining contractions, default back to einsum like behavior 

390 if len(known_contractions) == 0: 

391 path.append(tuple(range(len(input_sets)))) 

392 break 

393 

394 # Sort based on first index 

395 best = min(known_contractions, key=lambda x: x[0]) 

396 

397 # Now propagate as many unused contractions as possible to next iteration 

398 known_contractions = _update_other_results(known_contractions, best) 

399 

400 # Next iteration only compute contractions with the new tensor 

401 # All other contractions have been accounted for 

402 input_sets = best[2] 

403 new_tensor_pos = len(input_sets) - 1 

404 comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos)) 

405 

406 # Update path and total cost 

407 path.append(best[1]) 

408 path_cost += best[0][1] 

409 

410 return path 

411 

412 

413def _can_dot(inputs, result, idx_removed): 

414 """ 

415 Checks if we can use BLAS (np.tensordot) call and its beneficial to do so. 

416 

417 Parameters 

418 ---------- 

419 inputs : list of str 

420 Specifies the subscripts for summation. 

421 result : str 

422 Resulting summation. 

423 idx_removed : set 

424 Indices that are removed in the summation 

425 

426 

427 Returns 

428 ------- 

429 type : bool 

430 Returns true if BLAS should and can be used, else False 

431 

432 Notes 

433 ----- 

434 If the operations is BLAS level 1 or 2 and is not already aligned 

435 we default back to einsum as the memory movement to copy is more 

436 costly than the operation itself. 

437 

438 

439 Examples 

440 -------- 

441 

442 # Standard GEMM operation 

443 >>> _can_dot(['ij', 'jk'], 'ik', set('j')) 

444 True 

445 

446 # Can use the standard BLAS, but requires odd data movement 

447 >>> _can_dot(['ijj', 'jk'], 'ik', set('j')) 

448 False 

449 

450 # DDOT where the memory is not aligned 

451 >>> _can_dot(['ijk', 'ikj'], '', set('ijk')) 

452 False 

453 

454 """ 

455 

456 # All `dot` calls remove indices 

457 if len(idx_removed) == 0: 

458 return False 

459 

460 # BLAS can only handle two operands 

461 if len(inputs) != 2: 

462 return False 

463 

464 input_left, input_right = inputs 

465 

466 for c in set(input_left + input_right): 

467 # can't deal with repeated indices on same input or more than 2 total 

468 nl, nr = input_left.count(c), input_right.count(c) 

469 if (nl > 1) or (nr > 1) or (nl + nr > 2): 

470 return False 

471 

472 # can't do implicit summation or dimension collapse e.g. 

473 # "ab,bc->c" (implicitly sum over 'a') 

474 # "ab,ca->ca" (take diagonal of 'a') 

475 if nl + nr - 1 == int(c in result): 

476 return False 

477 

478 # Build a few temporaries 

479 set_left = set(input_left) 

480 set_right = set(input_right) 

481 keep_left = set_left - idx_removed 

482 keep_right = set_right - idx_removed 

483 rs = len(idx_removed) 

484 

485 # At this point we are a DOT, GEMV, or GEMM operation 

486 

487 # Handle inner products 

488 

489 # DDOT with aligned data 

490 if input_left == input_right: 

491 return True 

492 

493 # DDOT without aligned data (better to use einsum) 

494 if set_left == set_right: 

495 return False 

496 

497 # Handle the 4 possible (aligned) GEMV or GEMM cases 

498 

499 # GEMM or GEMV no transpose 

500 if input_left[-rs:] == input_right[:rs]: 

501 return True 

502 

503 # GEMM or GEMV transpose both 

504 if input_left[:rs] == input_right[-rs:]: 

505 return True 

506 

507 # GEMM or GEMV transpose right 

508 if input_left[-rs:] == input_right[-rs:]: 

509 return True 

510 

511 # GEMM or GEMV transpose left 

512 if input_left[:rs] == input_right[:rs]: 

513 return True 

514 

515 # Einsum is faster than GEMV if we have to copy data 

516 if not keep_left or not keep_right: 

517 return False 

518 

519 # We are a matrix-matrix product, but we need to copy data 

520 return True 

521 

522 

523def _parse_einsum_input(operands): 

524 """ 

525 A reproduction of einsum c side einsum parsing in python. 

526 

527 Returns 

528 ------- 

529 input_strings : str 

530 Parsed input strings 

531 output_string : str 

532 Parsed output string 

533 operands : list of array_like 

534 The operands to use in the numpy contraction 

535 

536 Examples 

537 -------- 

538 The operand list is simplified to reduce printing: 

539 

540 >>> np.random.seed(123) 

541 >>> a = np.random.rand(4, 4) 

542 >>> b = np.random.rand(4, 4, 4) 

543 >>> _parse_einsum_input(('...a,...a->...', a, b)) 

544 ('za,xza', 'xz', [a, b]) # may vary 

545 

546 >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) 

547 ('za,xza', 'xz', [a, b]) # may vary 

548 """ 

549 

550 if len(operands) == 0: 

551 raise ValueError("No input operands") 

552 

553 if isinstance(operands[0], str): 

554 subscripts = operands[0].replace(" ", "") 

555 operands = [asanyarray(v) for v in operands[1:]] 

556 

557 # Ensure all characters are valid 

558 for s in subscripts: 

559 if s in '.,->': 

560 continue 

561 if s not in einsum_symbols: 

562 raise ValueError("Character %s is not a valid symbol." % s) 

563 

564 else: 

565 tmp_operands = list(operands) 

566 operand_list = [] 

567 subscript_list = [] 

568 for p in range(len(operands) // 2): 

569 operand_list.append(tmp_operands.pop(0)) 

570 subscript_list.append(tmp_operands.pop(0)) 

571 

572 output_list = tmp_operands[-1] if len(tmp_operands) else None 

573 operands = [asanyarray(v) for v in operand_list] 

574 subscripts = "" 

575 last = len(subscript_list) - 1 

576 for num, sub in enumerate(subscript_list): 

577 for s in sub: 

578 if s is Ellipsis: 

579 subscripts += "..." 

580 else: 

581 try: 

582 s = operator.index(s) 

583 except TypeError as e: 

584 raise TypeError("For this input type lists must contain " 

585 "either int or Ellipsis") from e 

586 subscripts += einsum_symbols[s] 

587 if num != last: 

588 subscripts += "," 

589 

590 if output_list is not None: 

591 subscripts += "->" 

592 for s in output_list: 

593 if s is Ellipsis: 

594 subscripts += "..." 

595 else: 

596 try: 

597 s = operator.index(s) 

598 except TypeError as e: 

599 raise TypeError("For this input type lists must contain " 

600 "either int or Ellipsis") from e 

601 subscripts += einsum_symbols[s] 

602 # Check for proper "->" 

603 if ("-" in subscripts) or (">" in subscripts): 

604 invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) 

605 if invalid or (subscripts.count("->") != 1): 

606 raise ValueError("Subscripts can only contain one '->'.") 

607 

608 # Parse ellipses 

609 if "." in subscripts: 

610 used = subscripts.replace(".", "").replace(",", "").replace("->", "") 

611 unused = list(einsum_symbols_set - set(used)) 

612 ellipse_inds = "".join(unused) 

613 longest = 0 

614 

615 if "->" in subscripts: 

616 input_tmp, output_sub = subscripts.split("->") 

617 split_subscripts = input_tmp.split(",") 

618 out_sub = True 

619 else: 

620 split_subscripts = subscripts.split(',') 

621 out_sub = False 

622 

623 for num, sub in enumerate(split_subscripts): 

624 if "." in sub: 

625 if (sub.count(".") != 3) or (sub.count("...") != 1): 

626 raise ValueError("Invalid Ellipses.") 

627 

628 # Take into account numerical values 

629 if operands[num].shape == (): 

630 ellipse_count = 0 

631 else: 

632 ellipse_count = max(operands[num].ndim, 1) 

633 ellipse_count -= (len(sub) - 3) 

634 

635 if ellipse_count > longest: 

636 longest = ellipse_count 

637 

638 if ellipse_count < 0: 

639 raise ValueError("Ellipses lengths do not match.") 

640 elif ellipse_count == 0: 

641 split_subscripts[num] = sub.replace('...', '') 

642 else: 

643 rep_inds = ellipse_inds[-ellipse_count:] 

644 split_subscripts[num] = sub.replace('...', rep_inds) 

645 

646 subscripts = ",".join(split_subscripts) 

647 if longest == 0: 

648 out_ellipse = "" 

649 else: 

650 out_ellipse = ellipse_inds[-longest:] 

651 

652 if out_sub: 

653 subscripts += "->" + output_sub.replace("...", out_ellipse) 

654 else: 

655 # Special care for outputless ellipses 

656 output_subscript = "" 

657 tmp_subscripts = subscripts.replace(",", "") 

658 for s in sorted(set(tmp_subscripts)): 

659 if s not in (einsum_symbols): 

660 raise ValueError("Character %s is not a valid symbol." % s) 

661 if tmp_subscripts.count(s) == 1: 

662 output_subscript += s 

663 normal_inds = ''.join(sorted(set(output_subscript) - 

664 set(out_ellipse))) 

665 

666 subscripts += "->" + out_ellipse + normal_inds 

667 

668 # Build output string if does not exist 

669 if "->" in subscripts: 

670 input_subscripts, output_subscript = subscripts.split("->") 

671 else: 

672 input_subscripts = subscripts 

673 # Build output subscripts 

674 tmp_subscripts = subscripts.replace(",", "") 

675 output_subscript = "" 

676 for s in sorted(set(tmp_subscripts)): 

677 if s not in einsum_symbols: 

678 raise ValueError("Character %s is not a valid symbol." % s) 

679 if tmp_subscripts.count(s) == 1: 

680 output_subscript += s 

681 

682 # Make sure output subscripts are in the input 

683 for char in output_subscript: 

684 if char not in input_subscripts: 

685 raise ValueError("Output character %s did not appear in the input" 

686 % char) 

687 

688 # Make sure number operands is equivalent to the number of terms 

689 if len(input_subscripts.split(',')) != len(operands): 

690 raise ValueError("Number of einsum subscripts must be equal to the " 

691 "number of operands.") 

692 

693 return (input_subscripts, output_subscript, operands) 

694 

695 

696def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None): 

697 # NOTE: technically, we should only dispatch on array-like arguments, not 

698 # subscripts (given as strings). But separating operands into 

699 # arrays/subscripts is a little tricky/slow (given einsum's two supported 

700 # signatures), so as a practical shortcut we dispatch on everything. 

701 # Strings will be ignored for dispatching since they don't define 

702 # __array_function__. 

703 return operands 

704 

705 

706@array_function_dispatch(_einsum_path_dispatcher, module='numpy') 

707def einsum_path(*operands, optimize='greedy', einsum_call=False): 

708 """ 

709 einsum_path(subscripts, *operands, optimize='greedy') 

710 

711 Evaluates the lowest cost contraction order for an einsum expression by 

712 considering the creation of intermediate arrays. 

713 

714 Parameters 

715 ---------- 

716 subscripts : str 

717 Specifies the subscripts for summation. 

718 *operands : list of array_like 

719 These are the arrays for the operation. 

720 optimize : {bool, list, tuple, 'greedy', 'optimal'} 

721 Choose the type of path. If a tuple is provided, the second argument is 

722 assumed to be the maximum intermediate size created. If only a single 

723 argument is provided the largest input or output array size is used 

724 as a maximum intermediate size. 

725 

726 * if a list is given that starts with ``einsum_path``, uses this as the 

727 contraction path 

728 * if False no optimization is taken 

729 * if True defaults to the 'greedy' algorithm 

730 * 'optimal' An algorithm that combinatorially explores all possible 

731 ways of contracting the listed tensors and choosest the least costly 

732 path. Scales exponentially with the number of terms in the 

733 contraction. 

734 * 'greedy' An algorithm that chooses the best pair contraction 

735 at each step. Effectively, this algorithm searches the largest inner, 

736 Hadamard, and then outer products at each step. Scales cubically with 

737 the number of terms in the contraction. Equivalent to the 'optimal' 

738 path for most contractions. 

739 

740 Default is 'greedy'. 

741 

742 Returns 

743 ------- 

744 path : list of tuples 

745 A list representation of the einsum path. 

746 string_repr : str 

747 A printable representation of the einsum path. 

748 

749 Notes 

750 ----- 

751 The resulting path indicates which terms of the input contraction should be 

752 contracted first, the result of this contraction is then appended to the 

753 end of the contraction list. This list can then be iterated over until all 

754 intermediate contractions are complete. 

755 

756 See Also 

757 -------- 

758 einsum, linalg.multi_dot 

759 

760 Examples 

761 -------- 

762 

763 We can begin with a chain dot example. In this case, it is optimal to 

764 contract the ``b`` and ``c`` tensors first as represented by the first 

765 element of the path ``(1, 2)``. The resulting tensor is added to the end 

766 of the contraction and the remaining contraction ``(0, 1)`` is then 

767 completed. 

768 

769 >>> np.random.seed(123) 

770 >>> a = np.random.rand(2, 2) 

771 >>> b = np.random.rand(2, 5) 

772 >>> c = np.random.rand(5, 2) 

773 >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') 

774 >>> print(path_info[0]) 

775 ['einsum_path', (1, 2), (0, 1)] 

776 >>> print(path_info[1]) 

777 Complete contraction: ij,jk,kl->il # may vary 

778 Naive scaling: 4 

779 Optimized scaling: 3 

780 Naive FLOP count: 1.600e+02 

781 Optimized FLOP count: 5.600e+01 

782 Theoretical speedup: 2.857 

783 Largest intermediate: 4.000e+00 elements 

784 ------------------------------------------------------------------------- 

785 scaling current remaining 

786 ------------------------------------------------------------------------- 

787 3 kl,jk->jl ij,jl->il 

788 3 jl,ij->il il->il 

789 

790 

791 A more complex index transformation example. 

792 

793 >>> I = np.random.rand(10, 10, 10, 10) 

794 >>> C = np.random.rand(10, 10) 

795 >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, 

796 ... optimize='greedy') 

797 

798 >>> print(path_info[0]) 

799 ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)] 

800 >>> print(path_info[1])  

801 Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary 

802 Naive scaling: 8 

803 Optimized scaling: 5 

804 Naive FLOP count: 8.000e+08 

805 Optimized FLOP count: 8.000e+05 

806 Theoretical speedup: 1000.000 

807 Largest intermediate: 1.000e+04 elements 

808 -------------------------------------------------------------------------- 

809 scaling current remaining 

810 -------------------------------------------------------------------------- 

811 5 abcd,ea->bcde fb,gc,hd,bcde->efgh 

812 5 bcde,fb->cdef gc,hd,cdef->efgh 

813 5 cdef,gc->defg hd,defg->efgh 

814 5 defg,hd->efgh efgh->efgh 

815 """ 

816 

817 # Figure out what the path really is 

818 path_type = optimize 

819 if path_type is True: 

820 path_type = 'greedy' 

821 if path_type is None: 

822 path_type = False 

823 

824 explicit_einsum_path = False 

825 memory_limit = None 

826 

827 # No optimization or a named path algorithm 

828 if (path_type is False) or isinstance(path_type, str): 

829 pass 

830 

831 # Given an explicit path 

832 elif len(path_type) and (path_type[0] == 'einsum_path'): 

833 explicit_einsum_path = True 

834 

835 # Path tuple with memory limit 

836 elif ((len(path_type) == 2) and isinstance(path_type[0], str) and 

837 isinstance(path_type[1], (int, float))): 

838 memory_limit = int(path_type[1]) 

839 path_type = path_type[0] 

840 

841 else: 

842 raise TypeError("Did not understand the path: %s" % str(path_type)) 

843 

844 # Hidden option, only einsum should call this 

845 einsum_call_arg = einsum_call 

846 

847 # Python side parsing 

848 input_subscripts, output_subscript, operands = _parse_einsum_input(operands) 

849 

850 # Build a few useful list and sets 

851 input_list = input_subscripts.split(',') 

852 input_sets = [set(x) for x in input_list] 

853 output_set = set(output_subscript) 

854 indices = set(input_subscripts.replace(',', '')) 

855 

856 # Get length of each unique dimension and ensure all dimensions are correct 

857 dimension_dict = {} 

858 broadcast_indices = [[] for x in range(len(input_list))] 

859 for tnum, term in enumerate(input_list): 

860 sh = operands[tnum].shape 

861 if len(sh) != len(term): 

862 raise ValueError("Einstein sum subscript %s does not contain the " 

863 "correct number of indices for operand %d." 

864 % (input_subscripts[tnum], tnum)) 

865 for cnum, char in enumerate(term): 

866 dim = sh[cnum] 

867 

868 # Build out broadcast indices 

869 if dim == 1: 

870 broadcast_indices[tnum].append(char) 

871 

872 if char in dimension_dict.keys(): 

873 # For broadcasting cases we always want the largest dim size 

874 if dimension_dict[char] == 1: 

875 dimension_dict[char] = dim 

876 elif dim not in (1, dimension_dict[char]): 

877 raise ValueError("Size of label '%s' for operand %d (%d) " 

878 "does not match previous terms (%d)." 

879 % (char, tnum, dimension_dict[char], dim)) 

880 else: 

881 dimension_dict[char] = dim 

882 

883 # Convert broadcast inds to sets 

884 broadcast_indices = [set(x) for x in broadcast_indices] 

885 

886 # Compute size of each input array plus the output array 

887 size_list = [_compute_size_by_dict(term, dimension_dict) 

888 for term in input_list + [output_subscript]] 

889 max_size = max(size_list) 

890 

891 if memory_limit is None: 

892 memory_arg = max_size 

893 else: 

894 memory_arg = memory_limit 

895 

896 # Compute naive cost 

897 # This isn't quite right, need to look into exactly how einsum does this 

898 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 

899 naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict) 

900 

901 # Compute the path 

902 if explicit_einsum_path: 

903 path = path_type[1:] 

904 elif ( 

905 (path_type is False) 

906 or (len(input_list) in [1, 2]) 

907 or (indices == output_set) 

908 ): 

909 # Nothing to be optimized, leave it to einsum 

910 path = [tuple(range(len(input_list)))] 

911 elif path_type == "greedy": 

912 path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg) 

913 elif path_type == "optimal": 

914 path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg) 

915 else: 

916 raise KeyError("Path name %s not found", path_type) 

917 

918 cost_list, scale_list, size_list, contraction_list = [], [], [], [] 

919 

920 # Build contraction tuple (positions, gemm, einsum_str, remaining) 

921 for cnum, contract_inds in enumerate(path): 

922 # Make sure we remove inds from right to left 

923 contract_inds = tuple(sorted(list(contract_inds), reverse=True)) 

924 

925 contract = _find_contraction(contract_inds, input_sets, output_set) 

926 out_inds, input_sets, idx_removed, idx_contract = contract 

927 

928 cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict) 

929 cost_list.append(cost) 

930 scale_list.append(len(idx_contract)) 

931 size_list.append(_compute_size_by_dict(out_inds, dimension_dict)) 

932 

933 bcast = set() 

934 tmp_inputs = [] 

935 for x in contract_inds: 

936 tmp_inputs.append(input_list.pop(x)) 

937 bcast |= broadcast_indices.pop(x) 

938 

939 new_bcast_inds = bcast - idx_removed 

940 

941 # If we're broadcasting, nix blas 

942 if not len(idx_removed & bcast): 

943 do_blas = _can_dot(tmp_inputs, out_inds, idx_removed) 

944 else: 

945 do_blas = False 

946 

947 # Last contraction 

948 if (cnum - len(path)) == -1: 

949 idx_result = output_subscript 

950 else: 

951 sort_result = [(dimension_dict[ind], ind) for ind in out_inds] 

952 idx_result = "".join([x[1] for x in sorted(sort_result)]) 

953 

954 input_list.append(idx_result) 

955 broadcast_indices.append(new_bcast_inds) 

956 einsum_str = ",".join(tmp_inputs) + "->" + idx_result 

957 

958 contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas) 

959 contraction_list.append(contraction) 

960 

961 opt_cost = sum(cost_list) + 1 

962 

963 if len(input_list) != 1: 

964 # Explicit "einsum_path" is usually trusted, but we detect this kind of 

965 # mistake in order to prevent from returning an intermediate value. 

966 raise RuntimeError( 

967 "Invalid einsum_path is specified: {} more operands has to be " 

968 "contracted.".format(len(input_list) - 1)) 

969 

970 if einsum_call_arg: 

971 return (operands, contraction_list) 

972 

973 # Return the path along with a nice string representation 

974 overall_contraction = input_subscripts + "->" + output_subscript 

975 header = ("scaling", "current", "remaining") 

976 

977 speedup = naive_cost / opt_cost 

978 max_i = max(size_list) 

979 

980 path_print = " Complete contraction: %s\n" % overall_contraction 

981 path_print += " Naive scaling: %d\n" % len(indices) 

982 path_print += " Optimized scaling: %d\n" % max(scale_list) 

983 path_print += " Naive FLOP count: %.3e\n" % naive_cost 

984 path_print += " Optimized FLOP count: %.3e\n" % opt_cost 

985 path_print += " Theoretical speedup: %3.3f\n" % speedup 

986 path_print += " Largest intermediate: %.3e elements\n" % max_i 

987 path_print += "-" * 74 + "\n" 

988 path_print += "%6s %24s %40s\n" % header 

989 path_print += "-" * 74 

990 

991 for n, contraction in enumerate(contraction_list): 

992 inds, idx_rm, einsum_str, remaining, blas = contraction 

993 remaining_str = ",".join(remaining) + "->" + output_subscript 

994 path_run = (scale_list[n], einsum_str, remaining_str) 

995 path_print += "\n%4d %24s %40s" % path_run 

996 

997 path = ['einsum_path'] + path 

998 return (path, path_print) 

999 

1000 

1001def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs): 

1002 # Arguably we dispatch on more arguments than we really should; see note in 

1003 # _einsum_path_dispatcher for why. 

1004 yield from operands 

1005 yield out 

1006 

1007 

1008# Rewrite einsum to handle different cases 

1009@array_function_dispatch(_einsum_dispatcher, module='numpy') 

1010def einsum(*operands, out=None, optimize=False, **kwargs): 

1011 """ 

1012 einsum(subscripts, *operands, out=None, dtype=None, order='K', 

1013 casting='safe', optimize=False) 

1014 

1015 Evaluates the Einstein summation convention on the operands. 

1016 

1017 Using the Einstein summation convention, many common multi-dimensional, 

1018 linear algebraic array operations can be represented in a simple fashion. 

1019 In *implicit* mode `einsum` computes these values. 

1020 

1021 In *explicit* mode, `einsum` provides further flexibility to compute 

1022 other array operations that might not be considered classical Einstein 

1023 summation operations, by disabling, or forcing summation over specified 

1024 subscript labels. 

1025 

1026 See the notes and examples for clarification. 

1027 

1028 Parameters 

1029 ---------- 

1030 subscripts : str 

1031 Specifies the subscripts for summation as comma separated list of 

1032 subscript labels. An implicit (classical Einstein summation) 

1033 calculation is performed unless the explicit indicator '->' is 

1034 included as well as subscript labels of the precise output form. 

1035 operands : list of array_like 

1036 These are the arrays for the operation. 

1037 out : ndarray, optional 

1038 If provided, the calculation is done into this array. 

1039 dtype : {data-type, None}, optional 

1040 If provided, forces the calculation to use the data type specified. 

1041 Note that you may have to also give a more liberal `casting` 

1042 parameter to allow the conversions. Default is None. 

1043 order : {'C', 'F', 'A', 'K'}, optional 

1044 Controls the memory layout of the output. 'C' means it should 

1045 be C contiguous. 'F' means it should be Fortran contiguous, 

1046 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise. 

1047 'K' means it should be as close to the layout as the inputs as 

1048 is possible, including arbitrarily permuted axes. 

1049 Default is 'K'. 

1050 casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional 

1051 Controls what kind of data casting may occur. Setting this to 

1052 'unsafe' is not recommended, as it can adversely affect accumulations. 

1053 

1054 * 'no' means the data types should not be cast at all. 

1055 * 'equiv' means only byte-order changes are allowed. 

1056 * 'safe' means only casts which can preserve values are allowed. 

1057 * 'same_kind' means only safe casts or casts within a kind, 

1058 like float64 to float32, are allowed. 

1059 * 'unsafe' means any data conversions may be done. 

1060 

1061 Default is 'safe'. 

1062 optimize : {False, True, 'greedy', 'optimal'}, optional 

1063 Controls if intermediate optimization should occur. No optimization 

1064 will occur if False and True will default to the 'greedy' algorithm. 

1065 Also accepts an explicit contraction list from the ``np.einsum_path`` 

1066 function. See ``np.einsum_path`` for more details. Defaults to False. 

1067 

1068 Returns 

1069 ------- 

1070 output : ndarray 

1071 The calculation based on the Einstein summation convention. 

1072 

1073 See Also 

1074 -------- 

1075 einsum_path, dot, inner, outer, tensordot, linalg.multi_dot 

1076 einops : 

1077 similar verbose interface is provided by 

1078 `einops <https://github.com/arogozhnikov/einops>`_ package to cover 

1079 additional operations: transpose, reshape/flatten, repeat/tile, 

1080 squeeze/unsqueeze and reductions. 

1081 opt_einsum : 

1082 `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ 

1083 optimizes contraction order for einsum-like expressions 

1084 in backend-agnostic manner. 

1085 

1086 Notes 

1087 ----- 

1088 .. versionadded:: 1.6.0 

1089 

1090 The Einstein summation convention can be used to compute 

1091 many multi-dimensional, linear algebraic array operations. `einsum` 

1092 provides a succinct way of representing these. 

1093 

1094 A non-exhaustive list of these operations, 

1095 which can be computed by `einsum`, is shown below along with examples: 

1096 

1097 * Trace of an array, :py:func:`numpy.trace`. 

1098 * Return a diagonal, :py:func:`numpy.diag`. 

1099 * Array axis summations, :py:func:`numpy.sum`. 

1100 * Transpositions and permutations, :py:func:`numpy.transpose`. 

1101 * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`. 

1102 * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`. 

1103 * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`. 

1104 * Tensor contractions, :py:func:`numpy.tensordot`. 

1105 * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`. 

1106 

1107 The subscripts string is a comma-separated list of subscript labels, 

1108 where each label refers to a dimension of the corresponding operand. 

1109 Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)`` 

1110 is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label 

1111 appears only once, it is not summed, so ``np.einsum('i', a)`` produces a 

1112 view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)`` 

1113 describes traditional matrix multiplication and is equivalent to 

1114 :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one 

1115 operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent 

1116 to :py:func:`np.trace(a) <numpy.trace>`. 

1117 

1118 In *implicit mode*, the chosen subscripts are important 

1119 since the axes of the output are reordered alphabetically. This 

1120 means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while 

1121 ``np.einsum('ji', a)`` takes its transpose. Additionally, 

1122 ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while, 

1123 ``np.einsum('ij,jh', a, b)`` returns the transpose of the 

1124 multiplication since subscript 'h' precedes subscript 'i'. 

1125 

1126 In *explicit mode* the output can be directly controlled by 

1127 specifying output subscript labels. This requires the 

1128 identifier '->' as well as the list of output subscript labels. 

1129 This feature increases the flexibility of the function since 

1130 summing can be disabled or forced when required. The call 

1131 ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`, 

1132 and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`. 

1133 The difference is that `einsum` does not allow broadcasting by default. 

1134 Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the 

1135 order of the output subscript labels and therefore returns matrix 

1136 multiplication, unlike the example above in implicit mode. 

1137 

1138 To enable and control broadcasting, use an ellipsis. Default 

1139 NumPy-style broadcasting is done by adding an ellipsis 

1140 to the left of each term, like ``np.einsum('...ii->...i', a)``. 

1141 To take the trace along the first and last axes, 

1142 you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix 

1143 product with the left-most indices instead of rightmost, one can do 

1144 ``np.einsum('ij...,jk...->ik...', a, b)``. 

1145 

1146 When there is only one operand, no axes are summed, and no output 

1147 parameter is provided, a view into the operand is returned instead 

1148 of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)`` 

1149 produces a view (changed in version 1.10.0). 

1150 

1151 `einsum` also provides an alternative way to provide the subscripts 

1152 and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. 

1153 If the output shape is not provided in this format `einsum` will be 

1154 calculated in implicit mode, otherwise it will be performed explicitly. 

1155 The examples below have corresponding `einsum` calls with the two 

1156 parameter methods. 

1157 

1158 .. versionadded:: 1.10.0 

1159 

1160 Views returned from einsum are now writeable whenever the input array 

1161 is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now 

1162 have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>` 

1163 and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal 

1164 of a 2D array. 

1165 

1166 .. versionadded:: 1.12.0 

1167 

1168 Added the ``optimize`` argument which will optimize the contraction order 

1169 of an einsum expression. For a contraction with three or more operands this 

1170 can greatly increase the computational efficiency at the cost of a larger 

1171 memory footprint during computation. 

1172 

1173 Typically a 'greedy' algorithm is applied which empirical tests have shown 

1174 returns the optimal path in the majority of cases. In some cases 'optimal' 

1175 will return the superlative path through a more expensive, exhaustive search. 

1176 For iterative calculations it may be advisable to calculate the optimal path 

1177 once and reuse that path by supplying it as an argument. An example is given 

1178 below. 

1179 

1180 See :py:func:`numpy.einsum_path` for more details. 

1181 

1182 Examples 

1183 -------- 

1184 >>> a = np.arange(25).reshape(5,5) 

1185 >>> b = np.arange(5) 

1186 >>> c = np.arange(6).reshape(2,3) 

1187 

1188 Trace of a matrix: 

1189 

1190 >>> np.einsum('ii', a) 

1191 60 

1192 >>> np.einsum(a, [0,0]) 

1193 60 

1194 >>> np.trace(a) 

1195 60 

1196 

1197 Extract the diagonal (requires explicit form): 

1198 

1199 >>> np.einsum('ii->i', a) 

1200 array([ 0, 6, 12, 18, 24]) 

1201 >>> np.einsum(a, [0,0], [0]) 

1202 array([ 0, 6, 12, 18, 24]) 

1203 >>> np.diag(a) 

1204 array([ 0, 6, 12, 18, 24]) 

1205 

1206 Sum over an axis (requires explicit form): 

1207 

1208 >>> np.einsum('ij->i', a) 

1209 array([ 10, 35, 60, 85, 110]) 

1210 >>> np.einsum(a, [0,1], [0]) 

1211 array([ 10, 35, 60, 85, 110]) 

1212 >>> np.sum(a, axis=1) 

1213 array([ 10, 35, 60, 85, 110]) 

1214 

1215 For higher dimensional arrays summing a single axis can be done with ellipsis: 

1216 

1217 >>> np.einsum('...j->...', a) 

1218 array([ 10, 35, 60, 85, 110]) 

1219 >>> np.einsum(a, [Ellipsis,1], [Ellipsis]) 

1220 array([ 10, 35, 60, 85, 110]) 

1221 

1222 Compute a matrix transpose, or reorder any number of axes: 

1223 

1224 >>> np.einsum('ji', c) 

1225 array([[0, 3], 

1226 [1, 4], 

1227 [2, 5]]) 

1228 >>> np.einsum('ij->ji', c) 

1229 array([[0, 3], 

1230 [1, 4], 

1231 [2, 5]]) 

1232 >>> np.einsum(c, [1,0]) 

1233 array([[0, 3], 

1234 [1, 4], 

1235 [2, 5]]) 

1236 >>> np.transpose(c) 

1237 array([[0, 3], 

1238 [1, 4], 

1239 [2, 5]]) 

1240 

1241 Vector inner products: 

1242 

1243 >>> np.einsum('i,i', b, b) 

1244 30 

1245 >>> np.einsum(b, [0], b, [0]) 

1246 30 

1247 >>> np.inner(b,b) 

1248 30 

1249 

1250 Matrix vector multiplication: 

1251 

1252 >>> np.einsum('ij,j', a, b) 

1253 array([ 30, 80, 130, 180, 230]) 

1254 >>> np.einsum(a, [0,1], b, [1]) 

1255 array([ 30, 80, 130, 180, 230]) 

1256 >>> np.dot(a, b) 

1257 array([ 30, 80, 130, 180, 230]) 

1258 >>> np.einsum('...j,j', a, b) 

1259 array([ 30, 80, 130, 180, 230]) 

1260 

1261 Broadcasting and scalar multiplication: 

1262 

1263 >>> np.einsum('..., ...', 3, c) 

1264 array([[ 0, 3, 6], 

1265 [ 9, 12, 15]]) 

1266 >>> np.einsum(',ij', 3, c) 

1267 array([[ 0, 3, 6], 

1268 [ 9, 12, 15]]) 

1269 >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) 

1270 array([[ 0, 3, 6], 

1271 [ 9, 12, 15]]) 

1272 >>> np.multiply(3, c) 

1273 array([[ 0, 3, 6], 

1274 [ 9, 12, 15]]) 

1275 

1276 Vector outer product: 

1277 

1278 >>> np.einsum('i,j', np.arange(2)+1, b) 

1279 array([[0, 1, 2, 3, 4], 

1280 [0, 2, 4, 6, 8]]) 

1281 >>> np.einsum(np.arange(2)+1, [0], b, [1]) 

1282 array([[0, 1, 2, 3, 4], 

1283 [0, 2, 4, 6, 8]]) 

1284 >>> np.outer(np.arange(2)+1, b) 

1285 array([[0, 1, 2, 3, 4], 

1286 [0, 2, 4, 6, 8]]) 

1287 

1288 Tensor contraction: 

1289 

1290 >>> a = np.arange(60.).reshape(3,4,5) 

1291 >>> b = np.arange(24.).reshape(4,3,2) 

1292 >>> np.einsum('ijk,jil->kl', a, b) 

1293 array([[4400., 4730.], 

1294 [4532., 4874.], 

1295 [4664., 5018.], 

1296 [4796., 5162.], 

1297 [4928., 5306.]]) 

1298 >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) 

1299 array([[4400., 4730.], 

1300 [4532., 4874.], 

1301 [4664., 5018.], 

1302 [4796., 5162.], 

1303 [4928., 5306.]]) 

1304 >>> np.tensordot(a,b, axes=([1,0],[0,1])) 

1305 array([[4400., 4730.], 

1306 [4532., 4874.], 

1307 [4664., 5018.], 

1308 [4796., 5162.], 

1309 [4928., 5306.]]) 

1310 

1311 Writeable returned arrays (since version 1.10.0): 

1312 

1313 >>> a = np.zeros((3, 3)) 

1314 >>> np.einsum('ii->i', a)[:] = 1 

1315 >>> a 

1316 array([[1., 0., 0.], 

1317 [0., 1., 0.], 

1318 [0., 0., 1.]]) 

1319 

1320 Example of ellipsis use: 

1321 

1322 >>> a = np.arange(6).reshape((3,2)) 

1323 >>> b = np.arange(12).reshape((4,3)) 

1324 >>> np.einsum('ki,jk->ij', a, b) 

1325 array([[10, 28, 46, 64], 

1326 [13, 40, 67, 94]]) 

1327 >>> np.einsum('ki,...k->i...', a, b) 

1328 array([[10, 28, 46, 64], 

1329 [13, 40, 67, 94]]) 

1330 >>> np.einsum('k...,jk', a, b) 

1331 array([[10, 28, 46, 64], 

1332 [13, 40, 67, 94]]) 

1333 

1334 Chained array operations. For more complicated contractions, speed ups 

1335 might be achieved by repeatedly computing a 'greedy' path or pre-computing the 

1336 'optimal' path and repeatedly applying it, using an 

1337 `einsum_path` insertion (since version 1.12.0). Performance improvements can be 

1338 particularly significant with larger arrays: 

1339 

1340 >>> a = np.ones(64).reshape(2,4,8) 

1341 

1342 Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.) 

1343 

1344 >>> for iteration in range(500): 

1345 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a) 

1346 

1347 Sub-optimal `einsum` (due to repeated path calculation time): ~330ms 

1348 

1349 >>> for iteration in range(500): 

1350 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal') 

1351 

1352 Greedy `einsum` (faster optimal path approximation): ~160ms 

1353 

1354 >>> for iteration in range(500): 

1355 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy') 

1356 

1357 Optimal `einsum` (best usage pattern in some use cases): ~110ms 

1358 

1359 >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0] 

1360 >>> for iteration in range(500): 

1361 ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path) 

1362 

1363 """ 

1364 # Special handling if out is specified 

1365 specified_out = out is not None 

1366 

1367 # If no optimization, run pure einsum 

1368 if optimize is False: 

1369 if specified_out: 

1370 kwargs['out'] = out 

1371 return c_einsum(*operands, **kwargs) 

1372 

1373 # Check the kwargs to avoid a more cryptic error later, without having to 

1374 # repeat default values here 

1375 valid_einsum_kwargs = ['dtype', 'order', 'casting'] 

1376 unknown_kwargs = [k for (k, v) in kwargs.items() if 

1377 k not in valid_einsum_kwargs] 

1378 if len(unknown_kwargs): 

1379 raise TypeError("Did not understand the following kwargs: %s" 

1380 % unknown_kwargs) 

1381 

1382 # Build the contraction list and operand 

1383 operands, contraction_list = einsum_path(*operands, optimize=optimize, 

1384 einsum_call=True) 

1385 

1386 # Handle order kwarg for output array, c_einsum allows mixed case 

1387 output_order = kwargs.pop('order', 'K') 

1388 if output_order.upper() == 'A': 

1389 if all(arr.flags.f_contiguous for arr in operands): 

1390 output_order = 'F' 

1391 else: 

1392 output_order = 'C' 

1393 

1394 # Start contraction loop 

1395 for num, contraction in enumerate(contraction_list): 

1396 inds, idx_rm, einsum_str, remaining, blas = contraction 

1397 tmp_operands = [operands.pop(x) for x in inds] 

1398 

1399 # Do we need to deal with the output? 

1400 handle_out = specified_out and ((num + 1) == len(contraction_list)) 

1401 

1402 # Call tensordot if still possible 

1403 if blas: 

1404 # Checks have already been handled 

1405 input_str, results_index = einsum_str.split('->') 

1406 input_left, input_right = input_str.split(',') 

1407 

1408 tensor_result = input_left + input_right 

1409 for s in idx_rm: 

1410 tensor_result = tensor_result.replace(s, "") 

1411 

1412 # Find indices to contract over 

1413 left_pos, right_pos = [], [] 

1414 for s in sorted(idx_rm): 

1415 left_pos.append(input_left.find(s)) 

1416 right_pos.append(input_right.find(s)) 

1417 

1418 # Contract! 

1419 new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos))) 

1420 

1421 # Build a new view if needed 

1422 if (tensor_result != results_index) or handle_out: 

1423 if handle_out: 

1424 kwargs["out"] = out 

1425 new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs) 

1426 

1427 # Call einsum 

1428 else: 

1429 # If out was specified 

1430 if handle_out: 

1431 kwargs["out"] = out 

1432 

1433 # Do the contraction 

1434 new_view = c_einsum(einsum_str, *tmp_operands, **kwargs) 

1435 

1436 # Append new items and dereference what we can 

1437 operands.append(new_view) 

1438 del tmp_operands, new_view 

1439 

1440 if specified_out: 

1441 return out 

1442 else: 

1443 return asanyarray(operands[0], order=output_order)