Flatten a List of Lists

难度: 2

时长: 20 min

How to Flatten a List of Lists in Python – Real Python

Input:

>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

Output:

[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

General Solution with a for Loop

  1. Create a new empty list to store the flattened data.

  2. Iterate over each nested list or sublist in the original list.

  3. Add every item from the current sublist to the list of flattened data.

  4. Return the resulting list with the flattened data.

因为用了 .extend(row),所以只用一层 for loop:

>>> def flatten_extend(matrix):
...     flat_list = []
...     for row in matrix:
...         flat_list.extend(row)
...     return flat_list
...

>>> flatten_extend(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

>>> def flatten_concatenation(matrix):
...     flat_list = []
...     for row in matrix:
...         flat_list += row
...     return flat_list
...

Comprehension to Flatten a List of Lists

2维 comprehension:

>>> def flatten_comprehension(matrix):
...     return [item for row in matrix for item in row]
...

也可以将 comprehension 与 for in 组合, 见 List Guide。

Flattening a List Using Standard-Library and Built-in Tools

Chaining Iterables With itertools.chain()

>>> from itertools import chain

>>> def flatten_chain(matrix):
...     return list(chain.from_iterable(matrix))
...

>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

>>> flatten_chain(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

Concatenating Lists With functools.reduce()

>>> from functools import reduce

>>> def flatten_reduce_lambda(matrix):
...     return list(reduce(lambda x, y: x + y, matrix, []))
...

>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

>>> flatten_reduce_lambda(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

Using sum() to Concatenate Lists

并不推荐这种写法,但要看得明白别人这样写的代码:

>>> def flatten_sum(matrix):
...     return sum(matrix, [])
...

>>> matrix = [
...     [9, 3, 8, 3],
...     [4, 5, 2, 8],
...     [6, 4, 3, 1],
...     [1, 0, 4, 5],
... ]

>>> flatten_sum(matrix)
[9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5]

Performance While Flattening Your Lists

# flatten.py

from functools import reduce
from itertools import chain
from operator import add, concat, iconcat

def flatten_extend(matrix):
    flat_list = []
    for row in matrix:
        flat_list.extend(row)
    return flat_list

def flatten_concatenation(matrix):
    flat_list = []
    for row in matrix:
        flat_list += row
    return flat_list

def flatten_comprehension(matrix):
    return [item for row in matrix for item in row]

def flatten_chain(matrix):
    return list(chain.from_iterable(matrix))

def flatten_reduce_lambda(matrix):
    return list(reduce(lambda x, y: x + y, matrix, []))

def flatten_reduce_add(matrix):
    return reduce(add, matrix, [])

def flatten_reduce_concat(matrix):
    return reduce(concat, matrix, [])

def flatten_reduce_iconcat(matrix):
    return reduce(iconcat, matrix, [])

def flatten_sum(matrix):
    return sum(matrix, [])
# performance.py

from timeit import timeit

import flatten

SIZE = 1000
TO_MS = 1000
NUM = 10
FUNCTIONS = [
    "flatten_extend",
    "flatten_concatenation",
    "flatten_comprehension",
    "flatten_chain",
    "flatten_reduce_lambda",
    "flatten_reduce_add",
    "flatten_reduce_concat",
    "flatten_reduce_iconcat",
    "flatten_sum",
]

matrix = [list(range(SIZE))] * SIZE

results = {
    func: timeit(f"flatten.{func}(matrix)", globals=globals(), number=NUM)
    for func in FUNCTIONS
}

print(f"Time to flatten a {SIZE}x{SIZE} matrix (in milliseconds):\n")

for func, time in sorted(results.items(), key=lambda result: result[1]):
    print(f"{func + '()':.<30}{time * TO_MS / NUM:.>7.2f} ms")
$ python performance.py
Time to flatten a 1000x1000 matrix (in milliseconds):

flatten_concatenation()..........1.95 ms
flatten_extend().................2.03 ms
flatten_reduce_iconcat().........2.68 ms
flatten_chain()..................4.60 ms
flatten_comprehension()..........7.79 ms
flatten_sum().................1113.22 ms
flatten_reduce_concat().......1117.15 ms
flatten_reduce_lambda().......1117.52 ms
flatten_reduce_add()..........1118.80 ms

结论:

Flattening Python Lists for Data Science With NumPy

>>> import numpy as np

>>> matrix = np.array(
...     [
...         [9, 3, 8, 3],
...         [4, 5, 2, 8],
...         [6, 4, 3, 1],
...         [1, 0, 4, 5],
...     ]
... )

>>> matrix
array([[9, 3, 8, 3],
       [4, 5, 2, 8],
       [6, 4, 3, 1],
       [1, 0, 4, 5]])

>>> matrix.flatten()
array([9, 3, 8, 3, 4, 5, 2, 8, 6, 4, 3, 1, 1, 0, 4, 5])