summary refs log tree commit diff
path: root/nixos/tests/spark
diff options
context:
space:
mode:
authorillustris <rharikrishnan95@gmail.com>2021-09-17 22:31:01 +0530
committerillustris <rharikrishnan95@gmail.com>2021-09-17 22:40:06 +0530
commit13839b0022fee66a1291792c47f6bc2b71b91895 (patch)
tree8e3573de844269b13ca3f6c9d7b3e14712aab849 /nixos/tests/spark
parentdd987c2dbed988f573734f51f4f28c4c56f58b6b (diff)
downloadnixpkgs-13839b0022fee66a1291792c47f6bc2b71b91895.tar
nixpkgs-13839b0022fee66a1291792c47f6bc2b71b91895.tar.gz
nixpkgs-13839b0022fee66a1291792c47f6bc2b71b91895.tar.bz2
nixpkgs-13839b0022fee66a1291792c47f6bc2b71b91895.tar.lz
nixpkgs-13839b0022fee66a1291792c47f6bc2b71b91895.tar.xz
nixpkgs-13839b0022fee66a1291792c47f6bc2b71b91895.tar.zst
nixpkgs-13839b0022fee66a1291792c47f6bc2b71b91895.zip
nixos/spark: add test
Diffstat (limited to 'nixos/tests/spark')
-rw-r--r--nixos/tests/spark/default.nix28
-rw-r--r--nixos/tests/spark/spark_sample.py40
2 files changed, 68 insertions, 0 deletions
diff --git a/nixos/tests/spark/default.nix b/nixos/tests/spark/default.nix
new file mode 100644
index 00000000000..254cdec6e6b
--- /dev/null
+++ b/nixos/tests/spark/default.nix
@@ -0,0 +1,28 @@
+import ../make-test-python.nix ({...}: {
+  name = "spark";
+
+  nodes = {
+    worker = { nodes, pkgs, ... }: {
+      virtualisation.memorySize = 1024;
+      services.spark.worker = {
+        enable = true;
+        master = "master:7077";
+      };
+    };
+    master = { config, pkgs, ... }: {
+      services.spark.master = {
+        enable = true;
+        bind = "0.0.0.0";
+      };
+      networking.firewall.allowedTCPPorts = [ 22 7077 8080 ];
+    };
+  };
+
+  testScript = ''
+    master.wait_for_unit("spark-master.service")
+    worker.wait_for_unit("spark-worker.service")
+    worker.copy_from_host( "${./spark_sample.py}", "/spark_sample.py" )
+    assert "<title>Spark Master at spark://" in worker.succeed("curl -sSfkL http://master:8080/")
+    worker.succeed("spark-submit --master spark://master:7077 --executor-memory 512m --executor-cores 1 /spark_sample.py")
+  '';
+})
diff --git a/nixos/tests/spark/spark_sample.py b/nixos/tests/spark/spark_sample.py
new file mode 100644
index 00000000000..c4939451eae
--- /dev/null
+++ b/nixos/tests/spark/spark_sample.py
@@ -0,0 +1,40 @@
+from pyspark.sql import Row, SparkSession
+from pyspark.sql import functions as F
+from pyspark.sql.functions import udf
+from pyspark.sql.types import *
+from pyspark.sql.functions import explode
+
+def explode_col(weight):
+    return int(weight//10) * [10.0] + ([] if weight%10==0 else [weight%10])
+
+spark = SparkSession.builder.getOrCreate()
+
+dataSchema = [
+    StructField("feature_1", FloatType()),
+    StructField("feature_2", FloatType()),
+    StructField("bias_weight", FloatType())
+]
+
+data = [
+    Row(0.1, 0.2, 10.32),
+    Row(0.32, 1.43, 12.8),
+    Row(1.28, 1.12, 0.23)
+]
+
+df = spark.createDataFrame(spark.sparkContext.parallelize(data), StructType(dataSchema))
+
+normalizing_constant = 100
+sum_bias_weight = df.select(F.sum('bias_weight')).collect()[0][0]
+normalizing_factor = normalizing_constant / sum_bias_weight
+df = df.withColumn('normalized_bias_weight', df.bias_weight * normalizing_factor)
+df = df.drop('bias_weight')
+df = df.withColumnRenamed('normalized_bias_weight', 'bias_weight')
+
+my_udf = udf(lambda x: explode_col(x), ArrayType(FloatType()))
+df1 = df.withColumn('explode_val', my_udf(df.bias_weight))
+df1 = df1.withColumn("explode_val_1", explode(df1.explode_val)).drop("explode_val")
+df1 = df1.drop('bias_weight').withColumnRenamed('explode_val_1', 'bias_weight')
+
+df1.show()
+
+assert(df1.count() == 12)