PySpark DataFrames Practice Questions with Answers

PySpark DataFrame Practice Questions

PySpark DataFrames provide a powerful and user-friendly API for working with structured and semi-structured data. In this article, we present a set of practice questions to help you reinforce your understanding of PySpark DataFrames and their operations.

1. Loading Data

Load the "sales_data.csv" file into a PySpark DataFrame. The CSV file contains the following columns: "transaction_id", "customer_id", "product_name", "quantity", and "price". Ensure that the DataFrame correctly infers the schema and displays the first 5 rows.

from pyspark.sql import SparkSession

# Initialize SparkSession
spark = SparkSession.builder.appName("DataFramePractice").getOrCreate()

# Load the CSV file into a DataFrame
df ="sales_data.csv", header=True, inferSchema=True)

# Display the first 5 rows of the DataFrame

2. Filtering Data

Filter the DataFrame to show only the transactions where the "quantity" is greater than or equal to 10 and the "price" is less than 50.

filtered_df = df.filter((df["quantity"] >= 10) & (df["price"] < 50))

3. Grouping and Aggregating Data

Find the total revenue generated by each product and display the results in descending order.

from pyspark.sql import functions as F

revenue_df = df.groupBy("product_name").agg(F.sum("quantity" * "price").alias("total_revenue"))
revenue_df = revenue_df.orderBy(F.desc("total_revenue"))

4. Joining DataFrames

Load the "customer_data.csv" file into another DataFrame. The CSV file contains the following columns: "customer_id", "customer_name", and "email". Perform an inner join between the "df" DataFrame and the "customer_df" DataFrame based on the "customer_id" column and display the results.

customer_df ="customer_data.csv", header=True, inferSchema=True)

joined_df = df.join(customer_df, on="customer_id", how="inner")

5. Data Transformation

Transform the DataFrame to add a new column "total_amount" that represents the total amount for each transaction (quantity * price).

from pyspark.sql import functions as F

# Add a new column "total_amount"
df = df.withColumn("total_amount", df["quantity"] * df["price"])

6. Handling Missing Values

Count the number of missing values in each column of the DataFrame and display the results.

missing_values_df =[F.count(F.when(F.isnan(c) | F.col(c).isNull(), c)).alias(c) for c in df.columns])

7. Data Visualization

Visualize the distribution of the "quantity" column using a histogram.

import matplotlib.pyplot as plt

# Convert DataFrame column to Pandas Series
quantity_series ="quantity").toPandas()["quantity"]

# Plot histogram
plt.hist(quantity_series, bins=20, edgecolor='black')
plt.title("Distribution of Quantity")

8. Working with Dates

Convert the "transaction_date" column to a DateType and extract the year from it into a new column "transaction_year".

from pyspark.sql.functions import year, to_date

# Convert "transaction_date" column to DateType
df = df.withColumn("transaction_date", to_date("transaction_date", "yyyy-MM-dd"))

# Extract year and create "transaction_year" column
df = df.withColumn("transaction_year", year("transaction_date"))

9. Data Aggregation and Window Functions

Calculate the average quantity of each product for the last three transactions using a window function.

from pyspark.sql import Window

# Define a Window specification
window_spec = Window.partitionBy("product_name").orderBy(F.desc("transaction_date")).rowsBetween(0, 2)

# Calculate average quantity for the last three transactions
df = df.withColumn("avg_quantity_last_three", F.avg("quantity").over(window_spec))

10. Pivot Table

Create a pivot table that shows the total quantity of each product for each year.

# Pivot table
pivot_table = df.groupBy("product_name").pivot("transaction_year").agg(F.sum("quantity"))

11. String Manipulation

Create a new column "upper_product_name" that contains the product names in uppercase.

# Uppercase product names
df = df.withColumn("upper_product_name", F.upper("product_name"))

12. User-Defined Functions (UDFs)

Create a UDF that calculates the total amount for each transaction and apply it to the DataFrame to add a new column "total_amount_udf".

# User-Defined Function
def calculate_total_amount(quantity, price):
    return quantity * price

# Register UDF
spark.udf.register("calculate_total_amount_udf", calculate_total_amount)

# Apply UDF to create "total_amount_udf" column
df = df.withColumn("total_amount_udf", F.expr("calculate_total_amount_udf(quantity, price)"))

13. Joins and Aggregations

Join the "df" DataFrame with the "customer_df" DataFrame using the "customer_id" column. Then, find the total revenue generated by each customer and display the results in descending order.

joined_df = df.join(customer_df, on="customer_id", how="inner")

revenue_by_customer = joined_df.groupBy("customer_name").agg(F.sum("quantity" * "price").alias("total_revenue"))
revenue_by_customer = revenue_by_customer.orderBy(F.desc("total_revenue"))

14. Filtering and Date Manipulation

Filter the DataFrame to show only the transactions that occurred after a specific date, and calculate the total revenue for each day.

from pyspark.sql.functions import col

# Define the specific date
specific_date = "2023-08-01"

# Filter transactions after the specific date
filtered_df = df.filter(col("transaction_date") > specific_date)

# Calculate total revenue for each day
daily_revenue = filtered_df.groupBy("transaction_date").agg(F.sum("quantity" * "price").alias("total_revenue"))

15. Working with Arrays

Create a new column "product_list" that contains an array of product names for each transaction.

# Working with Arrays
df = df.withColumn("product_list", F.array("product_name"))

16. Window Functions and Ranking

Rank the customers based on their total revenue generated, and show the top 5 customers.

# Rank customers based on total revenue
ranked_customers = revenue_by_customer.withColumn("rank", F.rank().over(Window.orderBy(F.desc("total_revenue"))))

# Show top 5 customers
top_5_customers = ranked_customers.filter(F.col("rank") <= 5)

17. Data Deduplication

Remove duplicate rows from the DataFrame based on all columns and display the deduplicated DataFrame.

# Deduplicate rows based on all columns
deduplicated_df = df.dropDuplicates()

18. Data Sampling

Take a random sample of 10% of the DataFrame and display the sampled data.

# Random sampling of 10%
sampled_df = df.sample(withReplacement=False, fraction=0.1)

19. Data Reshaping

Melt the DataFrame from a wide format to a long format by unpivoting the columns "product_1", "product_2", "product_3" into a single "product" column.

# Data Melt - Unpivot the columns
from pyspark.sql.functions import array, concat_ws, explode

melted_df = df.withColumn("product", explode(array("product_1", "product_2", "product_3"))) \
    .select("transaction_id", "customer_id", "product", "quantity", "price")

20. Handling Null Values

Replace the null values in the "product_name" column with a default value "Unknown".

# Replace null values in "product_name" column
df = df.fillna("Unknown", subset=["product_name"])



Contact Form