Source code for dataio.save

"""Datapackage save module of dataio utility."""

import json
import logging
import os
from logging import getLogger
from pathlib import Path
import requests

import h5py
import numpy as np
import pandas as pd
import yaml
from uuid import uuid4
from io import StringIO

from dataio.tools import BonsaiBaseModel, BonsaiTableModel
from dataio.validate import validate_matrix, validate_table
from dataio.schemas.bonsai_api import DataResource

from .set_logger import set_logger

logger = getLogger("root")

SUPPORTED_TABLE_FILE_EXTENSIONS = [".parquet", ".xlsx", ".xls", ".csv", ".pkl"]

SUPPORTED_DICT_FILE_EXTENSIONS = [".json", ".yaml"]

SUPPORTED_MATRIX_FILE_EXTENSIONS = [".hdf5", ".h5"]


[docs] def save_dict(data, path: Path, append=False): if isinstance(data, BonsaiBaseModel): data_dict = data.model_dump() elif isinstance(data, dict): data_dict = data else: assert False, "Data format not supported, use dict or BonsaiBaseModel" write_type = "w" if append: write_type = "w+" if path.suffix == ".yaml": with open(path, write_type) as file: yaml.dump(data_dict, file) elif path.suffix == ".json": with open(path, write_type) as file: json.dump(data_dict, file)
[docs] def save_table(data, path: Path, append=False): if isinstance(data, BonsaiTableModel): df = data.to_pandas() elif isinstance(data, pd.DataFrame): df = data else: assert False, "Data format not supported, use DataFrame or BonsaiTableModel" mode = "w" if append: mode = "a" if path.suffix == ".parquet": if append: logger.error("Append is not supported for .parquet files") else: df.to_parquet(path) elif path.suffix == ".xlsx" or path.suffix == ".xls": with pd.ExcelWriter(path, mode=mode) as writer: df.to_excel(writer) elif path.suffix == ".csv": df.to_csv(path, mode=mode, index=False, date_format="%Y-%m-%d %H:%M:%SZ") elif path.suffix == ".pkl": if append: logger.error("Append is not supported for .pkl files") else: df.to_pickle(path)
[docs] def save_matrix(data: pd.DataFrame, name: str, path: Path, append=False): # Function to save DataFrame to HDF5 file with index and columns mode = "w" if append: mode = "a" data.to_hdf(path, key=name, mode=mode)
[docs] def save(data, name: str, path: Path, schema=None, overwrite=True): loaded_files = {} if not schema: return old_save(data, path) if path.name.startswith("http"): # if path is a url, connect to the API url and load the package names # defined in the keys of the schemas dict df = self._read_http(*args, **kwargs) else: if not path.parent.exists(): path.parent.mkdir(parents=True) if path.exists() and not overwrite: # If path is a file, just load the file if path.suffix in SUPPORTED_DICT_FILE_EXTENSIONS: save_dict(data, path, True) # If path is a file, just load the file elif path.suffix in SUPPORTED_TABLE_FILE_EXTENSIONS: # validate data before it is written validate_table(data, schema=schema) save_table(data, path, True) elif path.suffix in SUPPORTED_MATRIX_FILE_EXTENSIONS: validate_matrix(data, schema=schema) save_matrix(data, name, path, True) else: # If path is a file, just load the file if path.suffix in SUPPORTED_DICT_FILE_EXTENSIONS: save_dict(data, path, False) # If path is a file, just load the file elif path.suffix in SUPPORTED_TABLE_FILE_EXTENSIONS: # validate data before it is written validate_table(data, schema=schema) save_table(data, path, False) elif path.suffix in SUPPORTED_MATRIX_FILE_EXTENSIONS: validate_matrix(data, schema=schema) save_matrix(data, name, path)
[docs] def old_save( datapackage, root_path: str = ".", increment: str = None, overwrite: bool = False, create_path: bool = False, log_name: str = None, ): """Save datapackage from dataio.yaml file. Parameters ---------- datapackage: DataPackage dataio datapackage root_path : str path to root of database increment : str semantic level to increment, in [None, 'patch', 'minor', 'major'] overwrite : bool whether to overwrite create_path : bool whether to create path log_name : str name of log file, if None no log is set """ logger.info("Started datapackage save") metadata = datapackage.__metadata__ full_path = Path(root_path).joinpath(metadata["path"]) # open log file if log_name is not None: set_logger(filename=log_name, path=full_path.parent, overwrite=overwrite) logger.info("Started dataio plot log file") else: logger.info("Not initialized new log file") # Auto-increment if increment is None: logger.info("No version number auto-increment") else: logger.info("Attempting to auto-increment version number") if "version" not in metadata.keys(): logger.warning("No version field in metadata, so no increment") else: try: version_list = version_str2list(metadata["version"]) except ValueError: logger.warning( "Metadata version is not in semantic format, " f"so no auto-increment: {metadata['version']}" ) semantic_level_list = ["patch", "minor", "major"] semantic_level_pos = [2, 1, 0] if increment in semantic_level_list: pos = semantic_level_list.index(increment) version_list[semantic_level_pos[pos]] += 1 version_str = version_list2str(version_list) metadata["version"] = version_str logger.info( f"Given semantic_level '{increment}', the " f"new version number is '{version_str}'" ) else: logger.warning( "Unrecogninzed semantic level, no increment: " f"{increment}. Should be in ['patch', " "'minor', 'major']" ) # check path exists and overwrite options if not os.path.exists(full_path): if create_path: os.makedirs(full_path) logger.info(f"Path {full_path} created") else: logger.error( f"Path {full_path} does not exist and 'create_path' " "option is disabled" ) raise FileNotFoundError if len(os.listdir(full_path)) > 0: if not overwrite: logger.error( f"Path {full_path} is not empty and 'overwrite' " "option is disabled" ) raise FileExistsError # export csvs for pos, table in enumerate(metadata["tables"]): delimiter = "," quotechar = '"' if "dialect" in table.keys(): if "csv" in table["dialect"].keys(): if "delimiter" in table["dialect"]["csv"].keys(): delimiter = table["dialect"]["csv"]["delimiter"] if "quoteChar" in table["dialect"]["csv"].keys(): quotechar = table["dialect"]["csv"]["quoteChar"] if "skipInitialSpace" in table["dialect"]["csv"].keys(): if table["dialect"]["csv"]["skipInitialSpace"]: logger.warning( f"Initial space skip in table {table['name']} metadata " "originally True and set to False" ) table["dialect"]["csv"]["skipInitialSpace"] = False metadata["tables"][pos] = table csv_path = full_path.joinpath(table["path"]) df = datapackage.__dict__[table["name"]] df["id"] = df.index df.insert(0, "id", df.pop("id")) df.to_csv(csv_path, index=False, sep=delimiter, quotechar=quotechar) logger.info(f"Exported table {table['name']} to {csv_path}") # export metadata meta_path = full_path.joinpath(f"{metadata['name']}.dataio.yaml") try: with open(meta_path, "w") as f: yaml.safe_dump(metadata, f) except FileNotFoundError: logger.error(f"File '{meta_path}' could not be " "exported to output path") logger.info(f"Exported metadata to {meta_path}") logger.info("Finished datapackage save")
[docs] def version_str2list(version_str): """Convert semantic version string 'vMAJOR.MINOR.PATCH' to list.""" version_list = version_str.split(".") version_list[0] = version_list[0][1:] version_list = [int(version_level) for version_level in version_list] return version_list
[docs] def version_list2str(version_list): """Convert semantic version list to string 'vMAJOR.MINOR.PATCH'.""" str_level = [str(version_level) for version_level in version_list] version_str = ".".join(str_level) return "v" + version_str
[docs] def save_to_api(data: pd.DataFrame, resource: DataResource, schema=None, overwrite=True): """ Saves the given DataFrame to resource.api_endpoint via a single JSON POST. The JSON body has the form: { "data": [ {...}, {...} ] } so that multiple rows can be created in one request (per your test example). Parameters ---------- data : pd.DataFrame The data to be sent. Each row becomes one dict. resource : DataResource Must have a non-empty 'api_endpoint'. (Optionally add resource.id => references the 'version' if your endpoint requires it.) schema : optional If you want to validate 'data' before sending, do so here. overwrite : bool If your API supports 'overwrite', pass it as a query param or in the body (depending on your API). Raises ------ ValueError If 'resource.api_endpoint' is missing or if the POST fails. """ if not resource.api_endpoint: raise ValueError( f"Resource '{resource.name}' has no api_endpoint. Cannot save via API." ) #assign the resource id to version data['version'] = resource.id payload = data.to_dict(orient="records") # Overwrite logic: e.g., pass as query params params = {"overwrite": "true" if overwrite else "false"} try: response = requests.post(resource.api_endpoint, json=payload, params=params) response.raise_for_status() # Raises HTTPError if 4XX or 5XX except requests.RequestException as exc: raise ValueError( f"Failed to save data to API endpoint '{resource.api_endpoint}': {exc}" )