Commit 7a8054ed authored by Erly Villaroel's avatar Erly Villaroel

Correcion de bugs

parent 76fb4c10
...@@ -42,17 +42,18 @@ class Process: ...@@ -42,17 +42,18 @@ class Process:
# Parsea los parámetros de entrada # Parsea los parámetros de entrada
relation = relation_classname_identifier[self.descriptor["idScript"]] relation = relation_classname_identifier[self.descriptor["idScript"]]
obj_script = globals()[relation](self.app) obj_script = globals()[relation](self.app)
obj_script.parser(self.descriptor) obj_script.parser(self.descriptor)
# Iniciando process # Iniciando process
self.app.logger.info(f"Iniciando procesamiento de script") self.app.logger.info(f"Iniciando procesamiento de script")
obj_script.process(source) obj_script.process(source)
# Guardando resultado # Guardando resultado
self.app.logger.info(f"Generado y guardando resultado") self.app.logger.info(f"Generando resultados")
response = obj_script.response() response = obj_script.response()
# response.show() if response["status"] != StatusEnum.OK.name:
raise RuntimeError(response["message"])
self.app.logger.info(f"Guardando resultados")
response = response["result"]
result = self.utils.create_result(response, self.descriptor) result = self.utils.create_result(response, self.descriptor)
save = self.utils.save_result(result, self.descriptor, db_session) save = self.utils.save_result(result, self.descriptor, db_session)
if save["status"] == StatusEnum.ERROR.name: if save["status"] == StatusEnum.ERROR.name:
......
...@@ -3,11 +3,11 @@ app: ...@@ -3,11 +3,11 @@ app:
db_parameters: db_parameters:
# BD Credentials # BD Credentials
type: 'mysql' type: 'mysql'
host: '192.168.1.37' host: '192.168.0.11'
port: 13306 port: 3301
user: root user: root
password: root password: root
db: css_cuscatlan db: cusca
dialect: 'mysql+pymysql' dialect: 'mysql+pymysql'
# BD conexion configurations # BD conexion configurations
# https://docs.sqlalchemy.org/en/14/core/pooling.html # https://docs.sqlalchemy.org/en/14/core/pooling.html
......
from typing import Any, Dict, List from typing import Any, Dict
import importlib.util import importlib.util
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import subprocess
import json import json
from dask import dataframe as dd import os
from numba import jit, types, typed import uuid
from app.main.engine.enum.StatusEnum import StatusEnum
from dpss import find_subset
from wrapt_timeout_decorator import timeout from wrapt_timeout_decorator import timeout
from pyspark.sql.functions import sum, collect_list, round, when, col, lit, size, udf, array_except, array
from pyspark.sql.types import ArrayType, StringType, IntegerType, LongType, List
from app.main.engine.action.ActionInterface import ActionInterface from app.main.engine.action.ActionInterface import ActionInterface
# RELACION DE IDENTIFICADOR DE ACCION Y NOMBRE DE CLASE # RELACION DE IDENTIFICADOR DE ACCION Y NOMBRE DE CLASE
...@@ -14,27 +18,35 @@ relation_classname_identifier = { ...@@ -14,27 +18,35 @@ relation_classname_identifier = {
"match-and-exclude-records-actions": "MatchAndExcludeRecordsAction" "match-and-exclude-records-actions": "MatchAndExcludeRecordsAction"
} }
# CONFIGURACION DE SESION DE SPARK
MASTER = "local[*]"
DRIVER_MEMORY = "8g"
EXECUTOR_MEMORY = "8g"
MYSQL_JAR_PATH = "jars/mysql-connector-java-8.0.30.jar"
# EXCLUDE VALIDATION FIELD # EXCLUDE VALIDATION FIELD
EXCLUDE_ROWS_FIELD = "EXCLUDE_VALID" EXCLUDE_ROWS_FIELD = "EXCLUDE_VALID"
# REDONDEO DE DECIMALES # REDONDEO DE DECIMALES
ROUND_DECIMAL = 2 ROUND_DECIMAL = 3
FACTOR = 1000
class MatchAndExcludeRecordsAction(ActionInterface): class MatchAndExcludeRecordsAction(ActionInterface):
library_required = "pyspark" library_required = "pyspark"
version_required = "3.4.0" version_required = "3.4.0"
def __init__(self, app) -> None: def __init__(self, app) -> None:
super().__init__(app) super().__init__(app)
self.max_combinations = None self.max_combinations = None
self.timeout = None self.timeout = False
self.timeout_time = None
self.exclude_pivot = None self.exclude_pivot = None
self.pivot_params = None self.pivot_params = None
self.ctp_params = None self.ctp_params = None
self.output = None self.output = None
self.config_params = ["max-records-per-combinations", "max-timeout-per-combinations", "exclude-entity-pivot"] self.config_params = ["max-records-per-combinations", "max-timeout-per-combinations", "exclude-entity-pivot"]
def parser(self, descriptor: Dict[str, Any]): def parser(self, descriptor: Dict[str, Any]):
# Validar si pyspark y su versión está instalada # Validar si pyspark y su versión está instalada
pyspark_lib = importlib.util.find_spec(self.library_required) pyspark_lib = importlib.util.find_spec(self.library_required)
...@@ -43,7 +55,8 @@ class MatchAndExcludeRecordsAction(ActionInterface): ...@@ -43,7 +55,8 @@ class MatchAndExcludeRecordsAction(ActionInterface):
import pyspark import pyspark
version = pyspark.__version__ version = pyspark.__version__
if version != self.version_required: if version != self.version_required:
raise ImportError(f"Versión requerida no instalada. Requerida: {self.version_required}. Instalada: {version}") raise ImportError(
f"Versión requerida no instalada. Requerida: {self.version_required}. Instalada: {version}")
# Validación de parámetros de entrada # Validación de parámetros de entrada
entity_config_params = ["tablename", "id-column", "amount-column", "columns-group", "columns-transaction"] entity_config_params = ["tablename", "id-column", "amount-column", "columns-group", "columns-transaction"]
...@@ -69,209 +82,213 @@ class MatchAndExcludeRecordsAction(ActionInterface): ...@@ -69,209 +82,213 @@ class MatchAndExcludeRecordsAction(ActionInterface):
raise ReferenceError(f"Parámetro *{param}* no encontrado en pivot o contraparte") raise ReferenceError(f"Parámetro *{param}* no encontrado en pivot o contraparte")
self.max_combinations = configs["max-records-per-combinations"] self.max_combinations = configs["max-records-per-combinations"]
self.timeout = configs["max-timeout-per-combinations"] self.timeout_time = configs["max-timeout-per-combinations"]
self.exclude_pivot = configs["exclude-entity-pivot"] self.exclude_pivot = configs["exclude-entity-pivot"]
self.pivot_params = pivot_params self.pivot_params = pivot_params
self.ctp_params = ctp_params self.ctp_params = ctp_params
def process(self, source_obs): def process(self, source_obs):
response = {"status": StatusEnum.ERROR.name}
try: try:
@timeout(self.timeout) @timeout(self.timeout_time)
def __process(source_obj): def __process(source_obj):
# Traer la data desde BD tanto pivot como contraparte try:
pivot_table, ctp_table = self.pivot_params["tablename"], self.ctp_params["tablename"] # Inicializar la sesion de Spark
dialect = source_obj.get_dialect() session = self.createSession()
pivot_df = dd.read_sql_table(pivot_table, dialect, index_col=self.pivot_params["id-column"], # Traer la data desde BD tanto pivot como contraparte
npartitions=4) pivot_table, ctp_table = self.pivot_params["tablename"], self.ctp_params["tablename"]
pivot_df = pivot_df.reset_index() jdbc_conn = source_obj.create_spark_connection()
ctp_df = dd.read_sql_table(ctp_table, dialect, index_col=self.ctp_params["id-column"], npartitions=4) jdbc_url = jdbc_conn["url"]
ctp_df = ctp_df.reset_index() jdbc_properties = jdbc_conn["properties"]
pivot_df = session.read.jdbc(url=jdbc_url, table=pivot_table, properties=jdbc_properties)
# Agregar un prefijo a cada columna, tanto del pivot como contraparte. Actualizar campos del input ctp_df = session.read.jdbc(url=jdbc_url, table=ctp_table, properties=jdbc_properties)
# pivot: 'PIVOT_', contraparte: 'COUNTERPART_'
# Iterar sobre las columnas del DataFrame # Agregar un prefijo a cada columna, tanto del pivot como contraparte. Actualizar campos del input
for column in pivot_df.columns: # pivot: 'PIVOT_', contraparte: 'COUNTERPART_'
if column == EXCLUDE_ROWS_FIELD: for column in pivot_df.columns:
continue if column == EXCLUDE_ROWS_FIELD:
new_column_name = "PIVOT_" + column continue
pivot_df = pivot_df.rename(columns={column: new_column_name}) pivot_df = pivot_df.withColumnRenamed(column, "PIVOT_" + column)
for column in ctp_df.columns: for column in ctp_df.columns:
if column == EXCLUDE_ROWS_FIELD: if column == EXCLUDE_ROWS_FIELD:
continue continue
new_column_name = "COUNTERPART_" + column ctp_df = ctp_df.withColumnRenamed(column, "COUNTERPART_" + column)
ctp_df = ctp_df.rename(columns={column: new_column_name})
for key_p, key_c in zip(self.pivot_params.keys(), self.ctp_params.keys()):
for key_p, key_c in zip(self.pivot_params.keys(), self.ctp_params.keys()): if isinstance(self.pivot_params[key_p], str):
if isinstance(self.pivot_params[key_p], str): self.pivot_params[key_p] = "PIVOT_" + self.pivot_params[key_p]
self.pivot_params[key_p] = "PIVOT_"+self.pivot_params[key_p] self.ctp_params[key_c] = "COUNTERPART_" + self.ctp_params[key_c]
self.ctp_params[key_c] = "COUNTERPART_"+self.ctp_params[key_c] else:
self.pivot_params[key_p] = ["PIVOT_" + column for column in self.pivot_params[key_p]]
self.ctp_params[key_c] = ["COUNTERPART_" + column for column in self.ctp_params[key_c]]
from pyspark.sql.functions import sum, collect_list, round, when, col, lit
pivot_cols = self.pivot_params["columns-transaction"].copy()
if self.pivot_params["amount-column"] in pivot_cols:
pivot_cols.remove(self.pivot_params["amount-column"])
ctp_cols = self.ctp_params["columns-transaction"].copy()
if self.ctp_params["amount-column"] in ctp_cols:
ctp_cols.remove(self.ctp_params["amount-column"])
max_combinations = self.max_combinations
# Ejecutamos lógica de excluir registros
if len(self.pivot_params["columns-group"]) == 0 and len(self.ctp_params["columns-group"]) == 0:
raise RuntimeError(f"Debe haber al menos pivot o contraparte agrupado")
# Caso: 1 - Muchos
elif len(self.pivot_params["columns-group"]) == 0 and len(self.ctp_params["columns-group"]) > 0:
ctp_df2 = ctp_df.groupby(self.ctp_params["columns-group"]). \
agg(
round(sum(self.ctp_params["amount-column"]), ROUND_DECIMAL).alias(self.ctp_params["amount-column"]),
collect_list(self.ctp_params["id-column"]).alias(self.ctp_params["id-column"]))
pivot_df2 = pivot_df
# Caso: Muchos - 1
elif len(self.pivot_params["columns-group"]) > 0 and len(self.ctp_params["columns-group"]) == 0:
pivot_df2 = pivot_df.groupby(self.pivot_params["columns-group"]). \
agg(round(sum(self.pivot_params["amount-column"]), ROUND_DECIMAL).alias(
self.pivot_params["amount-column"]),
collect_list(self.pivot_params["id-column"]).alias(self.pivot_params["id-column"]))
ctp_df2 = ctp_df.limit(1)
# Caso: Muchos - Muchos
elif len(self.pivot_params["columns-group"]) > 0 and len(self.ctp_params["columns-group"]) > 0:
pivot_df2 = pivot_df.groupby(self.pivot_params["columns-group"]). \
agg(round(sum(self.pivot_params["amount-column"]), ROUND_DECIMAL).alias(
self.pivot_params["amount-column"]),
collect_list(self.pivot_params["id-column"]).alias(self.pivot_params["id-column"]))
ctp_df2 = ctp_df.groupby(self.ctp_params["columns-group"]). \
agg(
round(sum(self.ctp_params["amount-column"]), ROUND_DECIMAL).alias(self.ctp_params["amount-column"]),
collect_list(self.ctp_params["id-column"]).alias(self.ctp_params["id-column"]))
condition = [pivot_df2[col1] == ctp_df2[col2] for col1, col2 in
zip(self.pivot_params["columns-transaction"],
self.ctp_params["columns-transaction"])]
total_merged = pivot_df2.join(ctp_df2, condition, 'left')
total_merged = total_merged.withColumn("DIFF",
when(col(self.ctp_params["columns-transaction"][0]).isNotNull(),
lit(0)).otherwise(lit(None)))
total_merged = total_merged.select(*pivot_df2.columns, "DIFF")
condition = [total_merged[col1] == ctp_df2[col2] for col1, col2 in zip(pivot_cols, ctp_cols)]
merged = total_merged.join(ctp_df2, condition)
merged = merged.withColumn("DIFF", when(col("DIFF").isNull(),
total_merged[self.pivot_params["amount-column"]] - ctp_df2[
self.ctp_params["amount-column"]]).otherwise(col("DIFF")))
merged_df = merged.withColumn("DIFF", round(merged["DIFF"], ROUND_DECIMAL))
if self.exclude_pivot:
df = pivot_df
group_cols = self.pivot_params["columns-group"]
amount_col = self.pivot_params["amount-column"]
id_col = self.pivot_params["id-column"]
else: else:
self.pivot_params[key_p] = ["PIVOT_"+column for column in self.pivot_params[key_p]] df = ctp_df
self.ctp_params[key_c] = ["COUNTERPART_" + column for column in self.ctp_params[key_c]] group_cols = self.ctp_params["columns-group"]
amount_col = self.ctp_params["amount-column"]
from pyspark.sql.functions import sum, collect_list, round, when, col, lit id_col = self.ctp_params["id-column"]
pivot_cols = self.pivot_params["columns-transaction"].copy() total_tmp_cols = group_cols + ["DIFF"]
if self.pivot_params["amount-column"] in pivot_cols: df3 = df.join(merged_df.select(*total_tmp_cols), group_cols)
pivot_cols.remove(self.pivot_params["amount-column"]) columns = [col(column) for column in group_cols]
columns_amount = columns.copy()
ctp_cols = self.ctp_params["columns-transaction"].copy() columns_amount.append(col(amount_col))
if self.ctp_params["amount-column"] in ctp_cols: custom = udf(custom_func_udf, ArrayType(IntegerType()))
ctp_cols.remove(self.ctp_params["amount-column"]) resultado = df3.groupby(*columns).agg(
custom(collect_list(amount_col), collect_list(id_col), collect_list("DIFF"),collect_list(EXCLUDE_ROWS_FIELD),
max_combinations = self.max_combinations lit(max_combinations)).alias("LISTA_DIFF"))
# Ejecutamos lógica de excluir registros meged2 = resultado.join(merged_df, group_cols, 'left')
if len(self.pivot_params["columns-group"]) == 0 and len(self.ctp_params["columns-group"]) == 0: handle_array_udf = udf(handle_array, ArrayType(IntegerType()))
raise RuntimeError(f"Debe haber al menos pivot o contraparte agrupado")
meged2 = meged2.withColumn("LISTA_DIFF", handle_array_udf("LISTA_DIFF"))
# Caso: 1 - Muchos
elif len(self.pivot_params["columns-group"]) == 0 and len(self.ctp_params["columns-group"]) > 0: meged2 = meged2.filter((col("DIFF") == 0) | ((col("DIFF") != 0) & (size(col("LISTA_DIFF")) > 0)))
ctp_df2 = ctp_df.groupby(self.ctp_params["columns-group"]).agg({ if self.exclude_pivot:
self.ctp_params["amount-column"]: 'sum', # Sumar la columna de cantidades meged2 = meged2.withColumn("INTER_PIVOT_ID", array_except(meged2[self.pivot_params["id-column"]],
self.ctp_params["id-column"]: list meged2["LISTA_DIFF"]))
}) meged2 = meged2.withColumnRenamed(self.ctp_params["id-column"], "INTER_CTP_ID")
ctp_df2 = ctp_df2.reset_index() if meged2.schema["INTER_CTP_ID"].dataType == LongType():
meged2 = meged2.withColumn("INTER_CTP_ID",
pivot_df2 = pivot_df array(col("INTER_CTP_ID")).cast(ArrayType(LongType())))
else:
# Caso: Muchos - 1 meged2 = meged2.withColumn("INTER_CTP_ID",
elif len(self.pivot_params["columns-group"]) > 0 and len(self.ctp_params["columns-group"]) == 0: array_except(meged2[self.ctp_params["id-column"]], meged2["LISTA_DIFF"]))
pivot_df2 = pivot_df.groupby(self.pivot_params["columns-group"]).agg({ meged2 = meged2.withColumnRenamed(self.pivot_params["id-column"], "INTER_PIVOT_ID")
self.pivot_params["amount-column"]: 'sum', if meged2.schema["INTER_PIVOT_ID"].dataType == LongType():
self.pivot_params["id-column"]: list meged2 = meged2.withColumn("INTER_PIVOT_ID",
}) array(col("INTER_PIVOT_ID")).cast(ArrayType(LongType())))
pivot_df2 = pivot_df2.reset_index() meged2 = meged2.toPandas()
return meged2
ctp_df2 = ctp_df
except Exception as e:
# Caso: Muchos - Muchos self.timeout = True
elif len(self.pivot_params["columns-group"]) > 0 and len(self.ctp_params["columns-group"]) > 0: self.app.logger.error(f"Error de Timeout. Error: {e}")
pivot_df2 = pivot_df.groupby(self.pivot_params["columns-group"]).agg({ raise TimeoutError("Tiempo de ejecución superado.")
self.pivot_params["amount-column"]: 'sum',
self.pivot_params["id-column"]: list response["status"] = StatusEnum.OK.name
}) response["result"] = __process(source_obs)
pivot_df2 = pivot_df2.reset_index()
ctp_df2 = ctp_df.groupby(self.ctp_params["columns-group"]).agg({
self.ctp_params["amount-column"]: 'sum', # Sumar la columna de cantidades
self.ctp_params["id-column"]: list
})
ctp_df2 = ctp_df2.reset_index()
pivot_df2[self.pivot_params["amount-column"]] = pivot_df2[self.pivot_params["amount-column"]].round(
ROUND_DECIMAL)
ctp_df2[self.ctp_params["amount-column"]] = ctp_df2[self.ctp_params["amount-column"]].round(
ROUND_DECIMAL)
total_merged = pivot_df2.merge(ctp_df2, 'left', left_on=self.pivot_params["columns-transaction"],
right_on=self.ctp_params["columns-transaction"])
total_merged = total_merged.map_partitions(self.add_diff_column)
selected_columns = list(pivot_df2.columns) + ['DIFF']
total_merged = total_merged[selected_columns]
merged = total_merged.merge(ctp_df2, 'inner', left_on=pivot_cols, right_on=ctp_cols)
merged['DIFF'] = merged['DIFF'].where(merged['DIFF'].notnull(),
merged[self.pivot_params["amount-column"]] - merged[
self.ctp_params["amount-column"]])
if len(self.pivot_params["columns-group"]) == 0 and len(self.ctp_params["columns-group"]) > 0:
merged = merged.drop_duplicates(subset=pivot_cols)
elif len(self.pivot_params["columns-group"]) > 0 and len(self.ctp_params["columns-group"]) == 0:
merged = merged.drop_duplicates(subset=ctp_cols)
merged_df = merged.assign(DIFF=lambda partition: partition["DIFF"].round(ROUND_DECIMAL))
if self.exclude_pivot:
df = pivot_df
group_cols = self.pivot_params["columns-group"]
amount_col = self.pivot_params["amount-column"]
id_col = self.pivot_params["id-column"]
else:
df = ctp_df
group_cols = self.ctp_params["columns-group"]
amount_col = self.ctp_params["amount-column"]
id_col = self.ctp_params["id-column"]
total_tmp_cols = group_cols + ["DIFF"]
df3 = df.merge(merged_df[total_tmp_cols], 'inner', on=group_cols)
df3 = df3.compute()
total_cols = group_cols + [amount_col, id_col, EXCLUDE_ROWS_FIELD, "DIFF"]
resultado = df3.groupby(group_cols)[total_cols].apply(lambda x: custom_func(x, amount_col, id_col, max_combinations))
resultado = resultado.reset_index()
if len(resultado.columns) == 1:
resultado = pd.DataFrame([], columns=group_cols + ["LISTA_DIFF"])
else:
resultado.columns = group_cols + ["LISTA_DIFF"]
resultado = dd.from_pandas(resultado, npartitions=4)
meged2 = resultado.merge(merged_df, 'left', group_cols)
meged2 = meged2.map_partitions(lambda partition: partition.assign(
LISTA_DIFF=partition['LISTA_DIFF'].apply(lambda x: [] if pd.isna(x) else x)), meta=meged2.dtypes.to_dict())
meged2 = meged2[
(meged2['DIFF'] == 0) |
((meged2['DIFF'] != 0) & meged2['LISTA_DIFF'].apply(
lambda x: True if not pd.isna(x) and ((isinstance(x, List) and len(x) > 0) or (isinstance(x, str) and len(x) > 2)) else False))
]
meged2 = meged2.compute()
if meged2.empty:
pass
elif self.exclude_pivot:
meged2['INTER_PIVOT_ID'] = meged2.apply(lambda row: self.array_except(row[self.pivot_params["id-column"]], row['LISTA_DIFF']), axis=1)
meged2 = meged2.rename(columns={self.ctp_params["id-column"]: "INTER_CTP_ID"})
if meged2['INTER_CTP_ID'].dtype == 'int64':
meged2['INTER_CTP_ID'] = meged2['INTER_CTP_ID'].apply(lambda x: [x]).astype('object')
else:
meged2['INTER_CTP_ID'] = meged2.apply(lambda row: self.array_except(row[self.ctp_params["id-column"]], row['LISTA_DIFF']), axis=1)
meged2 = meged2.rename(columns={self.pivot_params["id-column"]: "INTER_PIVOT_ID"})
if meged2['INTER_PIVOT_ID'].dtype == 'int64':
meged2['INTER_PIVOT_ID'] = meged2['INTER_PIVOT_ID'].apply(lambda x: [x]).astype('object')
return meged2
except TimeoutError as e: except TimeoutError as e:
raise TimeoutError(f"Tiempo límite superado. {e}") response["status"] = StatusEnum.TIMEOUT.name
response["message"] = e
self.output = __process(source_obs) except Exception as e:
response["status"] = StatusEnum.ERROR.name
response["message"] = e
finally:
self.output = response
def response(self): def response(self):
return self.output return self.output
def add_diff_column(self, partition): def timeout_response(self):
partition['DIFF'] = np.where(partition[self.ctp_params["columns-transaction"][0]].notnull(), 0, np.nan) return self.timeout
return partition
def handle_array(self, x):
if isinstance(x, np.ndarray):
return x
else:
return []
def array_except(self, arr1, arr2): def array_except(self, arr1, arr2):
# print(arr2)
if arr2 is None: if arr2 is None:
return arr1 return arr1
elif not isinstance(arr2, List): else:
cadena_sin_corchetes = arr2.strip('[]') return [item for item in arr1 if item not in arr2]
partes = cadena_sin_corchetes.split()
# print(partes) def createSession(self, name: str = "app_engine_spark"):
arr2 = [int(numero) for numero in partes] try:
arr1 = json.loads(arr1.replace(" ", "")) from pyspark.sql import SparkSession
return [item for item in arr1 if item not in arr2] session = SparkSession.builder.master(MASTER) \
.appName(name) \
.config("spark.jars", MYSQL_JAR_PATH) \
.config("spark.executor.extraClassPath", MYSQL_JAR_PATH) \
.config("spark.driver.extraClassPath", MYSQL_JAR_PATH) \
.config("spark.driver.memory", DRIVER_MEMORY) \
.config("spark.executor.memory", EXECUTOR_MEMORY) \
.getOrCreate()
self.app.logger.info(f"Sesión creada exitosamente")
return session
except Exception as e:
raise Exception(f"Error creando sesion Spark. {e}")
def handle_array(x):
if isinstance(x, List):
return x
else:
return []
def custom_func(group, amount_field, id_field, max_combinations): def custom_func(group, amount_field, id_field, max_combinations):
diff_value = group["DIFF"].values[0]
if np.isnan(diff_value): diff = int(group["DIFF"].values[0]*(10**ROUND_DECIMAL))
return None
diff = int(diff_value*(10**ROUND_DECIMAL))
if pd.isna(diff) or diff == 0: if pd.isna(diff) or diff == 0:
return None return None
group = group[group[EXCLUDE_ROWS_FIELD] == 'S'] group = group[group[EXCLUDE_ROWS_FIELD] == 'S']
...@@ -281,45 +298,67 @@ def custom_func(group, amount_field, id_field, max_combinations): ...@@ -281,45 +298,67 @@ def custom_func(group, amount_field, id_field, max_combinations):
values *= (10**ROUND_DECIMAL) values *= (10**ROUND_DECIMAL)
values = values.astype(np.int64) values = values.astype(np.int64)
ids = group[id_field].values ids = group[id_field].values
tam = len(values)
tam = tam if tam <= max_combinations else max_combinations
result = subset_sum_iter(values, diff, tam) result = find_subset(values, diff, max_combinations)
indices = ids[np.isin(values, result)] if len(result) <1:
return None
result = result[0]
indices = []
for idx, val in zip(ids, values):
if val in result:
indices.append(idx)
result.remove(val)
return indices return indices
@jit(nopython=False)
def subset_sum_iter(numbers, target, num_elements):
# Initialize solutions list
final = typed.List.empty_list(types.int64)
for step in range(1, num_elements+1):
# Build first index by taking the first num_elements from the numbers
indices = list(range(step))
while True:
for i in range(step):
if indices[i] != i + len(numbers) - step:
break
else:
# No combinations left
break
# Increase current index and all its following ones
indices[i] += 1
for j in range(i + 1, step):
indices[j] = indices[j - 1] + 1
# Check current solution
solution = typed.List.empty_list(types.int64)
for i in indices:
solution.append(numbers[i])
if round(sum(solution), ROUND_DECIMAL) == target:
final = solution
break
if len(final) > 0:
break
return final
# def custom_func_udf(amount_values, id_values, diffs, max_combinations):
# diff = diffs[0]
# if pd.isna(diff) or diff == 0:
# return None
# diff = int(diff * FACTOR)
# amount_values = [int(value * FACTOR) for value in amount_values]
# result = find_subset(amount_values, diff, max_combinations)
# if len(result) <1:
# return None
# result = result[0]
# indices = []
# for idx, val in zip(id_values, amount_values) :
# if val in result:
# indices.append(idx)
# result.remove(val)
# return indices
def custom_func_udf(amount_values, id_values, diffs, excludes, max_combinations):
diff = diffs[0]
if pd.isna(diff) or diff == 0:
return None
diff = int(diff * FACTOR)
amount_values = [int(value * FACTOR) for value, exclude in zip(amount_values, excludes) if exclude=="S"]
dir_name = str(uuid.uuid4())
prefix = "/tmp/" + dir_name + "_"
tmp_file_arr1, tmp_file_arr2 = "values.txt", "target.txt"
full_path_arr1, full_path_arr2 = prefix + tmp_file_arr1, prefix + tmp_file_arr2
with open(full_path_arr1, 'w') as archivo:
archivo.writelines([f'{entero}\n' for entero in amount_values])
with open(full_path_arr2, 'w') as archivo:
archivo.write(str(diff))
executable_path = '/home/evillarroel/Descargas/Telegram Desktop/subset_sum_linux'
indices = []
for comb in range(1, max_combinations+1):
argumentos = [full_path_arr1, full_path_arr2, str(comb), '1', '1', 'false', 'false']
result = subprocess.run([executable_path] + argumentos, check=True, capture_output=True, text=True)
result = str(result)
if "keys:[" in result:
match = result[result.index("keys:[") + 5:result.index("keys remainder") - 20]
match = match.replace("targets:", "").replace("+", ",")
match = match.split("==")[0].replace(" ", "")
match = json.loads(match)
for idx, val in zip(id_values, amount_values):
if val in match:
indices.append(idx)
match.remove(val)
break
os.remove(full_path_arr1), os.remove(full_path_arr2)
return indices
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment