Coverage for /var/srv/projects/api.amasfac.comuna18.com/tmp/venv/lib/python3.9/site-packages/pandas/io/stata.py: 10%
1557 statements
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
« prev ^ index » next coverage.py v6.4.4, created at 2023-07-17 14:22 -0600
1"""
2Module contains tools for processing Stata files into DataFrames
4The StataReader below was originally written by Joe Presbrey as part of PyDTA.
5It has been extended and improved by Skipper Seabold from the Statsmodels
6project who also developed the StataWriter and was finally added to pandas in
7a once again improved version.
9You can find more information on http://presbrey.mit.edu/PyDTA and
10https://www.statsmodels.org/devel/
11"""
12from __future__ import annotations
14from collections import abc
15import datetime
16from io import BytesIO
17import os
18import struct
19import sys
20from typing import (
21 IO,
22 TYPE_CHECKING,
23 Any,
24 AnyStr,
25 Final,
26 Hashable,
27 Sequence,
28 cast,
29)
30import warnings
32from dateutil.relativedelta import relativedelta
33import numpy as np
35from pandas._libs.lib import infer_dtype
36from pandas._libs.writers import max_len_string_array
37from pandas._typing import (
38 CompressionOptions,
39 FilePath,
40 ReadBuffer,
41 StorageOptions,
42 WriteBuffer,
43)
44from pandas.errors import (
45 CategoricalConversionWarning,
46 InvalidColumnName,
47 PossiblePrecisionLoss,
48 ValueLabelTypeMismatch,
49)
50from pandas.util._decorators import (
51 Appender,
52 deprecate_nonkeyword_arguments,
53 doc,
54)
55from pandas.util._exceptions import find_stack_level
57from pandas.core.dtypes.common import (
58 ensure_object,
59 is_categorical_dtype,
60 is_datetime64_dtype,
61 is_numeric_dtype,
62)
64from pandas import (
65 Categorical,
66 DatetimeIndex,
67 NaT,
68 Timestamp,
69 isna,
70 to_datetime,
71 to_timedelta,
72)
73from pandas.core.arrays.boolean import BooleanDtype
74from pandas.core.arrays.integer import IntegerDtype
75from pandas.core.frame import DataFrame
76from pandas.core.indexes.base import Index
77from pandas.core.series import Series
78from pandas.core.shared_docs import _shared_docs
80from pandas.io.common import get_handle
82if TYPE_CHECKING: 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true
83 from typing import Literal
85_version_error = (
86 "Version of given Stata file is {version}. pandas supports importing "
87 "versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), "
88 "114 (Stata 10/11), 115 (Stata 12), 117 (Stata 13), 118 (Stata 14/15/16),"
89 "and 119 (Stata 15/16, over 32,767 variables)."
90)
92_statafile_processing_params1 = """\
93convert_dates : bool, default True
94 Convert date variables to DataFrame time values.
95convert_categoricals : bool, default True
96 Read value labels and convert columns to Categorical/Factor variables."""
98_statafile_processing_params2 = """\
99index_col : str, optional
100 Column to set as index.
101convert_missing : bool, default False
102 Flag indicating whether to convert missing values to their Stata
103 representations. If False, missing values are replaced with nan.
104 If True, columns containing missing values are returned with
105 object data types and missing values are represented by
106 StataMissingValue objects.
107preserve_dtypes : bool, default True
108 Preserve Stata datatypes. If False, numeric data are upcast to pandas
109 default types for foreign data (float64 or int64).
110columns : list or None
111 Columns to retain. Columns will be returned in the given order. None
112 returns all columns.
113order_categoricals : bool, default True
114 Flag indicating whether converted categorical data are ordered."""
116_chunksize_params = """\
117chunksize : int, default None
118 Return StataReader object for iterations, returns chunks with
119 given number of lines."""
121_iterator_params = """\
122iterator : bool, default False
123 Return StataReader object."""
125_reader_notes = """\
126Notes
127-----
128Categorical variables read through an iterator may not have the same
129categories and dtype. This occurs when a variable stored in a DTA
130file is associated to an incomplete set of value labels that only
131label a strict subset of the values."""
133_read_stata_doc = f"""
134Read Stata file into DataFrame.
136Parameters
137----------
138filepath_or_buffer : str, path object or file-like object
139 Any valid string path is acceptable. The string could be a URL. Valid
140 URL schemes include http, ftp, s3, and file. For file URLs, a host is
141 expected. A local file could be: ``file://localhost/path/to/table.dta``.
143 If you want to pass in a path object, pandas accepts any ``os.PathLike``.
145 By file-like object, we refer to objects with a ``read()`` method,
146 such as a file handle (e.g. via builtin ``open`` function)
147 or ``StringIO``.
148{_statafile_processing_params1}
149{_statafile_processing_params2}
150{_chunksize_params}
151{_iterator_params}
152{_shared_docs["decompression_options"] % "filepath_or_buffer"}
153{_shared_docs["storage_options"]}
155Returns
156-------
157DataFrame or StataReader
159See Also
160--------
161io.stata.StataReader : Low-level reader for Stata data files.
162DataFrame.to_stata: Export Stata data files.
164{_reader_notes}
166Examples
167--------
169Creating a dummy stata for this example
170>>> df = pd.DataFrame({{'animal': ['falcon', 'parrot', 'falcon',
171... 'parrot'],
172... 'speed': [350, 18, 361, 15]}}) # doctest: +SKIP
173>>> df.to_stata('animals.dta') # doctest: +SKIP
175Read a Stata dta file:
177>>> df = pd.read_stata('animals.dta') # doctest: +SKIP
179Read a Stata dta file in 10,000 line chunks:
180>>> values = np.random.randint(0, 10, size=(20_000, 1), dtype="uint8") # doctest: +SKIP
181>>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP
182>>> df.to_stata('filename.dta') # doctest: +SKIP
184>>> itr = pd.read_stata('filename.dta', chunksize=10000) # doctest: +SKIP
185>>> for chunk in itr:
186... # Operate on a single chunk, e.g., chunk.mean()
187... pass # doctest: +SKIP
188"""
190_read_method_doc = f"""\
191Reads observations from Stata file, converting them into a dataframe
193Parameters
194----------
195nrows : int
196 Number of lines to read from data file, if None read whole file.
197{_statafile_processing_params1}
198{_statafile_processing_params2}
200Returns
201-------
202DataFrame
203"""
205_stata_reader_doc = f"""\
206Class for reading Stata dta files.
208Parameters
209----------
210path_or_buf : path (string), buffer or path object
211 string, path object (pathlib.Path or py._path.local.LocalPath) or object
212 implementing a binary read() functions.
213{_statafile_processing_params1}
214{_statafile_processing_params2}
215{_chunksize_params}
216{_shared_docs["decompression_options"]}
217{_shared_docs["storage_options"]}
219{_reader_notes}
220"""
223_date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"]
226stata_epoch: Final = datetime.datetime(1960, 1, 1)
229# TODO: Add typing. As of January 2020 it is not possible to type this function since
230# mypy doesn't understand that a Series and an int can be combined using mathematical
231# operations. (+, -).
232def _stata_elapsed_date_to_datetime_vec(dates, fmt) -> Series:
233 """
234 Convert from SIF to datetime. https://www.stata.com/help.cgi?datetime
236 Parameters
237 ----------
238 dates : Series
239 The Stata Internal Format date to convert to datetime according to fmt
240 fmt : str
241 The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
242 Returns
244 Returns
245 -------
246 converted : Series
247 The converted dates
249 Examples
250 --------
251 >>> dates = pd.Series([52])
252 >>> _stata_elapsed_date_to_datetime_vec(dates , "%tw")
253 0 1961-01-01
254 dtype: datetime64[ns]
256 Notes
257 -----
258 datetime/c - tc
259 milliseconds since 01jan1960 00:00:00.000, assuming 86,400 s/day
260 datetime/C - tC - NOT IMPLEMENTED
261 milliseconds since 01jan1960 00:00:00.000, adjusted for leap seconds
262 date - td
263 days since 01jan1960 (01jan1960 = 0)
264 weekly date - tw
265 weeks since 1960w1
266 This assumes 52 weeks in a year, then adds 7 * remainder of the weeks.
267 The datetime value is the start of the week in terms of days in the
268 year, not ISO calendar weeks.
269 monthly date - tm
270 months since 1960m1
271 quarterly date - tq
272 quarters since 1960q1
273 half-yearly date - th
274 half-years since 1960h1 yearly
275 date - ty
276 years since 0000
277 """
278 MIN_YEAR, MAX_YEAR = Timestamp.min.year, Timestamp.max.year
279 MAX_DAY_DELTA = (Timestamp.max - datetime.datetime(1960, 1, 1)).days
280 MIN_DAY_DELTA = (Timestamp.min - datetime.datetime(1960, 1, 1)).days
281 MIN_MS_DELTA = MIN_DAY_DELTA * 24 * 3600 * 1000
282 MAX_MS_DELTA = MAX_DAY_DELTA * 24 * 3600 * 1000
284 def convert_year_month_safe(year, month) -> Series:
285 """
286 Convert year and month to datetimes, using pandas vectorized versions
287 when the date range falls within the range supported by pandas.
288 Otherwise it falls back to a slower but more robust method
289 using datetime.
290 """
291 if year.max() < MAX_YEAR and year.min() > MIN_YEAR:
292 return to_datetime(100 * year + month, format="%Y%m")
293 else:
294 index = getattr(year, "index", None)
295 return Series(
296 [datetime.datetime(y, m, 1) for y, m in zip(year, month)], index=index
297 )
299 def convert_year_days_safe(year, days) -> Series:
300 """
301 Converts year (e.g. 1999) and days since the start of the year to a
302 datetime or datetime64 Series
303 """
304 if year.max() < (MAX_YEAR - 1) and year.min() > MIN_YEAR:
305 return to_datetime(year, format="%Y") + to_timedelta(days, unit="d")
306 else:
307 index = getattr(year, "index", None)
308 value = [
309 datetime.datetime(y, 1, 1) + relativedelta(days=int(d))
310 for y, d in zip(year, days)
311 ]
312 return Series(value, index=index)
314 def convert_delta_safe(base, deltas, unit) -> Series:
315 """
316 Convert base dates and deltas to datetimes, using pandas vectorized
317 versions if the deltas satisfy restrictions required to be expressed
318 as dates in pandas.
319 """
320 index = getattr(deltas, "index", None)
321 if unit == "d":
322 if deltas.max() > MAX_DAY_DELTA or deltas.min() < MIN_DAY_DELTA:
323 values = [base + relativedelta(days=int(d)) for d in deltas]
324 return Series(values, index=index)
325 elif unit == "ms":
326 if deltas.max() > MAX_MS_DELTA or deltas.min() < MIN_MS_DELTA:
327 values = [
328 base + relativedelta(microseconds=(int(d) * 1000)) for d in deltas
329 ]
330 return Series(values, index=index)
331 else:
332 raise ValueError("format not understood")
333 base = to_datetime(base)
334 deltas = to_timedelta(deltas, unit=unit)
335 return base + deltas
337 # TODO(non-nano): If/when pandas supports more than datetime64[ns], this
338 # should be improved to use correct range, e.g. datetime[Y] for yearly
339 bad_locs = np.isnan(dates)
340 has_bad_values = False
341 if bad_locs.any():
342 has_bad_values = True
343 # reset cache to avoid SettingWithCopy checks (we own the DataFrame and the
344 # `dates` Series is used to overwrite itself in the DataFramae)
345 dates._reset_cacher()
346 dates[bad_locs] = 1.0 # Replace with NaT
347 dates = dates.astype(np.int64)
349 if fmt.startswith(("%tc", "tc")): # Delta ms relative to base
350 base = stata_epoch
351 ms = dates
352 conv_dates = convert_delta_safe(base, ms, "ms")
353 elif fmt.startswith(("%tC", "tC")):
355 warnings.warn(
356 "Encountered %tC format. Leaving in Stata Internal Format.",
357 stacklevel=find_stack_level(),
358 )
359 conv_dates = Series(dates, dtype=object)
360 if has_bad_values:
361 conv_dates[bad_locs] = NaT
362 return conv_dates
363 # Delta days relative to base
364 elif fmt.startswith(("%td", "td", "%d", "d")):
365 base = stata_epoch
366 days = dates
367 conv_dates = convert_delta_safe(base, days, "d")
368 # does not count leap days - 7 days is a week.
369 # 52nd week may have more than 7 days
370 elif fmt.startswith(("%tw", "tw")):
371 year = stata_epoch.year + dates // 52
372 days = (dates % 52) * 7
373 conv_dates = convert_year_days_safe(year, days)
374 elif fmt.startswith(("%tm", "tm")): # Delta months relative to base
375 year = stata_epoch.year + dates // 12
376 month = (dates % 12) + 1
377 conv_dates = convert_year_month_safe(year, month)
378 elif fmt.startswith(("%tq", "tq")): # Delta quarters relative to base
379 year = stata_epoch.year + dates // 4
380 quarter_month = (dates % 4) * 3 + 1
381 conv_dates = convert_year_month_safe(year, quarter_month)
382 elif fmt.startswith(("%th", "th")): # Delta half-years relative to base
383 year = stata_epoch.year + dates // 2
384 month = (dates % 2) * 6 + 1
385 conv_dates = convert_year_month_safe(year, month)
386 elif fmt.startswith(("%ty", "ty")): # Years -- not delta
387 year = dates
388 first_month = np.ones_like(dates)
389 conv_dates = convert_year_month_safe(year, first_month)
390 else:
391 raise ValueError(f"Date fmt {fmt} not understood")
393 if has_bad_values: # Restore NaT for bad values
394 conv_dates[bad_locs] = NaT
396 return conv_dates
399def _datetime_to_stata_elapsed_vec(dates: Series, fmt: str) -> Series:
400 """
401 Convert from datetime to SIF. https://www.stata.com/help.cgi?datetime
403 Parameters
404 ----------
405 dates : Series
406 Series or array containing datetime.datetime or datetime64[ns] to
407 convert to the Stata Internal Format given by fmt
408 fmt : str
409 The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
410 """
411 index = dates.index
412 NS_PER_DAY = 24 * 3600 * 1000 * 1000 * 1000
413 US_PER_DAY = NS_PER_DAY / 1000
415 def parse_dates_safe(dates, delta=False, year=False, days=False):
416 d = {}
417 if is_datetime64_dtype(dates.dtype):
418 if delta:
419 time_delta = dates - stata_epoch
420 d["delta"] = time_delta._values.view(np.int64) // 1000 # microseconds
421 if days or year:
422 date_index = DatetimeIndex(dates)
423 d["year"] = date_index._data.year
424 d["month"] = date_index._data.month
425 if days:
426 days_in_ns = dates.view(np.int64) - to_datetime(
427 d["year"], format="%Y"
428 ).view(np.int64)
429 d["days"] = days_in_ns // NS_PER_DAY
431 elif infer_dtype(dates, skipna=False) == "datetime":
432 if delta:
433 delta = dates._values - stata_epoch
435 def f(x: datetime.timedelta) -> float:
436 return US_PER_DAY * x.days + 1000000 * x.seconds + x.microseconds
438 v = np.vectorize(f)
439 d["delta"] = v(delta)
440 if year:
441 year_month = dates.apply(lambda x: 100 * x.year + x.month)
442 d["year"] = year_month._values // 100
443 d["month"] = year_month._values - d["year"] * 100
444 if days:
446 def g(x: datetime.datetime) -> int:
447 return (x - datetime.datetime(x.year, 1, 1)).days
449 v = np.vectorize(g)
450 d["days"] = v(dates)
451 else:
452 raise ValueError(
453 "Columns containing dates must contain either "
454 "datetime64, datetime.datetime or null values."
455 )
457 return DataFrame(d, index=index)
459 bad_loc = isna(dates)
460 index = dates.index
461 if bad_loc.any():
462 dates = Series(dates)
463 if is_datetime64_dtype(dates):
464 dates[bad_loc] = to_datetime(stata_epoch)
465 else:
466 dates[bad_loc] = stata_epoch
468 if fmt in ["%tc", "tc"]:
469 d = parse_dates_safe(dates, delta=True)
470 conv_dates = d.delta / 1000
471 elif fmt in ["%tC", "tC"]:
472 warnings.warn(
473 "Stata Internal Format tC not supported.",
474 stacklevel=find_stack_level(),
475 )
476 conv_dates = dates
477 elif fmt in ["%td", "td"]:
478 d = parse_dates_safe(dates, delta=True)
479 conv_dates = d.delta // US_PER_DAY
480 elif fmt in ["%tw", "tw"]:
481 d = parse_dates_safe(dates, year=True, days=True)
482 conv_dates = 52 * (d.year - stata_epoch.year) + d.days // 7
483 elif fmt in ["%tm", "tm"]:
484 d = parse_dates_safe(dates, year=True)
485 conv_dates = 12 * (d.year - stata_epoch.year) + d.month - 1
486 elif fmt in ["%tq", "tq"]:
487 d = parse_dates_safe(dates, year=True)
488 conv_dates = 4 * (d.year - stata_epoch.year) + (d.month - 1) // 3
489 elif fmt in ["%th", "th"]:
490 d = parse_dates_safe(dates, year=True)
491 conv_dates = 2 * (d.year - stata_epoch.year) + (d.month > 6).astype(int)
492 elif fmt in ["%ty", "ty"]:
493 d = parse_dates_safe(dates, year=True)
494 conv_dates = d.year
495 else:
496 raise ValueError(f"Format {fmt} is not a known Stata date format")
498 conv_dates = Series(conv_dates, dtype=np.float64)
499 missing_value = struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
500 conv_dates[bad_loc] = missing_value
502 return Series(conv_dates, index=index)
505excessive_string_length_error: Final = """
506Fixed width strings in Stata .dta files are limited to 244 (or fewer)
507characters. Column '{0}' does not satisfy this restriction. Use the
508'version=117' parameter to write the newer (Stata 13 and later) format.
509"""
512precision_loss_doc: Final = """
513Column converted from {0} to {1}, and some data are outside of the lossless
514conversion range. This may result in a loss of precision in the saved data.
515"""
518value_label_mismatch_doc: Final = """
519Stata value labels (pandas categories) must be strings. Column {0} contains
520non-string labels which will be converted to strings. Please check that the
521Stata data file created has not lost information due to duplicate labels.
522"""
525invalid_name_doc: Final = """
526Not all pandas column names were valid Stata variable names.
527The following replacements have been made:
529 {0}
531If this is not what you expect, please make sure you have Stata-compliant
532column names in your DataFrame (strings only, max 32 characters, only
533alphanumerics and underscores, no Stata reserved words)
534"""
537categorical_conversion_warning: Final = """
538One or more series with value labels are not fully labeled. Reading this
539dataset with an iterator results in categorical variable with different
540categories. This occurs since it is not possible to know all possible values
541until the entire dataset has been read. To avoid this warning, you can either
542read dataset without an iterator, or manually convert categorical data by
543``convert_categoricals`` to False and then accessing the variable labels
544through the value_labels method of the reader.
545"""
548def _cast_to_stata_types(data: DataFrame) -> DataFrame:
549 """
550 Checks the dtypes of the columns of a pandas DataFrame for
551 compatibility with the data types and ranges supported by Stata, and
552 converts if necessary.
554 Parameters
555 ----------
556 data : DataFrame
557 The DataFrame to check and convert
559 Notes
560 -----
561 Numeric columns in Stata must be one of int8, int16, int32, float32 or
562 float64, with some additional value restrictions. int8 and int16 columns
563 are checked for violations of the value restrictions and upcast if needed.
564 int64 data is not usable in Stata, and so it is downcast to int32 whenever
565 the value are in the int32 range, and sidecast to float64 when larger than
566 this range. If the int64 values are outside of the range of those
567 perfectly representable as float64 values, a warning is raised.
569 bool columns are cast to int8. uint columns are converted to int of the
570 same size if there is no loss in precision, otherwise are upcast to a
571 larger type. uint64 is currently not supported since it is concerted to
572 object in a DataFrame.
573 """
574 ws = ""
575 # original, if small, if large
576 conversion_data: tuple[
577 tuple[type, type, type],
578 tuple[type, type, type],
579 tuple[type, type, type],
580 tuple[type, type, type],
581 tuple[type, type, type],
582 ] = (
583 (np.bool_, np.int8, np.int8),
584 (np.uint8, np.int8, np.int16),
585 (np.uint16, np.int16, np.int32),
586 (np.uint32, np.int32, np.int64),
587 (np.uint64, np.int64, np.float64),
588 )
590 float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
591 float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
593 for col in data:
594 # Cast from unsupported types to supported types
595 is_nullable_int = isinstance(data[col].dtype, (IntegerDtype, BooleanDtype))
596 orig = data[col]
597 # We need to find orig_missing before altering data below
598 orig_missing = orig.isna()
599 if is_nullable_int:
600 missing_loc = data[col].isna()
601 if missing_loc.any():
602 # Replace with always safe value
603 fv = 0 if isinstance(data[col].dtype, IntegerDtype) else False
604 data.loc[missing_loc, col] = fv
605 # Replace with NumPy-compatible column
606 data[col] = data[col].astype(data[col].dtype.numpy_dtype)
607 dtype = data[col].dtype
608 for c_data in conversion_data:
609 if dtype == c_data[0]:
610 if data[col].max() <= np.iinfo(c_data[1]).max:
611 dtype = c_data[1]
612 else:
613 dtype = c_data[2]
614 if c_data[2] == np.int64: # Warn if necessary
615 if data[col].max() >= 2**53:
616 ws = precision_loss_doc.format("uint64", "float64")
618 data[col] = data[col].astype(dtype)
620 # Check values and upcast if necessary
621 if dtype == np.int8:
622 if data[col].max() > 100 or data[col].min() < -127:
623 data[col] = data[col].astype(np.int16)
624 elif dtype == np.int16:
625 if data[col].max() > 32740 or data[col].min() < -32767:
626 data[col] = data[col].astype(np.int32)
627 elif dtype == np.int64:
628 if data[col].max() <= 2147483620 and data[col].min() >= -2147483647:
629 data[col] = data[col].astype(np.int32)
630 else:
631 data[col] = data[col].astype(np.float64)
632 if data[col].max() >= 2**53 or data[col].min() <= -(2**53):
633 ws = precision_loss_doc.format("int64", "float64")
634 elif dtype in (np.float32, np.float64):
635 if np.isinf(data[col]).any():
636 raise ValueError(
637 f"Column {col} contains infinity or -infinity"
638 "which is outside the range supported by Stata."
639 )
640 value = data[col].max()
641 if dtype == np.float32 and value > float32_max:
642 data[col] = data[col].astype(np.float64)
643 elif dtype == np.float64:
644 if value > float64_max:
645 raise ValueError(
646 f"Column {col} has a maximum value ({value}) outside the range "
647 f"supported by Stata ({float64_max})"
648 )
649 if is_nullable_int:
650 if orig_missing.any():
651 # Replace missing by Stata sentinel value
652 sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
653 data.loc[orig_missing, col] = sentinel
654 if ws:
655 warnings.warn(
656 ws,
657 PossiblePrecisionLoss,
658 stacklevel=find_stack_level(),
659 )
661 return data
664class StataValueLabel:
665 """
666 Parse a categorical column and prepare formatted output
668 Parameters
669 ----------
670 catarray : Series
671 Categorical Series to encode
672 encoding : {"latin-1", "utf-8"}
673 Encoding to use for value labels.
674 """
676 def __init__(
677 self, catarray: Series, encoding: Literal["latin-1", "utf-8"] = "latin-1"
678 ) -> None:
680 if encoding not in ("latin-1", "utf-8"):
681 raise ValueError("Only latin-1 and utf-8 are supported.")
682 self.labname = catarray.name
683 self._encoding = encoding
684 categories = catarray.cat.categories
685 self.value_labels: list[tuple[float, str]] = list(
686 zip(np.arange(len(categories)), categories)
687 )
688 self.value_labels.sort(key=lambda x: x[0])
690 self._prepare_value_labels()
692 def _prepare_value_labels(self):
693 """Encode value labels."""
695 self.text_len = 0
696 self.txt: list[bytes] = []
697 self.n = 0
698 # Offsets (length of categories), converted to int32
699 self.off = np.array([], dtype=np.int32)
700 # Values, converted to int32
701 self.val = np.array([], dtype=np.int32)
702 self.len = 0
704 # Compute lengths and setup lists of offsets and labels
705 offsets: list[int] = []
706 values: list[float] = []
707 for vl in self.value_labels:
708 category: str | bytes = vl[1]
709 if not isinstance(category, str):
710 category = str(category)
711 warnings.warn(
712 value_label_mismatch_doc.format(self.labname),
713 ValueLabelTypeMismatch,
714 stacklevel=find_stack_level(),
715 )
716 category = category.encode(self._encoding)
717 offsets.append(self.text_len)
718 self.text_len += len(category) + 1 # +1 for the padding
719 values.append(vl[0])
720 self.txt.append(category)
721 self.n += 1
723 if self.text_len > 32000:
724 raise ValueError(
725 "Stata value labels for a single variable must "
726 "have a combined length less than 32,000 characters."
727 )
729 # Ensure int32
730 self.off = np.array(offsets, dtype=np.int32)
731 self.val = np.array(values, dtype=np.int32)
733 # Total length
734 self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len
736 def generate_value_label(self, byteorder: str) -> bytes:
737 """
738 Generate the binary representation of the value labels.
740 Parameters
741 ----------
742 byteorder : str
743 Byte order of the output
745 Returns
746 -------
747 value_label : bytes
748 Bytes containing the formatted value label
749 """
750 encoding = self._encoding
751 bio = BytesIO()
752 null_byte = b"\x00"
754 # len
755 bio.write(struct.pack(byteorder + "i", self.len))
757 # labname
758 labname = str(self.labname)[:32].encode(encoding)
759 lab_len = 32 if encoding not in ("utf-8", "utf8") else 128
760 labname = _pad_bytes(labname, lab_len + 1)
761 bio.write(labname)
763 # padding - 3 bytes
764 for i in range(3):
765 bio.write(struct.pack("c", null_byte))
767 # value_label_table
768 # n - int32
769 bio.write(struct.pack(byteorder + "i", self.n))
771 # textlen - int32
772 bio.write(struct.pack(byteorder + "i", self.text_len))
774 # off - int32 array (n elements)
775 for offset in self.off:
776 bio.write(struct.pack(byteorder + "i", offset))
778 # val - int32 array (n elements)
779 for value in self.val:
780 bio.write(struct.pack(byteorder + "i", value))
782 # txt - Text labels, null terminated
783 for text in self.txt:
784 bio.write(text + null_byte)
786 return bio.getvalue()
789class StataNonCatValueLabel(StataValueLabel):
790 """
791 Prepare formatted version of value labels
793 Parameters
794 ----------
795 labname : str
796 Value label name
797 value_labels: Dictionary
798 Mapping of values to labels
799 encoding : {"latin-1", "utf-8"}
800 Encoding to use for value labels.
801 """
803 def __init__(
804 self,
805 labname: str,
806 value_labels: dict[float, str],
807 encoding: Literal["latin-1", "utf-8"] = "latin-1",
808 ) -> None:
810 if encoding not in ("latin-1", "utf-8"):
811 raise ValueError("Only latin-1 and utf-8 are supported.")
813 self.labname = labname
814 self._encoding = encoding
815 self.value_labels: list[tuple[float, str]] = sorted(
816 value_labels.items(), key=lambda x: x[0]
817 )
818 self._prepare_value_labels()
821class StataMissingValue:
822 """
823 An observation's missing value.
825 Parameters
826 ----------
827 value : {int, float}
828 The Stata missing value code
830 Notes
831 -----
832 More information: <https://www.stata.com/help.cgi?missing>
834 Integer missing values make the code '.', '.a', ..., '.z' to the ranges
835 101 ... 127 (for int8), 32741 ... 32767 (for int16) and 2147483621 ...
836 2147483647 (for int32). Missing values for floating point data types are
837 more complex but the pattern is simple to discern from the following table.
839 np.float32 missing values (float in Stata)
840 0000007f .
841 0008007f .a
842 0010007f .b
843 ...
844 00c0007f .x
845 00c8007f .y
846 00d0007f .z
848 np.float64 missing values (double in Stata)
849 000000000000e07f .
850 000000000001e07f .a
851 000000000002e07f .b
852 ...
853 000000000018e07f .x
854 000000000019e07f .y
855 00000000001ae07f .z
856 """
858 # Construct a dictionary of missing values
859 MISSING_VALUES: dict[float, str] = {}
860 bases: Final = (101, 32741, 2147483621)
861 for b in bases:
862 # Conversion to long to avoid hash issues on 32 bit platforms #8968
863 MISSING_VALUES[b] = "."
864 for i in range(1, 27):
865 MISSING_VALUES[i + b] = "." + chr(96 + i)
867 float32_base: bytes = b"\x00\x00\x00\x7f"
868 increment: int = struct.unpack("<i", b"\x00\x08\x00\x00")[0]
869 for i in range(27):
870 key = struct.unpack("<f", float32_base)[0]
871 MISSING_VALUES[key] = "."
872 if i > 0:
873 MISSING_VALUES[key] += chr(96 + i)
874 int_value = struct.unpack("<i", struct.pack("<f", key))[0] + increment
875 float32_base = struct.pack("<i", int_value)
877 float64_base: bytes = b"\x00\x00\x00\x00\x00\x00\xe0\x7f"
878 increment = struct.unpack("q", b"\x00\x00\x00\x00\x00\x01\x00\x00")[0]
879 for i in range(27):
880 key = struct.unpack("<d", float64_base)[0]
881 MISSING_VALUES[key] = "."
882 if i > 0:
883 MISSING_VALUES[key] += chr(96 + i)
884 int_value = struct.unpack("q", struct.pack("<d", key))[0] + increment
885 float64_base = struct.pack("q", int_value)
887 BASE_MISSING_VALUES: Final = {
888 "int8": 101,
889 "int16": 32741,
890 "int32": 2147483621,
891 "float32": struct.unpack("<f", float32_base)[0],
892 "float64": struct.unpack("<d", float64_base)[0],
893 }
895 def __init__(self, value: float) -> None:
896 self._value = value
897 # Conversion to int to avoid hash issues on 32 bit platforms #8968
898 value = int(value) if value < 2147483648 else float(value)
899 self._str = self.MISSING_VALUES[value]
901 @property
902 def string(self) -> str:
903 """
904 The Stata representation of the missing value: '.', '.a'..'.z'
906 Returns
907 -------
908 str
909 The representation of the missing value.
910 """
911 return self._str
913 @property
914 def value(self) -> float:
915 """
916 The binary representation of the missing value.
918 Returns
919 -------
920 {int, float}
921 The binary representation of the missing value.
922 """
923 return self._value
925 def __str__(self) -> str:
926 return self.string
928 def __repr__(self) -> str:
929 return f"{type(self)}({self})"
931 def __eq__(self, other: Any) -> bool:
932 return (
933 isinstance(other, type(self))
934 and self.string == other.string
935 and self.value == other.value
936 )
938 @classmethod
939 def get_base_missing_value(cls, dtype: np.dtype) -> float:
940 if dtype.type is np.int8:
941 value = cls.BASE_MISSING_VALUES["int8"]
942 elif dtype.type is np.int16:
943 value = cls.BASE_MISSING_VALUES["int16"]
944 elif dtype.type is np.int32:
945 value = cls.BASE_MISSING_VALUES["int32"]
946 elif dtype.type is np.float32:
947 value = cls.BASE_MISSING_VALUES["float32"]
948 elif dtype.type is np.float64:
949 value = cls.BASE_MISSING_VALUES["float64"]
950 else:
951 raise ValueError("Unsupported dtype")
952 return value
955class StataParser:
956 def __init__(self) -> None:
958 # type code.
959 # --------------------
960 # str1 1 = 0x01
961 # str2 2 = 0x02
962 # ...
963 # str244 244 = 0xf4
964 # byte 251 = 0xfb (sic)
965 # int 252 = 0xfc
966 # long 253 = 0xfd
967 # float 254 = 0xfe
968 # double 255 = 0xff
969 # --------------------
970 # NOTE: the byte type seems to be reserved for categorical variables
971 # with a label, but the underlying variable is -127 to 100
972 # we're going to drop the label and cast to int
973 self.DTYPE_MAP = dict(
974 list(zip(range(1, 245), [np.dtype("a" + str(i)) for i in range(1, 245)]))
975 + [
976 (251, np.dtype(np.int8)),
977 (252, np.dtype(np.int16)),
978 (253, np.dtype(np.int32)),
979 (254, np.dtype(np.float32)),
980 (255, np.dtype(np.float64)),
981 ]
982 )
983 self.DTYPE_MAP_XML: dict[int, np.dtype] = {
984 32768: np.dtype(np.uint8), # Keys to GSO
985 65526: np.dtype(np.float64),
986 65527: np.dtype(np.float32),
987 65528: np.dtype(np.int32),
988 65529: np.dtype(np.int16),
989 65530: np.dtype(np.int8),
990 }
991 self.TYPE_MAP = list(tuple(range(251)) + tuple("bhlfd"))
992 self.TYPE_MAP_XML = {
993 # Not really a Q, unclear how to handle byteswap
994 32768: "Q",
995 65526: "d",
996 65527: "f",
997 65528: "l",
998 65529: "h",
999 65530: "b",
1000 }
1001 # NOTE: technically, some of these are wrong. there are more numbers
1002 # that can be represented. it's the 27 ABOVE and BELOW the max listed
1003 # numeric data type in [U] 12.2.2 of the 11.2 manual
1004 float32_min = b"\xff\xff\xff\xfe"
1005 float32_max = b"\xff\xff\xff\x7e"
1006 float64_min = b"\xff\xff\xff\xff\xff\xff\xef\xff"
1007 float64_max = b"\xff\xff\xff\xff\xff\xff\xdf\x7f"
1008 self.VALID_RANGE = {
1009 "b": (-127, 100),
1010 "h": (-32767, 32740),
1011 "l": (-2147483647, 2147483620),
1012 "f": (
1013 np.float32(struct.unpack("<f", float32_min)[0]),
1014 np.float32(struct.unpack("<f", float32_max)[0]),
1015 ),
1016 "d": (
1017 np.float64(struct.unpack("<d", float64_min)[0]),
1018 np.float64(struct.unpack("<d", float64_max)[0]),
1019 ),
1020 }
1022 self.OLD_TYPE_MAPPING = {
1023 98: 251, # byte
1024 105: 252, # int
1025 108: 253, # long
1026 102: 254, # float
1027 100: 255, # double
1028 }
1030 # These missing values are the generic '.' in Stata, and are used
1031 # to replace nans
1032 self.MISSING_VALUES = {
1033 "b": 101,
1034 "h": 32741,
1035 "l": 2147483621,
1036 "f": np.float32(struct.unpack("<f", b"\x00\x00\x00\x7f")[0]),
1037 "d": np.float64(
1038 struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
1039 ),
1040 }
1041 self.NUMPY_TYPE_MAP = {
1042 "b": "i1",
1043 "h": "i2",
1044 "l": "i4",
1045 "f": "f4",
1046 "d": "f8",
1047 "Q": "u8",
1048 }
1050 # Reserved words cannot be used as variable names
1051 self.RESERVED_WORDS = (
1052 "aggregate",
1053 "array",
1054 "boolean",
1055 "break",
1056 "byte",
1057 "case",
1058 "catch",
1059 "class",
1060 "colvector",
1061 "complex",
1062 "const",
1063 "continue",
1064 "default",
1065 "delegate",
1066 "delete",
1067 "do",
1068 "double",
1069 "else",
1070 "eltypedef",
1071 "end",
1072 "enum",
1073 "explicit",
1074 "export",
1075 "external",
1076 "float",
1077 "for",
1078 "friend",
1079 "function",
1080 "global",
1081 "goto",
1082 "if",
1083 "inline",
1084 "int",
1085 "local",
1086 "long",
1087 "NULL",
1088 "pragma",
1089 "protected",
1090 "quad",
1091 "rowvector",
1092 "short",
1093 "typedef",
1094 "typename",
1095 "virtual",
1096 "_all",
1097 "_N",
1098 "_skip",
1099 "_b",
1100 "_pi",
1101 "str#",
1102 "in",
1103 "_pred",
1104 "strL",
1105 "_coef",
1106 "_rc",
1107 "using",
1108 "_cons",
1109 "_se",
1110 "with",
1111 "_n",
1112 )
1115class StataReader(StataParser, abc.Iterator):
1116 __doc__ = _stata_reader_doc
1118 def __init__(
1119 self,
1120 path_or_buf: FilePath | ReadBuffer[bytes],
1121 convert_dates: bool = True,
1122 convert_categoricals: bool = True,
1123 index_col: str | None = None,
1124 convert_missing: bool = False,
1125 preserve_dtypes: bool = True,
1126 columns: Sequence[str] | None = None,
1127 order_categoricals: bool = True,
1128 chunksize: int | None = None,
1129 compression: CompressionOptions = "infer",
1130 storage_options: StorageOptions = None,
1131 ) -> None:
1132 super().__init__()
1133 self.col_sizes: list[int] = []
1135 # Arguments to the reader (can be temporarily overridden in
1136 # calls to read).
1137 self._convert_dates = convert_dates
1138 self._convert_categoricals = convert_categoricals
1139 self._index_col = index_col
1140 self._convert_missing = convert_missing
1141 self._preserve_dtypes = preserve_dtypes
1142 self._columns = columns
1143 self._order_categoricals = order_categoricals
1144 self._encoding = ""
1145 self._chunksize = chunksize
1146 self._using_iterator = False
1147 if self._chunksize is None:
1148 self._chunksize = 1
1149 elif not isinstance(chunksize, int) or chunksize <= 0:
1150 raise ValueError("chunksize must be a positive integer when set.")
1152 # State variables for the file
1153 self._has_string_data = False
1154 self._missing_values = False
1155 self._can_read_value_labels = False
1156 self._column_selector_set = False
1157 self._value_labels_read = False
1158 self._data_read = False
1159 self._dtype: np.dtype | None = None
1160 self._lines_read = 0
1162 self._native_byteorder = _set_endianness(sys.byteorder)
1163 with get_handle(
1164 path_or_buf,
1165 "rb",
1166 storage_options=storage_options,
1167 is_text=False,
1168 compression=compression,
1169 ) as handles:
1170 # Copy to BytesIO, and ensure no encoding
1171 self.path_or_buf = BytesIO(handles.handle.read())
1173 self._read_header()
1174 self._setup_dtype()
1176 def __enter__(self) -> StataReader:
1177 """enter context manager"""
1178 return self
1180 def __exit__(self, exc_type, exc_value, traceback) -> None:
1181 """exit context manager"""
1182 self.close()
1184 def close(self) -> None:
1185 """close the handle if its open"""
1186 self.path_or_buf.close()
1188 def _set_encoding(self) -> None:
1189 """
1190 Set string encoding which depends on file version
1191 """
1192 if self.format_version < 118:
1193 self._encoding = "latin-1"
1194 else:
1195 self._encoding = "utf-8"
1197 def _read_header(self) -> None:
1198 first_char = self.path_or_buf.read(1)
1199 if struct.unpack("c", first_char)[0] == b"<":
1200 self._read_new_header()
1201 else:
1202 self._read_old_header(first_char)
1204 self.has_string_data = len([x for x in self.typlist if type(x) is int]) > 0
1206 # calculate size of a data record
1207 self.col_sizes = [self._calcsize(typ) for typ in self.typlist]
1209 def _read_new_header(self) -> None:
1210 # The first part of the header is common to 117 - 119.
1211 self.path_or_buf.read(27) # stata_dta><header><release>
1212 self.format_version = int(self.path_or_buf.read(3))
1213 if self.format_version not in [117, 118, 119]:
1214 raise ValueError(_version_error.format(version=self.format_version))
1215 self._set_encoding()
1216 self.path_or_buf.read(21) # </release><byteorder>
1217 self.byteorder = self.path_or_buf.read(3) == b"MSF" and ">" or "<"
1218 self.path_or_buf.read(15) # </byteorder><K>
1219 nvar_type = "H" if self.format_version <= 118 else "I"
1220 nvar_size = 2 if self.format_version <= 118 else 4
1221 self.nvar = struct.unpack(
1222 self.byteorder + nvar_type, self.path_or_buf.read(nvar_size)
1223 )[0]
1224 self.path_or_buf.read(7) # </K><N>
1226 self.nobs = self._get_nobs()
1227 self.path_or_buf.read(11) # </N><label>
1228 self._data_label = self._get_data_label()
1229 self.path_or_buf.read(19) # </label><timestamp>
1230 self.time_stamp = self._get_time_stamp()
1231 self.path_or_buf.read(26) # </timestamp></header><map>
1232 self.path_or_buf.read(8) # 0x0000000000000000
1233 self.path_or_buf.read(8) # position of <map>
1235 self._seek_vartypes = (
1236 struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16
1237 )
1238 self._seek_varnames = (
1239 struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1240 )
1241 self._seek_sortlist = (
1242 struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1243 )
1244 self._seek_formats = (
1245 struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9
1246 )
1247 self._seek_value_label_names = (
1248 struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19
1249 )
1251 # Requires version-specific treatment
1252 self._seek_variable_labels = self._get_seek_variable_labels()
1254 self.path_or_buf.read(8) # <characteristics>
1255 self.data_location = (
1256 struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6
1257 )
1258 self.seek_strls = (
1259 struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7
1260 )
1261 self.seek_value_labels = (
1262 struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14
1263 )
1265 self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes)
1267 self.path_or_buf.seek(self._seek_varnames)
1268 self.varlist = self._get_varlist()
1270 self.path_or_buf.seek(self._seek_sortlist)
1271 self.srtlist = struct.unpack(
1272 self.byteorder + ("h" * (self.nvar + 1)),
1273 self.path_or_buf.read(2 * (self.nvar + 1)),
1274 )[:-1]
1276 self.path_or_buf.seek(self._seek_formats)
1277 self.fmtlist = self._get_fmtlist()
1279 self.path_or_buf.seek(self._seek_value_label_names)
1280 self.lbllist = self._get_lbllist()
1282 self.path_or_buf.seek(self._seek_variable_labels)
1283 self._variable_labels = self._get_variable_labels()
1285 # Get data type information, works for versions 117-119.
1286 def _get_dtypes(
1287 self, seek_vartypes: int
1288 ) -> tuple[list[int | str], list[str | np.dtype]]:
1290 self.path_or_buf.seek(seek_vartypes)
1291 raw_typlist = [
1292 struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1293 for _ in range(self.nvar)
1294 ]
1296 def f(typ: int) -> int | str:
1297 if typ <= 2045:
1298 return typ
1299 try:
1300 return self.TYPE_MAP_XML[typ]
1301 except KeyError as err:
1302 raise ValueError(f"cannot convert stata types [{typ}]") from err
1304 typlist = [f(x) for x in raw_typlist]
1306 def g(typ: int) -> str | np.dtype:
1307 if typ <= 2045:
1308 return str(typ)
1309 try:
1310 return self.DTYPE_MAP_XML[typ]
1311 except KeyError as err:
1312 raise ValueError(f"cannot convert stata dtype [{typ}]") from err
1314 dtyplist = [g(x) for x in raw_typlist]
1316 return typlist, dtyplist
1318 def _get_varlist(self) -> list[str]:
1319 # 33 in order formats, 129 in formats 118 and 119
1320 b = 33 if self.format_version < 118 else 129
1321 return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
1323 # Returns the format list
1324 def _get_fmtlist(self) -> list[str]:
1325 if self.format_version >= 118:
1326 b = 57
1327 elif self.format_version > 113:
1328 b = 49
1329 elif self.format_version > 104:
1330 b = 12
1331 else:
1332 b = 7
1334 return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
1336 # Returns the label list
1337 def _get_lbllist(self) -> list[str]:
1338 if self.format_version >= 118:
1339 b = 129
1340 elif self.format_version > 108:
1341 b = 33
1342 else:
1343 b = 9
1344 return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
1346 def _get_variable_labels(self) -> list[str]:
1347 if self.format_version >= 118:
1348 vlblist = [
1349 self._decode(self.path_or_buf.read(321)) for _ in range(self.nvar)
1350 ]
1351 elif self.format_version > 105:
1352 vlblist = [
1353 self._decode(self.path_or_buf.read(81)) for _ in range(self.nvar)
1354 ]
1355 else:
1356 vlblist = [
1357 self._decode(self.path_or_buf.read(32)) for _ in range(self.nvar)
1358 ]
1359 return vlblist
1361 def _get_nobs(self) -> int:
1362 if self.format_version >= 118:
1363 return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1364 else:
1365 return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1367 def _get_data_label(self) -> str:
1368 if self.format_version >= 118:
1369 strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1370 return self._decode(self.path_or_buf.read(strlen))
1371 elif self.format_version == 117:
1372 strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1373 return self._decode(self.path_or_buf.read(strlen))
1374 elif self.format_version > 105:
1375 return self._decode(self.path_or_buf.read(81))
1376 else:
1377 return self._decode(self.path_or_buf.read(32))
1379 def _get_time_stamp(self) -> str:
1380 if self.format_version >= 118:
1381 strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1382 return self.path_or_buf.read(strlen).decode("utf-8")
1383 elif self.format_version == 117:
1384 strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1385 return self._decode(self.path_or_buf.read(strlen))
1386 elif self.format_version > 104:
1387 return self._decode(self.path_or_buf.read(18))
1388 else:
1389 raise ValueError()
1391 def _get_seek_variable_labels(self) -> int:
1392 if self.format_version == 117:
1393 self.path_or_buf.read(8) # <variable_labels>, throw away
1394 # Stata 117 data files do not follow the described format. This is
1395 # a work around that uses the previous label, 33 bytes for each
1396 # variable, 20 for the closing tag and 17 for the opening tag
1397 return self._seek_value_label_names + (33 * self.nvar) + 20 + 17
1398 elif self.format_version >= 118:
1399 return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17
1400 else:
1401 raise ValueError()
1403 def _read_old_header(self, first_char: bytes) -> None:
1404 self.format_version = struct.unpack("b", first_char)[0]
1405 if self.format_version not in [104, 105, 108, 111, 113, 114, 115]:
1406 raise ValueError(_version_error.format(version=self.format_version))
1407 self._set_encoding()
1408 self.byteorder = (
1409 struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 and ">" or "<"
1410 )
1411 self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0]
1412 self.path_or_buf.read(1) # unused
1414 self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1415 self.nobs = self._get_nobs()
1417 self._data_label = self._get_data_label()
1419 self.time_stamp = self._get_time_stamp()
1421 # descriptors
1422 if self.format_version > 108:
1423 typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)]
1424 else:
1425 buf = self.path_or_buf.read(self.nvar)
1426 typlistb = np.frombuffer(buf, dtype=np.uint8)
1427 typlist = []
1428 for tp in typlistb:
1429 if tp in self.OLD_TYPE_MAPPING:
1430 typlist.append(self.OLD_TYPE_MAPPING[tp])
1431 else:
1432 typlist.append(tp - 127) # bytes
1434 try:
1435 self.typlist = [self.TYPE_MAP[typ] for typ in typlist]
1436 except ValueError as err:
1437 invalid_types = ",".join([str(x) for x in typlist])
1438 raise ValueError(f"cannot convert stata types [{invalid_types}]") from err
1439 try:
1440 self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
1441 except ValueError as err:
1442 invalid_dtypes = ",".join([str(x) for x in typlist])
1443 raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err
1445 if self.format_version > 108:
1446 self.varlist = [
1447 self._decode(self.path_or_buf.read(33)) for _ in range(self.nvar)
1448 ]
1449 else:
1450 self.varlist = [
1451 self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar)
1452 ]
1453 self.srtlist = struct.unpack(
1454 self.byteorder + ("h" * (self.nvar + 1)),
1455 self.path_or_buf.read(2 * (self.nvar + 1)),
1456 )[:-1]
1458 self.fmtlist = self._get_fmtlist()
1460 self.lbllist = self._get_lbllist()
1462 self._variable_labels = self._get_variable_labels()
1464 # ignore expansion fields (Format 105 and later)
1465 # When reading, read five bytes; the last four bytes now tell you
1466 # the size of the next read, which you discard. You then continue
1467 # like this until you read 5 bytes of zeros.
1469 if self.format_version > 104:
1470 while True:
1471 data_type = struct.unpack(
1472 self.byteorder + "b", self.path_or_buf.read(1)
1473 )[0]
1474 if self.format_version > 108:
1475 data_len = struct.unpack(
1476 self.byteorder + "i", self.path_or_buf.read(4)
1477 )[0]
1478 else:
1479 data_len = struct.unpack(
1480 self.byteorder + "h", self.path_or_buf.read(2)
1481 )[0]
1482 if data_type == 0:
1483 break
1484 self.path_or_buf.read(data_len)
1486 # necessary data to continue parsing
1487 self.data_location = self.path_or_buf.tell()
1489 def _setup_dtype(self) -> np.dtype:
1490 """Map between numpy and state dtypes"""
1491 if self._dtype is not None:
1492 return self._dtype
1494 dtypes = [] # Convert struct data types to numpy data type
1495 for i, typ in enumerate(self.typlist):
1496 if typ in self.NUMPY_TYPE_MAP:
1497 typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP
1498 dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ]))
1499 else:
1500 dtypes.append(("s" + str(i), "S" + str(typ)))
1501 self._dtype = np.dtype(dtypes)
1503 return self._dtype
1505 def _calcsize(self, fmt: int | str) -> int:
1506 if isinstance(fmt, int):
1507 return fmt
1508 return struct.calcsize(self.byteorder + fmt)
1510 def _decode(self, s: bytes) -> str:
1511 # have bytes not strings, so must decode
1512 s = s.partition(b"\0")[0]
1513 try:
1514 return s.decode(self._encoding)
1515 except UnicodeDecodeError:
1516 # GH 25960, fallback to handle incorrect format produced when 117
1517 # files are converted to 118 files in Stata
1518 encoding = self._encoding
1519 msg = f"""
1520One or more strings in the dta file could not be decoded using {encoding}, and
1521so the fallback encoding of latin-1 is being used. This can happen when a file
1522has been incorrectly encoded by Stata or some other software. You should verify
1523the string values returned are correct."""
1524 warnings.warn(
1525 msg,
1526 UnicodeWarning,
1527 stacklevel=find_stack_level(),
1528 )
1529 return s.decode("latin-1")
1531 def _read_value_labels(self) -> None:
1532 if self._value_labels_read:
1533 # Don't read twice
1534 return
1535 if self.format_version <= 108:
1536 # Value labels are not supported in version 108 and earlier.
1537 self._value_labels_read = True
1538 self.value_label_dict: dict[str, dict[float, str]] = {}
1539 return
1541 if self.format_version >= 117:
1542 self.path_or_buf.seek(self.seek_value_labels)
1543 else:
1544 assert self._dtype is not None
1545 offset = self.nobs * self._dtype.itemsize
1546 self.path_or_buf.seek(self.data_location + offset)
1548 self._value_labels_read = True
1549 self.value_label_dict = {}
1551 while True:
1552 if self.format_version >= 117:
1553 if self.path_or_buf.read(5) == b"</val": # <lbl>
1554 break # end of value label table
1556 slength = self.path_or_buf.read(4)
1557 if not slength:
1558 break # end of value label table (format < 117)
1559 if self.format_version <= 117:
1560 labname = self._decode(self.path_or_buf.read(33))
1561 else:
1562 labname = self._decode(self.path_or_buf.read(129))
1563 self.path_or_buf.read(3) # padding
1565 n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1566 txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1567 off = np.frombuffer(
1568 self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
1569 )
1570 val = np.frombuffer(
1571 self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
1572 )
1573 ii = np.argsort(off)
1574 off = off[ii]
1575 val = val[ii]
1576 txt = self.path_or_buf.read(txtlen)
1577 self.value_label_dict[labname] = {}
1578 for i in range(n):
1579 end = off[i + 1] if i < n - 1 else txtlen
1580 self.value_label_dict[labname][val[i]] = self._decode(txt[off[i] : end])
1581 if self.format_version >= 117:
1582 self.path_or_buf.read(6) # </lbl>
1583 self._value_labels_read = True
1585 def _read_strls(self) -> None:
1586 self.path_or_buf.seek(self.seek_strls)
1587 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1588 self.GSO = {"0": ""}
1589 while True:
1590 if self.path_or_buf.read(3) != b"GSO":
1591 break
1593 if self.format_version == 117:
1594 v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1595 else:
1596 buf = self.path_or_buf.read(12)
1597 # Only tested on little endian file on little endian machine.
1598 v_size = 2 if self.format_version == 118 else 3
1599 if self.byteorder == "<":
1600 buf = buf[0:v_size] + buf[4 : (12 - v_size)]
1601 else:
1602 # This path may not be correct, impossible to test
1603 buf = buf[0:v_size] + buf[(4 + v_size) :]
1604 v_o = struct.unpack("Q", buf)[0]
1605 typ = struct.unpack("B", self.path_or_buf.read(1))[0]
1606 length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1607 va = self.path_or_buf.read(length)
1608 if typ == 130:
1609 decoded_va = va[0:-1].decode(self._encoding)
1610 else:
1611 # Stata says typ 129 can be binary, so use str
1612 decoded_va = str(va)
1613 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1614 self.GSO[str(v_o)] = decoded_va
1616 def __next__(self) -> DataFrame:
1617 self._using_iterator = True
1618 return self.read(nrows=self._chunksize)
1620 def get_chunk(self, size: int | None = None) -> DataFrame:
1621 """
1622 Reads lines from Stata file and returns as dataframe
1624 Parameters
1625 ----------
1626 size : int, defaults to None
1627 Number of lines to read. If None, reads whole file.
1629 Returns
1630 -------
1631 DataFrame
1632 """
1633 if size is None:
1634 size = self._chunksize
1635 return self.read(nrows=size)
1637 @Appender(_read_method_doc)
1638 def read(
1639 self,
1640 nrows: int | None = None,
1641 convert_dates: bool | None = None,
1642 convert_categoricals: bool | None = None,
1643 index_col: str | None = None,
1644 convert_missing: bool | None = None,
1645 preserve_dtypes: bool | None = None,
1646 columns: Sequence[str] | None = None,
1647 order_categoricals: bool | None = None,
1648 ) -> DataFrame:
1649 # Handle empty file or chunk. If reading incrementally raise
1650 # StopIteration. If reading the whole thing return an empty
1651 # data frame.
1652 if (self.nobs == 0) and (nrows is None):
1653 self._can_read_value_labels = True
1654 self._data_read = True
1655 self.close()
1656 return DataFrame(columns=self.varlist)
1658 # Handle options
1659 if convert_dates is None:
1660 convert_dates = self._convert_dates
1661 if convert_categoricals is None:
1662 convert_categoricals = self._convert_categoricals
1663 if convert_missing is None:
1664 convert_missing = self._convert_missing
1665 if preserve_dtypes is None:
1666 preserve_dtypes = self._preserve_dtypes
1667 if columns is None:
1668 columns = self._columns
1669 if order_categoricals is None:
1670 order_categoricals = self._order_categoricals
1671 if index_col is None:
1672 index_col = self._index_col
1674 if nrows is None:
1675 nrows = self.nobs
1677 if (self.format_version >= 117) and (not self._value_labels_read):
1678 self._can_read_value_labels = True
1679 self._read_strls()
1681 # Read data
1682 assert self._dtype is not None
1683 dtype = self._dtype
1684 max_read_len = (self.nobs - self._lines_read) * dtype.itemsize
1685 read_len = nrows * dtype.itemsize
1686 read_len = min(read_len, max_read_len)
1687 if read_len <= 0:
1688 # Iterator has finished, should never be here unless
1689 # we are reading the file incrementally
1690 if convert_categoricals:
1691 self._read_value_labels()
1692 self.close()
1693 raise StopIteration
1694 offset = self._lines_read * dtype.itemsize
1695 self.path_or_buf.seek(self.data_location + offset)
1696 read_lines = min(nrows, self.nobs - self._lines_read)
1697 raw_data = np.frombuffer(
1698 self.path_or_buf.read(read_len), dtype=dtype, count=read_lines
1699 )
1701 self._lines_read += read_lines
1702 if self._lines_read == self.nobs:
1703 self._can_read_value_labels = True
1704 self._data_read = True
1705 # if necessary, swap the byte order to native here
1706 if self.byteorder != self._native_byteorder:
1707 raw_data = raw_data.byteswap().newbyteorder()
1709 if convert_categoricals:
1710 self._read_value_labels()
1712 if len(raw_data) == 0:
1713 data = DataFrame(columns=self.varlist)
1714 else:
1715 data = DataFrame.from_records(raw_data)
1716 data.columns = Index(self.varlist)
1718 # If index is not specified, use actual row number rather than
1719 # restarting at 0 for each chunk.
1720 if index_col is None:
1721 rng = np.arange(self._lines_read - read_lines, self._lines_read)
1722 data.index = Index(rng) # set attr instead of set_index to avoid copy
1724 if columns is not None:
1725 try:
1726 data = self._do_select_columns(data, columns)
1727 except ValueError:
1728 self.close()
1729 raise
1731 # Decode strings
1732 for col, typ in zip(data, self.typlist):
1733 if type(typ) is int:
1734 data[col] = data[col].apply(self._decode, convert_dtype=True)
1736 data = self._insert_strls(data)
1738 cols_ = np.where([dtyp is not None for dtyp in self.dtyplist])[0]
1739 # Convert columns (if needed) to match input type
1740 ix = data.index
1741 requires_type_conversion = False
1742 data_formatted = []
1743 for i in cols_:
1744 if self.dtyplist[i] is not None:
1745 col = data.columns[i]
1746 dtype = data[col].dtype
1747 if dtype != np.dtype(object) and dtype != self.dtyplist[i]:
1748 requires_type_conversion = True
1749 data_formatted.append(
1750 (col, Series(data[col], ix, self.dtyplist[i]))
1751 )
1752 else:
1753 data_formatted.append((col, data[col]))
1754 if requires_type_conversion:
1755 data = DataFrame.from_dict(dict(data_formatted))
1756 del data_formatted
1758 data = self._do_convert_missing(data, convert_missing)
1760 if convert_dates:
1762 def any_startswith(x: str) -> bool:
1763 return any(x.startswith(fmt) for fmt in _date_formats)
1765 cols = np.where([any_startswith(x) for x in self.fmtlist])[0]
1766 for i in cols:
1767 col = data.columns[i]
1768 try:
1769 data[col] = _stata_elapsed_date_to_datetime_vec(
1770 data[col], self.fmtlist[i]
1771 )
1772 except ValueError:
1773 self.close()
1774 raise
1776 if convert_categoricals and self.format_version > 108:
1777 data = self._do_convert_categoricals(
1778 data, self.value_label_dict, self.lbllist, order_categoricals
1779 )
1781 if not preserve_dtypes:
1782 retyped_data = []
1783 convert = False
1784 for col in data:
1785 dtype = data[col].dtype
1786 if dtype in (np.dtype(np.float16), np.dtype(np.float32)):
1787 dtype = np.dtype(np.float64)
1788 convert = True
1789 elif dtype in (
1790 np.dtype(np.int8),
1791 np.dtype(np.int16),
1792 np.dtype(np.int32),
1793 ):
1794 dtype = np.dtype(np.int64)
1795 convert = True
1796 retyped_data.append((col, data[col].astype(dtype)))
1797 if convert:
1798 data = DataFrame.from_dict(dict(retyped_data))
1800 if index_col is not None:
1801 data = data.set_index(data.pop(index_col))
1803 return data
1805 def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame:
1806 # Check for missing values, and replace if found
1807 replacements = {}
1808 for i, colname in enumerate(data):
1809 fmt = self.typlist[i]
1810 if fmt not in self.VALID_RANGE:
1811 continue
1813 fmt = cast(str, fmt) # only strs in VALID_RANGE
1814 nmin, nmax = self.VALID_RANGE[fmt]
1815 series = data[colname]
1817 # appreciably faster to do this with ndarray instead of Series
1818 svals = series._values
1819 missing = (svals < nmin) | (svals > nmax)
1821 if not missing.any():
1822 continue
1824 if convert_missing: # Replacement follows Stata notation
1825 missing_loc = np.nonzero(np.asarray(missing))[0]
1826 umissing, umissing_loc = np.unique(series[missing], return_inverse=True)
1827 replacement = Series(series, dtype=object)
1828 for j, um in enumerate(umissing):
1829 missing_value = StataMissingValue(um)
1831 loc = missing_loc[umissing_loc == j]
1832 replacement.iloc[loc] = missing_value
1833 else: # All replacements are identical
1834 dtype = series.dtype
1835 if dtype not in (np.float32, np.float64):
1836 dtype = np.float64
1837 replacement = Series(series, dtype=dtype)
1838 if not replacement._values.flags["WRITEABLE"]:
1839 # only relevant for ArrayManager; construction
1840 # path for BlockManager ensures writeability
1841 replacement = replacement.copy()
1842 # Note: operating on ._values is much faster than directly
1843 # TODO: can we fix that?
1844 replacement._values[missing] = np.nan
1845 replacements[colname] = replacement
1847 if replacements:
1848 for col in replacements:
1849 data[col] = replacements[col]
1850 return data
1852 def _insert_strls(self, data: DataFrame) -> DataFrame:
1853 if not hasattr(self, "GSO") or len(self.GSO) == 0:
1854 return data
1855 for i, typ in enumerate(self.typlist):
1856 if typ != "Q":
1857 continue
1858 # Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1859 data.iloc[:, i] = [self.GSO[str(k)] for k in data.iloc[:, i]]
1860 return data
1862 def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFrame:
1864 if not self._column_selector_set:
1865 column_set = set(columns)
1866 if len(column_set) != len(columns):
1867 raise ValueError("columns contains duplicate entries")
1868 unmatched = column_set.difference(data.columns)
1869 if unmatched:
1870 joined = ", ".join(list(unmatched))
1871 raise ValueError(
1872 "The following columns were not "
1873 f"found in the Stata data set: {joined}"
1874 )
1875 # Copy information for retained columns for later processing
1876 dtyplist = []
1877 typlist = []
1878 fmtlist = []
1879 lbllist = []
1880 for col in columns:
1881 i = data.columns.get_loc(col)
1882 dtyplist.append(self.dtyplist[i])
1883 typlist.append(self.typlist[i])
1884 fmtlist.append(self.fmtlist[i])
1885 lbllist.append(self.lbllist[i])
1887 self.dtyplist = dtyplist
1888 self.typlist = typlist
1889 self.fmtlist = fmtlist
1890 self.lbllist = lbllist
1891 self._column_selector_set = True
1893 return data[columns]
1895 def _do_convert_categoricals(
1896 self,
1897 data: DataFrame,
1898 value_label_dict: dict[str, dict[float, str]],
1899 lbllist: Sequence[str],
1900 order_categoricals: bool,
1901 ) -> DataFrame:
1902 """
1903 Converts categorical columns to Categorical type.
1904 """
1905 value_labels = list(value_label_dict.keys())
1906 cat_converted_data = []
1907 for col, label in zip(data, lbllist):
1908 if label in value_labels:
1909 # Explicit call with ordered=True
1910 vl = value_label_dict[label]
1911 keys = np.array(list(vl.keys()))
1912 column = data[col]
1913 key_matches = column.isin(keys)
1914 if self._using_iterator and key_matches.all():
1915 initial_categories: np.ndarray | None = keys
1916 # If all categories are in the keys and we are iterating,
1917 # use the same keys for all chunks. If some are missing
1918 # value labels, then we will fall back to the categories
1919 # varying across chunks.
1920 else:
1921 if self._using_iterator:
1922 # warn is using an iterator
1923 warnings.warn(
1924 categorical_conversion_warning,
1925 CategoricalConversionWarning,
1926 stacklevel=find_stack_level(),
1927 )
1928 initial_categories = None
1929 cat_data = Categorical(
1930 column, categories=initial_categories, ordered=order_categoricals
1931 )
1932 if initial_categories is None:
1933 # If None here, then we need to match the cats in the Categorical
1934 categories = []
1935 for category in cat_data.categories:
1936 if category in vl:
1937 categories.append(vl[category])
1938 else:
1939 categories.append(category)
1940 else:
1941 # If all cats are matched, we can use the values
1942 categories = list(vl.values())
1943 try:
1944 # Try to catch duplicate categories
1945 # TODO: if we get a non-copying rename_categories, use that
1946 cat_data = cat_data.rename_categories(categories)
1947 except ValueError as err:
1948 vc = Series(categories).value_counts()
1949 repeated_cats = list(vc.index[vc > 1])
1950 repeats = "-" * 80 + "\n" + "\n".join(repeated_cats)
1951 # GH 25772
1952 msg = f"""
1953Value labels for column {col} are not unique. These cannot be converted to
1954pandas categoricals.
1956Either read the file with `convert_categoricals` set to False or use the
1957low level interface in `StataReader` to separately read the values and the
1958value_labels.
1960The repeated labels are:
1961{repeats}
1962"""
1963 raise ValueError(msg) from err
1964 # TODO: is the next line needed above in the data(...) method?
1965 cat_series = Series(cat_data, index=data.index)
1966 cat_converted_data.append((col, cat_series))
1967 else:
1968 cat_converted_data.append((col, data[col]))
1969 data = DataFrame(dict(cat_converted_data), copy=False)
1970 return data
1972 @property
1973 def data_label(self) -> str:
1974 """
1975 Return data label of Stata file.
1976 """
1977 return self._data_label
1979 def variable_labels(self) -> dict[str, str]:
1980 """
1981 Return a dict associating each variable name with corresponding label.
1983 Returns
1984 -------
1985 dict
1986 """
1987 return dict(zip(self.varlist, self._variable_labels))
1989 def value_labels(self) -> dict[str, dict[float, str]]:
1990 """
1991 Return a nested dict associating each variable name to its value and label.
1993 Returns
1994 -------
1995 dict
1996 """
1997 if not self._value_labels_read:
1998 self._read_value_labels()
2000 return self.value_label_dict
2003@Appender(_read_stata_doc)
2004@deprecate_nonkeyword_arguments(version=None, allowed_args=["filepath_or_buffer"])
2005def read_stata(
2006 filepath_or_buffer: FilePath | ReadBuffer[bytes],
2007 convert_dates: bool = True,
2008 convert_categoricals: bool = True,
2009 index_col: str | None = None,
2010 convert_missing: bool = False,
2011 preserve_dtypes: bool = True,
2012 columns: Sequence[str] | None = None,
2013 order_categoricals: bool = True,
2014 chunksize: int | None = None,
2015 iterator: bool = False,
2016 compression: CompressionOptions = "infer",
2017 storage_options: StorageOptions = None,
2018) -> DataFrame | StataReader:
2020 reader = StataReader(
2021 filepath_or_buffer,
2022 convert_dates=convert_dates,
2023 convert_categoricals=convert_categoricals,
2024 index_col=index_col,
2025 convert_missing=convert_missing,
2026 preserve_dtypes=preserve_dtypes,
2027 columns=columns,
2028 order_categoricals=order_categoricals,
2029 chunksize=chunksize,
2030 storage_options=storage_options,
2031 compression=compression,
2032 )
2034 if iterator or chunksize:
2035 return reader
2037 with reader:
2038 return reader.read()
2041def _set_endianness(endianness: str) -> str:
2042 if endianness.lower() in ["<", "little"]:
2043 return "<"
2044 elif endianness.lower() in [">", "big"]:
2045 return ">"
2046 else: # pragma : no cover
2047 raise ValueError(f"Endianness {endianness} not understood")
2050def _pad_bytes(name: AnyStr, length: int) -> AnyStr:
2051 """
2052 Take a char string and pads it with null bytes until it's length chars.
2053 """
2054 if isinstance(name, bytes):
2055 return name + b"\x00" * (length - len(name))
2056 return name + "\x00" * (length - len(name))
2059def _convert_datetime_to_stata_type(fmt: str) -> np.dtype:
2060 """
2061 Convert from one of the stata date formats to a type in TYPE_MAP.
2062 """
2063 if fmt in [
2064 "tc",
2065 "%tc",
2066 "td",
2067 "%td",
2068 "tw",
2069 "%tw",
2070 "tm",
2071 "%tm",
2072 "tq",
2073 "%tq",
2074 "th",
2075 "%th",
2076 "ty",
2077 "%ty",
2078 ]:
2079 return np.dtype(np.float64) # Stata expects doubles for SIFs
2080 else:
2081 raise NotImplementedError(f"Format {fmt} not implemented")
2084def _maybe_convert_to_int_keys(convert_dates: dict, varlist: list[Hashable]) -> dict:
2085 new_dict = {}
2086 for key in convert_dates:
2087 if not convert_dates[key].startswith("%"): # make sure proper fmts
2088 convert_dates[key] = "%" + convert_dates[key]
2089 if key in varlist:
2090 new_dict.update({varlist.index(key): convert_dates[key]})
2091 else:
2092 if not isinstance(key, int):
2093 raise ValueError("convert_dates key must be a column or an integer")
2094 new_dict.update({key: convert_dates[key]})
2095 return new_dict
2098def _dtype_to_stata_type(dtype: np.dtype, column: Series) -> int:
2099 """
2100 Convert dtype types to stata types. Returns the byte of the given ordinal.
2101 See TYPE_MAP and comments for an explanation. This is also explained in
2102 the dta spec.
2103 1 - 244 are strings of this length
2104 Pandas Stata
2105 251 - for int8 byte
2106 252 - for int16 int
2107 253 - for int32 long
2108 254 - for float32 float
2109 255 - for double double
2111 If there are dates to convert, then dtype will already have the correct
2112 type inserted.
2113 """
2114 # TODO: expand to handle datetime to integer conversion
2115 if dtype.type is np.object_: # try to coerce it to the biggest string
2116 # not memory efficient, what else could we
2117 # do?
2118 itemsize = max_len_string_array(ensure_object(column._values))
2119 return max(itemsize, 1)
2120 elif dtype.type is np.float64:
2121 return 255
2122 elif dtype.type is np.float32:
2123 return 254
2124 elif dtype.type is np.int32:
2125 return 253
2126 elif dtype.type is np.int16:
2127 return 252
2128 elif dtype.type is np.int8:
2129 return 251
2130 else: # pragma : no cover
2131 raise NotImplementedError(f"Data type {dtype} not supported.")
2134def _dtype_to_default_stata_fmt(
2135 dtype, column: Series, dta_version: int = 114, force_strl: bool = False
2136) -> str:
2137 """
2138 Map numpy dtype to stata's default format for this type. Not terribly
2139 important since users can change this in Stata. Semantics are
2141 object -> "%DDs" where DD is the length of the string. If not a string,
2142 raise ValueError
2143 float64 -> "%10.0g"
2144 float32 -> "%9.0g"
2145 int64 -> "%9.0g"
2146 int32 -> "%12.0g"
2147 int16 -> "%8.0g"
2148 int8 -> "%8.0g"
2149 strl -> "%9s"
2150 """
2151 # TODO: Refactor to combine type with format
2152 # TODO: expand this to handle a default datetime format?
2153 if dta_version < 117:
2154 max_str_len = 244
2155 else:
2156 max_str_len = 2045
2157 if force_strl:
2158 return "%9s"
2159 if dtype.type is np.object_:
2160 itemsize = max_len_string_array(ensure_object(column._values))
2161 if itemsize > max_str_len:
2162 if dta_version >= 117:
2163 return "%9s"
2164 else:
2165 raise ValueError(excessive_string_length_error.format(column.name))
2166 return "%" + str(max(itemsize, 1)) + "s"
2167 elif dtype == np.float64:
2168 return "%10.0g"
2169 elif dtype == np.float32:
2170 return "%9.0g"
2171 elif dtype == np.int32:
2172 return "%12.0g"
2173 elif dtype == np.int8 or dtype == np.int16:
2174 return "%8.0g"
2175 else: # pragma : no cover
2176 raise NotImplementedError(f"Data type {dtype} not supported.")
2179@doc(
2180 storage_options=_shared_docs["storage_options"],
2181 compression_options=_shared_docs["compression_options"] % "fname",
2182)
2183class StataWriter(StataParser):
2184 """
2185 A class for writing Stata binary dta files
2187 Parameters
2188 ----------
2189 fname : path (string), buffer or path object
2190 string, path object (pathlib.Path or py._path.local.LocalPath) or
2191 object implementing a binary write() functions. If using a buffer
2192 then the buffer will not be automatically closed after the file
2193 is written.
2194 data : DataFrame
2195 Input to save
2196 convert_dates : dict
2197 Dictionary mapping columns containing datetime types to stata internal
2198 format to use when writing the dates. Options are 'tc', 'td', 'tm',
2199 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
2200 Datetime columns that do not have a conversion type specified will be
2201 converted to 'tc'. Raises NotImplementedError if a datetime column has
2202 timezone information
2203 write_index : bool
2204 Write the index to Stata dataset.
2205 byteorder : str
2206 Can be ">", "<", "little", or "big". default is `sys.byteorder`
2207 time_stamp : datetime
2208 A datetime to use as file creation date. Default is the current time
2209 data_label : str
2210 A label for the data set. Must be 80 characters or smaller.
2211 variable_labels : dict
2212 Dictionary containing columns as keys and variable labels as values.
2213 Each label must be 80 characters or smaller.
2214 {compression_options}
2216 .. versionadded:: 1.1.0
2218 .. versionchanged:: 1.4.0 Zstandard support.
2220 {storage_options}
2222 .. versionadded:: 1.2.0
2224 value_labels : dict of dicts
2225 Dictionary containing columns as keys and dictionaries of column value
2226 to labels as values. The combined length of all labels for a single
2227 variable must be 32,000 characters or smaller.
2229 .. versionadded:: 1.4.0
2231 Returns
2232 -------
2233 writer : StataWriter instance
2234 The StataWriter instance has a write_file method, which will
2235 write the file to the given `fname`.
2237 Raises
2238 ------
2239 NotImplementedError
2240 * If datetimes contain timezone information
2241 ValueError
2242 * Columns listed in convert_dates are neither datetime64[ns]
2243 or datetime.datetime
2244 * Column dtype is not representable in Stata
2245 * Column listed in convert_dates is not in DataFrame
2246 * Categorical label contains more than 32,000 characters
2248 Examples
2249 --------
2250 >>> data = pd.DataFrame([[1.0, 1]], columns=['a', 'b'])
2251 >>> writer = StataWriter('./data_file.dta', data)
2252 >>> writer.write_file()
2254 Directly write a zip file
2255 >>> compression = {{"method": "zip", "archive_name": "data_file.dta"}}
2256 >>> writer = StataWriter('./data_file.zip', data, compression=compression)
2257 >>> writer.write_file()
2259 Save a DataFrame with dates
2260 >>> from datetime import datetime
2261 >>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
2262 >>> writer = StataWriter('./date_data_file.dta', data, {{'date' : 'tw'}})
2263 >>> writer.write_file()
2264 """
2266 _max_string_length = 244
2267 _encoding: Literal["latin-1", "utf-8"] = "latin-1"
2269 def __init__(
2270 self,
2271 fname: FilePath | WriteBuffer[bytes],
2272 data: DataFrame,
2273 convert_dates: dict[Hashable, str] | None = None,
2274 write_index: bool = True,
2275 byteorder: str | None = None,
2276 time_stamp: datetime.datetime | None = None,
2277 data_label: str | None = None,
2278 variable_labels: dict[Hashable, str] | None = None,
2279 compression: CompressionOptions = "infer",
2280 storage_options: StorageOptions = None,
2281 *,
2282 value_labels: dict[Hashable, dict[float, str]] | None = None,
2283 ) -> None:
2284 super().__init__()
2285 self.data = data
2286 self._convert_dates = {} if convert_dates is None else convert_dates
2287 self._write_index = write_index
2288 self._time_stamp = time_stamp
2289 self._data_label = data_label
2290 self._variable_labels = variable_labels
2291 self._non_cat_value_labels = value_labels
2292 self._value_labels: list[StataValueLabel] = []
2293 self._has_value_labels = np.array([], dtype=bool)
2294 self._compression = compression
2295 self._output_file: IO[bytes] | None = None
2296 self._converted_names: dict[Hashable, str] = {}
2297 # attach nobs, nvars, data, varlist, typlist
2298 self._prepare_pandas(data)
2299 self.storage_options = storage_options
2301 if byteorder is None:
2302 byteorder = sys.byteorder
2303 self._byteorder = _set_endianness(byteorder)
2304 self._fname = fname
2305 self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8}
2307 def _write(self, to_write: str) -> None:
2308 """
2309 Helper to call encode before writing to file for Python 3 compat.
2310 """
2311 self.handles.handle.write(to_write.encode(self._encoding))
2313 def _write_bytes(self, value: bytes) -> None:
2314 """
2315 Helper to assert file is open before writing.
2316 """
2317 self.handles.handle.write(value)
2319 def _prepare_non_cat_value_labels(
2320 self, data: DataFrame
2321 ) -> list[StataNonCatValueLabel]:
2322 """
2323 Check for value labels provided for non-categorical columns. Value
2324 labels
2325 """
2326 non_cat_value_labels: list[StataNonCatValueLabel] = []
2327 if self._non_cat_value_labels is None:
2328 return non_cat_value_labels
2330 for labname, labels in self._non_cat_value_labels.items():
2331 if labname in self._converted_names:
2332 colname = self._converted_names[labname]
2333 elif labname in data.columns:
2334 colname = str(labname)
2335 else:
2336 raise KeyError(
2337 f"Can't create value labels for {labname}, it wasn't "
2338 "found in the dataset."
2339 )
2341 if not is_numeric_dtype(data[colname].dtype):
2342 # Labels should not be passed explicitly for categorical
2343 # columns that will be converted to int
2344 raise ValueError(
2345 f"Can't create value labels for {labname}, value labels "
2346 "can only be applied to numeric columns."
2347 )
2348 svl = StataNonCatValueLabel(colname, labels, self._encoding)
2349 non_cat_value_labels.append(svl)
2350 return non_cat_value_labels
2352 def _prepare_categoricals(self, data: DataFrame) -> DataFrame:
2353 """
2354 Check for categorical columns, retain categorical information for
2355 Stata file and convert categorical data to int
2356 """
2357 is_cat = [is_categorical_dtype(data[col].dtype) for col in data]
2358 if not any(is_cat):
2359 return data
2361 self._has_value_labels |= np.array(is_cat)
2363 get_base_missing_value = StataMissingValue.get_base_missing_value
2364 data_formatted = []
2365 for col, col_is_cat in zip(data, is_cat):
2366 if col_is_cat:
2367 svl = StataValueLabel(data[col], encoding=self._encoding)
2368 self._value_labels.append(svl)
2369 dtype = data[col].cat.codes.dtype
2370 if dtype == np.int64:
2371 raise ValueError(
2372 "It is not possible to export "
2373 "int64-based categorical data to Stata."
2374 )
2375 values = data[col].cat.codes._values.copy()
2377 # Upcast if needed so that correct missing values can be set
2378 if values.max() >= get_base_missing_value(dtype):
2379 if dtype == np.int8:
2380 dtype = np.dtype(np.int16)
2381 elif dtype == np.int16:
2382 dtype = np.dtype(np.int32)
2383 else:
2384 dtype = np.dtype(np.float64)
2385 values = np.array(values, dtype=dtype)
2387 # Replace missing values with Stata missing value for type
2388 values[values == -1] = get_base_missing_value(dtype)
2389 data_formatted.append((col, values))
2390 else:
2391 data_formatted.append((col, data[col]))
2392 return DataFrame.from_dict(dict(data_formatted))
2394 def _replace_nans(self, data: DataFrame) -> DataFrame:
2395 # return data
2396 """
2397 Checks floating point data columns for nans, and replaces these with
2398 the generic Stata for missing value (.)
2399 """
2400 for c in data:
2401 dtype = data[c].dtype
2402 if dtype in (np.float32, np.float64):
2403 if dtype == np.float32:
2404 replacement = self.MISSING_VALUES["f"]
2405 else:
2406 replacement = self.MISSING_VALUES["d"]
2407 data[c] = data[c].fillna(replacement)
2409 return data
2411 def _update_strl_names(self) -> None:
2412 """No-op, forward compatibility"""
2413 pass
2415 def _validate_variable_name(self, name: str) -> str:
2416 """
2417 Validate variable names for Stata export.
2419 Parameters
2420 ----------
2421 name : str
2422 Variable name
2424 Returns
2425 -------
2426 str
2427 The validated name with invalid characters replaced with
2428 underscores.
2430 Notes
2431 -----
2432 Stata 114 and 117 support ascii characters in a-z, A-Z, 0-9
2433 and _.
2434 """
2435 for c in name:
2436 if (
2437 (c < "A" or c > "Z")
2438 and (c < "a" or c > "z")
2439 and (c < "0" or c > "9")
2440 and c != "_"
2441 ):
2442 name = name.replace(c, "_")
2443 return name
2445 def _check_column_names(self, data: DataFrame) -> DataFrame:
2446 """
2447 Checks column names to ensure that they are valid Stata column names.
2448 This includes checks for:
2449 * Non-string names
2450 * Stata keywords
2451 * Variables that start with numbers
2452 * Variables with names that are too long
2454 When an illegal variable name is detected, it is converted, and if
2455 dates are exported, the variable name is propagated to the date
2456 conversion dictionary
2457 """
2458 converted_names: dict[Hashable, str] = {}
2459 columns = list(data.columns)
2460 original_columns = columns[:]
2462 duplicate_var_id = 0
2463 for j, name in enumerate(columns):
2464 orig_name = name
2465 if not isinstance(name, str):
2466 name = str(name)
2468 name = self._validate_variable_name(name)
2470 # Variable name must not be a reserved word
2471 if name in self.RESERVED_WORDS:
2472 name = "_" + name
2474 # Variable name may not start with a number
2475 if "0" <= name[0] <= "9":
2476 name = "_" + name
2478 name = name[: min(len(name), 32)]
2480 if not name == orig_name:
2481 # check for duplicates
2482 while columns.count(name) > 0:
2483 # prepend ascending number to avoid duplicates
2484 name = "_" + str(duplicate_var_id) + name
2485 name = name[: min(len(name), 32)]
2486 duplicate_var_id += 1
2487 converted_names[orig_name] = name
2489 columns[j] = name
2491 data.columns = Index(columns)
2493 # Check date conversion, and fix key if needed
2494 if self._convert_dates:
2495 for c, o in zip(columns, original_columns):
2496 if c != o:
2497 self._convert_dates[c] = self._convert_dates[o]
2498 del self._convert_dates[o]
2500 if converted_names:
2501 conversion_warning = []
2502 for orig_name, name in converted_names.items():
2503 msg = f"{orig_name} -> {name}"
2504 conversion_warning.append(msg)
2506 ws = invalid_name_doc.format("\n ".join(conversion_warning))
2507 warnings.warn(
2508 ws,
2509 InvalidColumnName,
2510 stacklevel=find_stack_level(),
2511 )
2513 self._converted_names = converted_names
2514 self._update_strl_names()
2516 return data
2518 def _set_formats_and_types(self, dtypes: Series) -> None:
2519 self.fmtlist: list[str] = []
2520 self.typlist: list[int] = []
2521 for col, dtype in dtypes.items():
2522 self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col]))
2523 self.typlist.append(_dtype_to_stata_type(dtype, self.data[col]))
2525 def _prepare_pandas(self, data: DataFrame) -> None:
2526 # NOTE: we might need a different API / class for pandas objects so
2527 # we can set different semantics - handle this with a PR to pandas.io
2529 data = data.copy()
2531 if self._write_index:
2532 temp = data.reset_index()
2533 if isinstance(temp, DataFrame):
2534 data = temp
2536 # Ensure column names are strings
2537 data = self._check_column_names(data)
2539 # Check columns for compatibility with stata, upcast if necessary
2540 # Raise if outside the supported range
2541 data = _cast_to_stata_types(data)
2543 # Replace NaNs with Stata missing values
2544 data = self._replace_nans(data)
2546 # Set all columns to initially unlabelled
2547 self._has_value_labels = np.repeat(False, data.shape[1])
2549 # Create value labels for non-categorical data
2550 non_cat_value_labels = self._prepare_non_cat_value_labels(data)
2552 non_cat_columns = [svl.labname for svl in non_cat_value_labels]
2553 has_non_cat_val_labels = data.columns.isin(non_cat_columns)
2554 self._has_value_labels |= has_non_cat_val_labels
2555 self._value_labels.extend(non_cat_value_labels)
2557 # Convert categoricals to int data, and strip labels
2558 data = self._prepare_categoricals(data)
2560 self.nobs, self.nvar = data.shape
2561 self.data = data
2562 self.varlist = data.columns.tolist()
2564 dtypes = data.dtypes
2566 # Ensure all date columns are converted
2567 for col in data:
2568 if col in self._convert_dates:
2569 continue
2570 if is_datetime64_dtype(data[col]):
2571 self._convert_dates[col] = "tc"
2573 self._convert_dates = _maybe_convert_to_int_keys(
2574 self._convert_dates, self.varlist
2575 )
2576 for key in self._convert_dates:
2577 new_type = _convert_datetime_to_stata_type(self._convert_dates[key])
2578 dtypes[key] = np.dtype(new_type)
2580 # Verify object arrays are strings and encode to bytes
2581 self._encode_strings()
2583 self._set_formats_and_types(dtypes)
2585 # set the given format for the datetime cols
2586 if self._convert_dates is not None:
2587 for key in self._convert_dates:
2588 if isinstance(key, int):
2589 self.fmtlist[key] = self._convert_dates[key]
2591 def _encode_strings(self) -> None:
2592 """
2593 Encode strings in dta-specific encoding
2595 Do not encode columns marked for date conversion or for strL
2596 conversion. The strL converter independently handles conversion and
2597 also accepts empty string arrays.
2598 """
2599 convert_dates = self._convert_dates
2600 # _convert_strl is not available in dta 114
2601 convert_strl = getattr(self, "_convert_strl", [])
2602 for i, col in enumerate(self.data):
2603 # Skip columns marked for date conversion or strl conversion
2604 if i in convert_dates or col in convert_strl:
2605 continue
2606 column = self.data[col]
2607 dtype = column.dtype
2608 if dtype.type is np.object_:
2609 inferred_dtype = infer_dtype(column, skipna=True)
2610 if not ((inferred_dtype == "string") or len(column) == 0):
2611 col = column.name
2612 raise ValueError(
2613 f"""\
2614Column `{col}` cannot be exported.\n\nOnly string-like object arrays
2615containing all strings or a mix of strings and None can be exported.
2616Object arrays containing only null values are prohibited. Other object
2617types cannot be exported and must first be converted to one of the
2618supported types."""
2619 )
2620 encoded = self.data[col].str.encode(self._encoding)
2621 # If larger than _max_string_length do nothing
2622 if (
2623 max_len_string_array(ensure_object(encoded._values))
2624 <= self._max_string_length
2625 ):
2626 self.data[col] = encoded
2628 def write_file(self) -> None:
2629 """
2630 Export DataFrame object to Stata dta format.
2631 """
2632 with get_handle(
2633 self._fname,
2634 "wb",
2635 compression=self._compression,
2636 is_text=False,
2637 storage_options=self.storage_options,
2638 ) as self.handles:
2640 if self.handles.compression["method"] is not None:
2641 # ZipFile creates a file (with the same name) for each write call.
2642 # Write it first into a buffer and then write the buffer to the ZipFile.
2643 self._output_file, self.handles.handle = self.handles.handle, BytesIO()
2644 self.handles.created_handles.append(self.handles.handle)
2646 try:
2647 self._write_header(
2648 data_label=self._data_label, time_stamp=self._time_stamp
2649 )
2650 self._write_map()
2651 self._write_variable_types()
2652 self._write_varnames()
2653 self._write_sortlist()
2654 self._write_formats()
2655 self._write_value_label_names()
2656 self._write_variable_labels()
2657 self._write_expansion_fields()
2658 self._write_characteristics()
2659 records = self._prepare_data()
2660 self._write_data(records)
2661 self._write_strls()
2662 self._write_value_labels()
2663 self._write_file_close_tag()
2664 self._write_map()
2665 self._close()
2666 except Exception as exc:
2667 self.handles.close()
2668 if isinstance(self._fname, (str, os.PathLike)) and os.path.isfile(
2669 self._fname
2670 ):
2671 try:
2672 os.unlink(self._fname)
2673 except OSError:
2674 warnings.warn(
2675 f"This save was not successful but {self._fname} could not "
2676 "be deleted. This file is not valid.",
2677 ResourceWarning,
2678 stacklevel=find_stack_level(),
2679 )
2680 raise exc
2682 def _close(self) -> None:
2683 """
2684 Close the file if it was created by the writer.
2686 If a buffer or file-like object was passed in, for example a GzipFile,
2687 then leave this file open for the caller to close.
2688 """
2689 # write compression
2690 if self._output_file is not None:
2691 assert isinstance(self.handles.handle, BytesIO)
2692 bio, self.handles.handle = self.handles.handle, self._output_file
2693 self.handles.handle.write(bio.getvalue())
2695 def _write_map(self) -> None:
2696 """No-op, future compatibility"""
2697 pass
2699 def _write_file_close_tag(self) -> None:
2700 """No-op, future compatibility"""
2701 pass
2703 def _write_characteristics(self) -> None:
2704 """No-op, future compatibility"""
2705 pass
2707 def _write_strls(self) -> None:
2708 """No-op, future compatibility"""
2709 pass
2711 def _write_expansion_fields(self) -> None:
2712 """Write 5 zeros for expansion fields"""
2713 self._write(_pad_bytes("", 5))
2715 def _write_value_labels(self) -> None:
2716 for vl in self._value_labels:
2717 self._write_bytes(vl.generate_value_label(self._byteorder))
2719 def _write_header(
2720 self,
2721 data_label: str | None = None,
2722 time_stamp: datetime.datetime | None = None,
2723 ) -> None:
2724 byteorder = self._byteorder
2725 # ds_format - just use 114
2726 self._write_bytes(struct.pack("b", 114))
2727 # byteorder
2728 self._write(byteorder == ">" and "\x01" or "\x02")
2729 # filetype
2730 self._write("\x01")
2731 # unused
2732 self._write("\x00")
2733 # number of vars, 2 bytes
2734 self._write_bytes(struct.pack(byteorder + "h", self.nvar)[:2])
2735 # number of obs, 4 bytes
2736 self._write_bytes(struct.pack(byteorder + "i", self.nobs)[:4])
2737 # data label 81 bytes, char, null terminated
2738 if data_label is None:
2739 self._write_bytes(self._null_terminate_bytes(_pad_bytes("", 80)))
2740 else:
2741 self._write_bytes(
2742 self._null_terminate_bytes(_pad_bytes(data_label[:80], 80))
2743 )
2744 # time stamp, 18 bytes, char, null terminated
2745 # format dd Mon yyyy hh:mm
2746 if time_stamp is None:
2747 time_stamp = datetime.datetime.now()
2748 elif not isinstance(time_stamp, datetime.datetime):
2749 raise ValueError("time_stamp should be datetime type")
2750 # GH #13856
2751 # Avoid locale-specific month conversion
2752 months = [
2753 "Jan",
2754 "Feb",
2755 "Mar",
2756 "Apr",
2757 "May",
2758 "Jun",
2759 "Jul",
2760 "Aug",
2761 "Sep",
2762 "Oct",
2763 "Nov",
2764 "Dec",
2765 ]
2766 month_lookup = {i + 1: month for i, month in enumerate(months)}
2767 ts = (
2768 time_stamp.strftime("%d ")
2769 + month_lookup[time_stamp.month]
2770 + time_stamp.strftime(" %Y %H:%M")
2771 )
2772 self._write_bytes(self._null_terminate_bytes(ts))
2774 def _write_variable_types(self) -> None:
2775 for typ in self.typlist:
2776 self._write_bytes(struct.pack("B", typ))
2778 def _write_varnames(self) -> None:
2779 # varlist names are checked by _check_column_names
2780 # varlist, requires null terminated
2781 for name in self.varlist:
2782 name = self._null_terminate_str(name)
2783 name = _pad_bytes(name[:32], 33)
2784 self._write(name)
2786 def _write_sortlist(self) -> None:
2787 # srtlist, 2*(nvar+1), int array, encoded by byteorder
2788 srtlist = _pad_bytes("", 2 * (self.nvar + 1))
2789 self._write(srtlist)
2791 def _write_formats(self) -> None:
2792 # fmtlist, 49*nvar, char array
2793 for fmt in self.fmtlist:
2794 self._write(_pad_bytes(fmt, 49))
2796 def _write_value_label_names(self) -> None:
2797 # lbllist, 33*nvar, char array
2798 for i in range(self.nvar):
2799 # Use variable name when categorical
2800 if self._has_value_labels[i]:
2801 name = self.varlist[i]
2802 name = self._null_terminate_str(name)
2803 name = _pad_bytes(name[:32], 33)
2804 self._write(name)
2805 else: # Default is empty label
2806 self._write(_pad_bytes("", 33))
2808 def _write_variable_labels(self) -> None:
2809 # Missing labels are 80 blank characters plus null termination
2810 blank = _pad_bytes("", 81)
2812 if self._variable_labels is None:
2813 for i in range(self.nvar):
2814 self._write(blank)
2815 return
2817 for col in self.data:
2818 if col in self._variable_labels:
2819 label = self._variable_labels[col]
2820 if len(label) > 80:
2821 raise ValueError("Variable labels must be 80 characters or fewer")
2822 is_latin1 = all(ord(c) < 256 for c in label)
2823 if not is_latin1:
2824 raise ValueError(
2825 "Variable labels must contain only characters that "
2826 "can be encoded in Latin-1"
2827 )
2828 self._write(_pad_bytes(label, 81))
2829 else:
2830 self._write(blank)
2832 def _convert_strls(self, data: DataFrame) -> DataFrame:
2833 """No-op, future compatibility"""
2834 return data
2836 def _prepare_data(self) -> np.recarray:
2837 data = self.data
2838 typlist = self.typlist
2839 convert_dates = self._convert_dates
2840 # 1. Convert dates
2841 if self._convert_dates is not None:
2842 for i, col in enumerate(data):
2843 if i in convert_dates:
2844 data[col] = _datetime_to_stata_elapsed_vec(
2845 data[col], self.fmtlist[i]
2846 )
2847 # 2. Convert strls
2848 data = self._convert_strls(data)
2850 # 3. Convert bad string data to '' and pad to correct length
2851 dtypes = {}
2852 native_byteorder = self._byteorder == _set_endianness(sys.byteorder)
2853 for i, col in enumerate(data):
2854 typ = typlist[i]
2855 if typ <= self._max_string_length:
2856 data[col] = data[col].fillna("").apply(_pad_bytes, args=(typ,))
2857 stype = f"S{typ}"
2858 dtypes[col] = stype
2859 data[col] = data[col].astype(stype)
2860 else:
2861 dtype = data[col].dtype
2862 if not native_byteorder:
2863 dtype = dtype.newbyteorder(self._byteorder)
2864 dtypes[col] = dtype
2866 return data.to_records(index=False, column_dtypes=dtypes)
2868 def _write_data(self, records: np.recarray) -> None:
2869 self._write_bytes(records.tobytes())
2871 @staticmethod
2872 def _null_terminate_str(s: str) -> str:
2873 s += "\x00"
2874 return s
2876 def _null_terminate_bytes(self, s: str) -> bytes:
2877 return self._null_terminate_str(s).encode(self._encoding)
2880def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int:
2881 """
2882 Converts dtype types to stata types. Returns the byte of the given ordinal.
2883 See TYPE_MAP and comments for an explanation. This is also explained in
2884 the dta spec.
2885 1 - 2045 are strings of this length
2886 Pandas Stata
2887 32768 - for object strL
2888 65526 - for int8 byte
2889 65527 - for int16 int
2890 65528 - for int32 long
2891 65529 - for float32 float
2892 65530 - for double double
2894 If there are dates to convert, then dtype will already have the correct
2895 type inserted.
2896 """
2897 # TODO: expand to handle datetime to integer conversion
2898 if force_strl:
2899 return 32768
2900 if dtype.type is np.object_: # try to coerce it to the biggest string
2901 # not memory efficient, what else could we
2902 # do?
2903 itemsize = max_len_string_array(ensure_object(column._values))
2904 itemsize = max(itemsize, 1)
2905 if itemsize <= 2045:
2906 return itemsize
2907 return 32768
2908 elif dtype.type is np.float64:
2909 return 65526
2910 elif dtype.type is np.float32:
2911 return 65527
2912 elif dtype.type is np.int32:
2913 return 65528
2914 elif dtype.type is np.int16:
2915 return 65529
2916 elif dtype.type is np.int8:
2917 return 65530
2918 else: # pragma : no cover
2919 raise NotImplementedError(f"Data type {dtype} not supported.")
2922def _pad_bytes_new(name: str | bytes, length: int) -> bytes:
2923 """
2924 Takes a bytes instance and pads it with null bytes until it's length chars.
2925 """
2926 if isinstance(name, str):
2927 name = bytes(name, "utf-8")
2928 return name + b"\x00" * (length - len(name))
2931class StataStrLWriter:
2932 """
2933 Converter for Stata StrLs
2935 Stata StrLs map 8 byte values to strings which are stored using a
2936 dictionary-like format where strings are keyed to two values.
2938 Parameters
2939 ----------
2940 df : DataFrame
2941 DataFrame to convert
2942 columns : Sequence[str]
2943 List of columns names to convert to StrL
2944 version : int, optional
2945 dta version. Currently supports 117, 118 and 119
2946 byteorder : str, optional
2947 Can be ">", "<", "little", or "big". default is `sys.byteorder`
2949 Notes
2950 -----
2951 Supports creation of the StrL block of a dta file for dta versions
2952 117, 118 and 119. These differ in how the GSO is stored. 118 and
2953 119 store the GSO lookup value as a uint32 and a uint64, while 117
2954 uses two uint32s. 118 and 119 also encode all strings as unicode
2955 which is required by the format. 117 uses 'latin-1' a fixed width
2956 encoding that extends the 7-bit ascii table with an additional 128
2957 characters.
2958 """
2960 def __init__(
2961 self,
2962 df: DataFrame,
2963 columns: Sequence[str],
2964 version: int = 117,
2965 byteorder: str | None = None,
2966 ) -> None:
2967 if version not in (117, 118, 119):
2968 raise ValueError("Only dta versions 117, 118 and 119 supported")
2969 self._dta_ver = version
2971 self.df = df
2972 self.columns = columns
2973 self._gso_table = {"": (0, 0)}
2974 if byteorder is None:
2975 byteorder = sys.byteorder
2976 self._byteorder = _set_endianness(byteorder)
2978 gso_v_type = "I" # uint32
2979 gso_o_type = "Q" # uint64
2980 self._encoding = "utf-8"
2981 if version == 117:
2982 o_size = 4
2983 gso_o_type = "I" # 117 used uint32
2984 self._encoding = "latin-1"
2985 elif version == 118:
2986 o_size = 6
2987 else: # version == 119
2988 o_size = 5
2989 self._o_offet = 2 ** (8 * (8 - o_size))
2990 self._gso_o_type = gso_o_type
2991 self._gso_v_type = gso_v_type
2993 def _convert_key(self, key: tuple[int, int]) -> int:
2994 v, o = key
2995 return v + self._o_offet * o
2997 def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]:
2998 """
2999 Generates the GSO lookup table for the DataFrame
3001 Returns
3002 -------
3003 gso_table : dict
3004 Ordered dictionary using the string found as keys
3005 and their lookup position (v,o) as values
3006 gso_df : DataFrame
3007 DataFrame where strl columns have been converted to
3008 (v,o) values
3010 Notes
3011 -----
3012 Modifies the DataFrame in-place.
3014 The DataFrame returned encodes the (v,o) values as uint64s. The
3015 encoding depends on the dta version, and can be expressed as
3017 enc = v + o * 2 ** (o_size * 8)
3019 so that v is stored in the lower bits and o is in the upper
3020 bits. o_size is
3022 * 117: 4
3023 * 118: 6
3024 * 119: 5
3025 """
3026 gso_table = self._gso_table
3027 gso_df = self.df
3028 columns = list(gso_df.columns)
3029 selected = gso_df[self.columns]
3030 col_index = [(col, columns.index(col)) for col in self.columns]
3031 keys = np.empty(selected.shape, dtype=np.uint64)
3032 for o, (idx, row) in enumerate(selected.iterrows()):
3033 for j, (col, v) in enumerate(col_index):
3034 val = row[col]
3035 # Allow columns with mixed str and None (GH 23633)
3036 val = "" if val is None else val
3037 key = gso_table.get(val, None)
3038 if key is None:
3039 # Stata prefers human numbers
3040 key = (v + 1, o + 1)
3041 gso_table[val] = key
3042 keys[o, j] = self._convert_key(key)
3043 for i, col in enumerate(self.columns):
3044 gso_df[col] = keys[:, i]
3046 return gso_table, gso_df
3048 def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes:
3049 """
3050 Generates the binary blob of GSOs that is written to the dta file.
3052 Parameters
3053 ----------
3054 gso_table : dict
3055 Ordered dictionary (str, vo)
3057 Returns
3058 -------
3059 gso : bytes
3060 Binary content of dta file to be placed between strl tags
3062 Notes
3063 -----
3064 Output format depends on dta version. 117 uses two uint32s to
3065 express v and o while 118+ uses a uint32 for v and a uint64 for o.
3066 """
3067 # Format information
3068 # Length includes null term
3069 # 117
3070 # GSOvvvvooootllllxxxxxxxxxxxxxxx...x
3071 # 3 u4 u4 u1 u4 string + null term
3072 #
3073 # 118, 119
3074 # GSOvvvvooooooootllllxxxxxxxxxxxxxxx...x
3075 # 3 u4 u8 u1 u4 string + null term
3077 bio = BytesIO()
3078 gso = bytes("GSO", "ascii")
3079 gso_type = struct.pack(self._byteorder + "B", 130)
3080 null = struct.pack(self._byteorder + "B", 0)
3081 v_type = self._byteorder + self._gso_v_type
3082 o_type = self._byteorder + self._gso_o_type
3083 len_type = self._byteorder + "I"
3084 for strl, vo in gso_table.items():
3085 if vo == (0, 0):
3086 continue
3087 v, o = vo
3089 # GSO
3090 bio.write(gso)
3092 # vvvv
3093 bio.write(struct.pack(v_type, v))
3095 # oooo / oooooooo
3096 bio.write(struct.pack(o_type, o))
3098 # t
3099 bio.write(gso_type)
3101 # llll
3102 utf8_string = bytes(strl, "utf-8")
3103 bio.write(struct.pack(len_type, len(utf8_string) + 1))
3105 # xxx...xxx
3106 bio.write(utf8_string)
3107 bio.write(null)
3109 return bio.getvalue()
3112class StataWriter117(StataWriter):
3113 """
3114 A class for writing Stata binary dta files in Stata 13 format (117)
3116 Parameters
3117 ----------
3118 fname : path (string), buffer or path object
3119 string, path object (pathlib.Path or py._path.local.LocalPath) or
3120 object implementing a binary write() functions. If using a buffer
3121 then the buffer will not be automatically closed after the file
3122 is written.
3123 data : DataFrame
3124 Input to save
3125 convert_dates : dict
3126 Dictionary mapping columns containing datetime types to stata internal
3127 format to use when writing the dates. Options are 'tc', 'td', 'tm',
3128 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
3129 Datetime columns that do not have a conversion type specified will be
3130 converted to 'tc'. Raises NotImplementedError if a datetime column has
3131 timezone information
3132 write_index : bool
3133 Write the index to Stata dataset.
3134 byteorder : str
3135 Can be ">", "<", "little", or "big". default is `sys.byteorder`
3136 time_stamp : datetime
3137 A datetime to use as file creation date. Default is the current time
3138 data_label : str
3139 A label for the data set. Must be 80 characters or smaller.
3140 variable_labels : dict
3141 Dictionary containing columns as keys and variable labels as values.
3142 Each label must be 80 characters or smaller.
3143 convert_strl : list
3144 List of columns names to convert to Stata StrL format. Columns with
3145 more than 2045 characters are automatically written as StrL.
3146 Smaller columns can be converted by including the column name. Using
3147 StrLs can reduce output file size when strings are longer than 8
3148 characters, and either frequently repeated or sparse.
3149 {compression_options}
3151 .. versionadded:: 1.1.0
3153 .. versionchanged:: 1.4.0 Zstandard support.
3155 value_labels : dict of dicts
3156 Dictionary containing columns as keys and dictionaries of column value
3157 to labels as values. The combined length of all labels for a single
3158 variable must be 32,000 characters or smaller.
3160 .. versionadded:: 1.4.0
3162 Returns
3163 -------
3164 writer : StataWriter117 instance
3165 The StataWriter117 instance has a write_file method, which will
3166 write the file to the given `fname`.
3168 Raises
3169 ------
3170 NotImplementedError
3171 * If datetimes contain timezone information
3172 ValueError
3173 * Columns listed in convert_dates are neither datetime64[ns]
3174 or datetime.datetime
3175 * Column dtype is not representable in Stata
3176 * Column listed in convert_dates is not in DataFrame
3177 * Categorical label contains more than 32,000 characters
3179 Examples
3180 --------
3181 >>> data = pd.DataFrame([[1.0, 1, 'a']], columns=['a', 'b', 'c'])
3182 >>> writer = pd.io.stata.StataWriter117('./data_file.dta', data)
3183 >>> writer.write_file()
3185 Directly write a zip file
3186 >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3187 >>> writer = pd.io.stata.StataWriter117(
3188 ... './data_file.zip', data, compression=compression
3189 ... )
3190 >>> writer.write_file()
3192 Or with long strings stored in strl format
3193 >>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
3194 ... columns=['strls'])
3195 >>> writer = pd.io.stata.StataWriter117(
3196 ... './data_file_with_long_strings.dta', data, convert_strl=['strls'])
3197 >>> writer.write_file()
3198 """
3200 _max_string_length = 2045
3201 _dta_version = 117
3203 def __init__(
3204 self,
3205 fname: FilePath | WriteBuffer[bytes],
3206 data: DataFrame,
3207 convert_dates: dict[Hashable, str] | None = None,
3208 write_index: bool = True,
3209 byteorder: str | None = None,
3210 time_stamp: datetime.datetime | None = None,
3211 data_label: str | None = None,
3212 variable_labels: dict[Hashable, str] | None = None,
3213 convert_strl: Sequence[Hashable] | None = None,
3214 compression: CompressionOptions = "infer",
3215 storage_options: StorageOptions = None,
3216 *,
3217 value_labels: dict[Hashable, dict[float, str]] | None = None,
3218 ) -> None:
3219 # Copy to new list since convert_strl might be modified later
3220 self._convert_strl: list[Hashable] = []
3221 if convert_strl is not None:
3222 self._convert_strl.extend(convert_strl)
3224 super().__init__(
3225 fname,
3226 data,
3227 convert_dates,
3228 write_index,
3229 byteorder=byteorder,
3230 time_stamp=time_stamp,
3231 data_label=data_label,
3232 variable_labels=variable_labels,
3233 value_labels=value_labels,
3234 compression=compression,
3235 storage_options=storage_options,
3236 )
3237 self._map: dict[str, int] = {}
3238 self._strl_blob = b""
3240 @staticmethod
3241 def _tag(val: str | bytes, tag: str) -> bytes:
3242 """Surround val with <tag></tag>"""
3243 if isinstance(val, str):
3244 val = bytes(val, "utf-8")
3245 return bytes("<" + tag + ">", "utf-8") + val + bytes("</" + tag + ">", "utf-8")
3247 def _update_map(self, tag: str) -> None:
3248 """Update map location for tag with file position"""
3249 assert self.handles.handle is not None
3250 self._map[tag] = self.handles.handle.tell()
3252 def _write_header(
3253 self,
3254 data_label: str | None = None,
3255 time_stamp: datetime.datetime | None = None,
3256 ) -> None:
3257 """Write the file header"""
3258 byteorder = self._byteorder
3259 self._write_bytes(bytes("<stata_dta>", "utf-8"))
3260 bio = BytesIO()
3261 # ds_format - 117
3262 bio.write(self._tag(bytes(str(self._dta_version), "utf-8"), "release"))
3263 # byteorder
3264 bio.write(self._tag(byteorder == ">" and "MSF" or "LSF", "byteorder"))
3265 # number of vars, 2 bytes in 117 and 118, 4 byte in 119
3266 nvar_type = "H" if self._dta_version <= 118 else "I"
3267 bio.write(self._tag(struct.pack(byteorder + nvar_type, self.nvar), "K"))
3268 # 117 uses 4 bytes, 118 uses 8
3269 nobs_size = "I" if self._dta_version == 117 else "Q"
3270 bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N"))
3271 # data label 81 bytes, char, null terminated
3272 label = data_label[:80] if data_label is not None else ""
3273 encoded_label = label.encode(self._encoding)
3274 label_size = "B" if self._dta_version == 117 else "H"
3275 label_len = struct.pack(byteorder + label_size, len(encoded_label))
3276 encoded_label = label_len + encoded_label
3277 bio.write(self._tag(encoded_label, "label"))
3278 # time stamp, 18 bytes, char, null terminated
3279 # format dd Mon yyyy hh:mm
3280 if time_stamp is None:
3281 time_stamp = datetime.datetime.now()
3282 elif not isinstance(time_stamp, datetime.datetime):
3283 raise ValueError("time_stamp should be datetime type")
3284 # Avoid locale-specific month conversion
3285 months = [
3286 "Jan",
3287 "Feb",
3288 "Mar",
3289 "Apr",
3290 "May",
3291 "Jun",
3292 "Jul",
3293 "Aug",
3294 "Sep",
3295 "Oct",
3296 "Nov",
3297 "Dec",
3298 ]
3299 month_lookup = {i + 1: month for i, month in enumerate(months)}
3300 ts = (
3301 time_stamp.strftime("%d ")
3302 + month_lookup[time_stamp.month]
3303 + time_stamp.strftime(" %Y %H:%M")
3304 )
3305 # '\x11' added due to inspection of Stata file
3306 stata_ts = b"\x11" + bytes(ts, "utf-8")
3307 bio.write(self._tag(stata_ts, "timestamp"))
3308 self._write_bytes(self._tag(bio.getvalue(), "header"))
3310 def _write_map(self) -> None:
3311 """
3312 Called twice during file write. The first populates the values in
3313 the map with 0s. The second call writes the final map locations when
3314 all blocks have been written.
3315 """
3316 if not self._map:
3317 self._map = {
3318 "stata_data": 0,
3319 "map": self.handles.handle.tell(),
3320 "variable_types": 0,
3321 "varnames": 0,
3322 "sortlist": 0,
3323 "formats": 0,
3324 "value_label_names": 0,
3325 "variable_labels": 0,
3326 "characteristics": 0,
3327 "data": 0,
3328 "strls": 0,
3329 "value_labels": 0,
3330 "stata_data_close": 0,
3331 "end-of-file": 0,
3332 }
3333 # Move to start of map
3334 self.handles.handle.seek(self._map["map"])
3335 bio = BytesIO()
3336 for val in self._map.values():
3337 bio.write(struct.pack(self._byteorder + "Q", val))
3338 self._write_bytes(self._tag(bio.getvalue(), "map"))
3340 def _write_variable_types(self) -> None:
3341 self._update_map("variable_types")
3342 bio = BytesIO()
3343 for typ in self.typlist:
3344 bio.write(struct.pack(self._byteorder + "H", typ))
3345 self._write_bytes(self._tag(bio.getvalue(), "variable_types"))
3347 def _write_varnames(self) -> None:
3348 self._update_map("varnames")
3349 bio = BytesIO()
3350 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3351 vn_len = 32 if self._dta_version == 117 else 128
3352 for name in self.varlist:
3353 name = self._null_terminate_str(name)
3354 name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1)
3355 bio.write(name)
3356 self._write_bytes(self._tag(bio.getvalue(), "varnames"))
3358 def _write_sortlist(self) -> None:
3359 self._update_map("sortlist")
3360 sort_size = 2 if self._dta_version < 119 else 4
3361 self._write_bytes(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist"))
3363 def _write_formats(self) -> None:
3364 self._update_map("formats")
3365 bio = BytesIO()
3366 fmt_len = 49 if self._dta_version == 117 else 57
3367 for fmt in self.fmtlist:
3368 bio.write(_pad_bytes_new(fmt.encode(self._encoding), fmt_len))
3369 self._write_bytes(self._tag(bio.getvalue(), "formats"))
3371 def _write_value_label_names(self) -> None:
3372 self._update_map("value_label_names")
3373 bio = BytesIO()
3374 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3375 vl_len = 32 if self._dta_version == 117 else 128
3376 for i in range(self.nvar):
3377 # Use variable name when categorical
3378 name = "" # default name
3379 if self._has_value_labels[i]:
3380 name = self.varlist[i]
3381 name = self._null_terminate_str(name)
3382 encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1)
3383 bio.write(encoded_name)
3384 self._write_bytes(self._tag(bio.getvalue(), "value_label_names"))
3386 def _write_variable_labels(self) -> None:
3387 # Missing labels are 80 blank characters plus null termination
3388 self._update_map("variable_labels")
3389 bio = BytesIO()
3390 # 118 scales by 4 to accommodate utf-8 data worst case encoding
3391 vl_len = 80 if self._dta_version == 117 else 320
3392 blank = _pad_bytes_new("", vl_len + 1)
3394 if self._variable_labels is None:
3395 for _ in range(self.nvar):
3396 bio.write(blank)
3397 self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
3398 return
3400 for col in self.data:
3401 if col in self._variable_labels:
3402 label = self._variable_labels[col]
3403 if len(label) > 80:
3404 raise ValueError("Variable labels must be 80 characters or fewer")
3405 try:
3406 encoded = label.encode(self._encoding)
3407 except UnicodeEncodeError as err:
3408 raise ValueError(
3409 "Variable labels must contain only characters that "
3410 f"can be encoded in {self._encoding}"
3411 ) from err
3413 bio.write(_pad_bytes_new(encoded, vl_len + 1))
3414 else:
3415 bio.write(blank)
3416 self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
3418 def _write_characteristics(self) -> None:
3419 self._update_map("characteristics")
3420 self._write_bytes(self._tag(b"", "characteristics"))
3422 def _write_data(self, records) -> None:
3423 self._update_map("data")
3424 self._write_bytes(b"<data>")
3425 self._write_bytes(records.tobytes())
3426 self._write_bytes(b"</data>")
3428 def _write_strls(self) -> None:
3429 self._update_map("strls")
3430 self._write_bytes(self._tag(self._strl_blob, "strls"))
3432 def _write_expansion_fields(self) -> None:
3433 """No-op in dta 117+"""
3434 pass
3436 def _write_value_labels(self) -> None:
3437 self._update_map("value_labels")
3438 bio = BytesIO()
3439 for vl in self._value_labels:
3440 lab = vl.generate_value_label(self._byteorder)
3441 lab = self._tag(lab, "lbl")
3442 bio.write(lab)
3443 self._write_bytes(self._tag(bio.getvalue(), "value_labels"))
3445 def _write_file_close_tag(self) -> None:
3446 self._update_map("stata_data_close")
3447 self._write_bytes(bytes("</stata_dta>", "utf-8"))
3448 self._update_map("end-of-file")
3450 def _update_strl_names(self) -> None:
3451 """
3452 Update column names for conversion to strl if they might have been
3453 changed to comply with Stata naming rules
3454 """
3455 # Update convert_strl if names changed
3456 for orig, new in self._converted_names.items():
3457 if orig in self._convert_strl:
3458 idx = self._convert_strl.index(orig)
3459 self._convert_strl[idx] = new
3461 def _convert_strls(self, data: DataFrame) -> DataFrame:
3462 """
3463 Convert columns to StrLs if either very large or in the
3464 convert_strl variable
3465 """
3466 convert_cols = [
3467 col
3468 for i, col in enumerate(data)
3469 if self.typlist[i] == 32768 or col in self._convert_strl
3470 ]
3472 if convert_cols:
3473 ssw = StataStrLWriter(data, convert_cols, version=self._dta_version)
3474 tab, new_data = ssw.generate_table()
3475 data = new_data
3476 self._strl_blob = ssw.generate_blob(tab)
3477 return data
3479 def _set_formats_and_types(self, dtypes: Series) -> None:
3480 self.typlist = []
3481 self.fmtlist = []
3482 for col, dtype in dtypes.items():
3483 force_strl = col in self._convert_strl
3484 fmt = _dtype_to_default_stata_fmt(
3485 dtype,
3486 self.data[col],
3487 dta_version=self._dta_version,
3488 force_strl=force_strl,
3489 )
3490 self.fmtlist.append(fmt)
3491 self.typlist.append(
3492 _dtype_to_stata_type_117(dtype, self.data[col], force_strl)
3493 )
3496class StataWriterUTF8(StataWriter117):
3497 """
3498 Stata binary dta file writing in Stata 15 (118) and 16 (119) formats
3500 DTA 118 and 119 format files support unicode string data (both fixed
3501 and strL) format. Unicode is also supported in value labels, variable
3502 labels and the dataset label. Format 119 is automatically used if the
3503 file contains more than 32,767 variables.
3505 .. versionadded:: 1.0.0
3507 Parameters
3508 ----------
3509 fname : path (string), buffer or path object
3510 string, path object (pathlib.Path or py._path.local.LocalPath) or
3511 object implementing a binary write() functions. If using a buffer
3512 then the buffer will not be automatically closed after the file
3513 is written.
3514 data : DataFrame
3515 Input to save
3516 convert_dates : dict, default None
3517 Dictionary mapping columns containing datetime types to stata internal
3518 format to use when writing the dates. Options are 'tc', 'td', 'tm',
3519 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
3520 Datetime columns that do not have a conversion type specified will be
3521 converted to 'tc'. Raises NotImplementedError if a datetime column has
3522 timezone information
3523 write_index : bool, default True
3524 Write the index to Stata dataset.
3525 byteorder : str, default None
3526 Can be ">", "<", "little", or "big". default is `sys.byteorder`
3527 time_stamp : datetime, default None
3528 A datetime to use as file creation date. Default is the current time
3529 data_label : str, default None
3530 A label for the data set. Must be 80 characters or smaller.
3531 variable_labels : dict, default None
3532 Dictionary containing columns as keys and variable labels as values.
3533 Each label must be 80 characters or smaller.
3534 convert_strl : list, default None
3535 List of columns names to convert to Stata StrL format. Columns with
3536 more than 2045 characters are automatically written as StrL.
3537 Smaller columns can be converted by including the column name. Using
3538 StrLs can reduce output file size when strings are longer than 8
3539 characters, and either frequently repeated or sparse.
3540 version : int, default None
3541 The dta version to use. By default, uses the size of data to determine
3542 the version. 118 is used if data.shape[1] <= 32767, and 119 is used
3543 for storing larger DataFrames.
3544 {compression_options}
3546 .. versionadded:: 1.1.0
3548 .. versionchanged:: 1.4.0 Zstandard support.
3550 value_labels : dict of dicts
3551 Dictionary containing columns as keys and dictionaries of column value
3552 to labels as values. The combined length of all labels for a single
3553 variable must be 32,000 characters or smaller.
3555 .. versionadded:: 1.4.0
3557 Returns
3558 -------
3559 StataWriterUTF8
3560 The instance has a write_file method, which will write the file to the
3561 given `fname`.
3563 Raises
3564 ------
3565 NotImplementedError
3566 * If datetimes contain timezone information
3567 ValueError
3568 * Columns listed in convert_dates are neither datetime64[ns]
3569 or datetime.datetime
3570 * Column dtype is not representable in Stata
3571 * Column listed in convert_dates is not in DataFrame
3572 * Categorical label contains more than 32,000 characters
3574 Examples
3575 --------
3576 Using Unicode data and column names
3578 >>> from pandas.io.stata import StataWriterUTF8
3579 >>> data = pd.DataFrame([[1.0, 1, 'ᴬ']], columns=['a', 'β', 'ĉ'])
3580 >>> writer = StataWriterUTF8('./data_file.dta', data)
3581 >>> writer.write_file()
3583 Directly write a zip file
3584 >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3585 >>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
3586 >>> writer.write_file()
3588 Or with long strings stored in strl format
3590 >>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
3591 ... columns=['strls'])
3592 >>> writer = StataWriterUTF8('./data_file_with_long_strings.dta', data,
3593 ... convert_strl=['strls'])
3594 >>> writer.write_file()
3595 """
3597 _encoding: Literal["utf-8"] = "utf-8"
3599 def __init__(
3600 self,
3601 fname: FilePath | WriteBuffer[bytes],
3602 data: DataFrame,
3603 convert_dates: dict[Hashable, str] | None = None,
3604 write_index: bool = True,
3605 byteorder: str | None = None,
3606 time_stamp: datetime.datetime | None = None,
3607 data_label: str | None = None,
3608 variable_labels: dict[Hashable, str] | None = None,
3609 convert_strl: Sequence[Hashable] | None = None,
3610 version: int | None = None,
3611 compression: CompressionOptions = "infer",
3612 storage_options: StorageOptions = None,
3613 *,
3614 value_labels: dict[Hashable, dict[float, str]] | None = None,
3615 ) -> None:
3616 if version is None:
3617 version = 118 if data.shape[1] <= 32767 else 119
3618 elif version not in (118, 119):
3619 raise ValueError("version must be either 118 or 119.")
3620 elif version == 118 and data.shape[1] > 32767:
3621 raise ValueError(
3622 "You must use version 119 for data sets containing more than"
3623 "32,767 variables"
3624 )
3626 super().__init__(
3627 fname,
3628 data,
3629 convert_dates=convert_dates,
3630 write_index=write_index,
3631 byteorder=byteorder,
3632 time_stamp=time_stamp,
3633 data_label=data_label,
3634 variable_labels=variable_labels,
3635 value_labels=value_labels,
3636 convert_strl=convert_strl,
3637 compression=compression,
3638 storage_options=storage_options,
3639 )
3640 # Override version set in StataWriter117 init
3641 self._dta_version = version
3643 def _validate_variable_name(self, name: str) -> str:
3644 """
3645 Validate variable names for Stata export.
3647 Parameters
3648 ----------
3649 name : str
3650 Variable name
3652 Returns
3653 -------
3654 str
3655 The validated name with invalid characters replaced with
3656 underscores.
3658 Notes
3659 -----
3660 Stata 118+ support most unicode characters. The only limitation is in
3661 the ascii range where the characters supported are a-z, A-Z, 0-9 and _.
3662 """
3663 # High code points appear to be acceptable
3664 for c in name:
3665 if (
3666 (
3667 ord(c) < 128
3668 and (c < "A" or c > "Z")
3669 and (c < "a" or c > "z")
3670 and (c < "0" or c > "9")
3671 and c != "_"
3672 )
3673 or 128 <= ord(c) < 192
3674 or c in {"×", "÷"}
3675 ):
3676 name = name.replace(c, "_")
3678 return name