Most of us have used systems that recommend us new content based on our previous interactions with the system and what other users with similar interests with us are using. Amazon pioneered product recommendation with the now ubiquitous "Customers who bought this also bought" feature; it is now a core business feature of most media platforms, for example Netflix, Spotify, YouTube and iTunes.
At the core of this technologies lie recommendation algorithms based on the collaborative filtering technique.
To demonstrate how we can build a recommendation pipeline from the ground up, we will use the Audioscrobbler dataset. Audioscrobbler track(ed) user's music preferences and recommended new bands and songs to listen to. In this example, we buid a system that allows us to ask questions like: "I like Deep Purple and Led Zeppelin, which other band should I listen to?" and get answers like "Black Sabbath"!
We start our exploration by loading all required modules
import org.apache.spark.ml.recommendation.ALSModel
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.ml.recommendation.ALS
import scala.util.Random
import org.apache.spark.ml.stat.Correlation
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.{Matrix, Vectors}
The AudioScrobbler dataset comes in three files:
artist_data
contains a map between artist ids and artist namesuser_artist_data
for each user_id and artist_id, this file contains the number of playsartist_alias
contains artist_ids that resolver to the same artist (we will not use this)We first load the playcounts file as a text file and print its first 5 lines
// Load the play count schema
val rawPlayCounts = spark.sparkContext.
textFile("../datasets/audioscrobbler/user_artist_data.txt")
rawPlayCounts.take(5).foreach(println)
1000002 1 55 1000002 1000006 33 1000002 1000007 8 1000002 1000009 144 1000002 1000010 314
We then define a schema, to convert the raw text file to a DataSet, which is much easier to work with (and is also what the ALS algorithm expects as input).
// Define the play count table schema
val pcSchemaString = "user_id artist_id playcount"
val pcFields =
pcSchemaString.split(" ").
map(fieldName => StructField(fieldName, IntegerType, nullable = true))
val pcSchema = StructType(pcFields)
Now that we have a schema, we need to convert our raw data RDD into a DataSet. The source RDD is essentially a list of file lines. To convert it, we split each line into an array of Strings which we then convert to a Row
of 3 items. Row
is Spark SQL's equivalent of Scala's Tuple
, in the sense that it contains an arbitrary number of items of various types.
To obtain the DataSet, we combine our list of Row
s with the schema we defined earlier. We also register the resulting DataSet to the SQL context, which allows us to run SQL queries on it.
val rowRDD = rawPlayCounts.
map{_.split(' ')}.
map{x => Row(x(0).toInt, x(1).toInt, x(2).toInt)}
// Create a Spark data frame
val playcounts = spark.createDataFrame(rowRDD, pcSchema).cache()
playcounts.createOrReplaceTempView("playcounts")
Let's run an SQL query to verify that it works. We select the top 5 users by number of play counts:
spark.sql("""select user_id, sum(playcount) as total_plays
from playcounts
group by user_id
order by total_plays
desc limit 5""").show
+-------+-----------+ |user_id|total_plays| +-------+-----------+ |1059637| 674412| |2064012| 548427| |2069337| 393515| |2023977| 285978| |1046559| 183972| +-------+-----------+
We also load the artist names as well; while this is not strictly necessary for our recommendation algorithm to work, it is nice to see band names rather than ids in the output.
The artist_data
file is a bit more challenging to handle, as it contains errors. While it is tab-separated, sometimes the artist name is missing or the artist ID cannot be mapped to an integer. This is a nice use case of Scala's flatMap
and monads. In the parsing code, instead of returning concrete values, we wrap our return values in the Option
monad: when we can actually return something, we return a Some
, otherwise a None
. flatMap
will automatically ignore the None
values and extract the contents of Some
at the flattening stage.
// Load the artist names as well
val rawArtists = spark.sparkContext.
textFile("../datasets/audioscrobbler/artist_data.txt")
val aSchemaString = "id artist"
val aSchema = StructType(Array(StructField("id", IntegerType, nullable=true),
StructField("artist", StringType, nullable=true)))
val rowRDD = rawArtists.flatMap{line =>
val (id, artist) = line.span( _ != '\t')
if (artist.isEmpty) {
None
} else {
try {
Some(Row(id.toInt, artist.trim))
} catch {
case _ : Exception => None
}
}
}
val artists = spark.createDataFrame(rowRDD, aSchema).cache()
artists.createOrReplaceTempView("artists")
artists.count
1848281
We now perform some basic exploration of the dataset to make sure the data makes sense.
We first perform a join by means of API calls:
playcounts.join(artists, playcounts.col("artist_id") === artists.col("id")).
select("user_id", "artist", "playcount").
show(10)
+-------+----------+---------+ |user_id| artist|playcount| +-------+----------+---------+ |1000019|The Smiths| 1| |1000020|The Smiths| 199| |1000022|The Smiths| 20| |1000033|The Smiths| 466| |1000056|The Smiths| 10| |1000067|The Smiths| 18| |1000070|The Smiths| 399| |1000073|The Smiths| 3| |1000077|The Smiths| 15| |1000082|The Smiths| 16| +-------+----------+---------+ only showing top 10 rows
Then, we perform the same join using SQL
spark.sql("""select pc.user_id, a.artist, pc.playcount
from playcounts pc join artists a on a.id = pc.artist_id
limit 10""").show
+-------+----------+---------+ |user_id| artist|playcount| +-------+----------+---------+ |1000019|The Smiths| 1| |1000020|The Smiths| 199| |1000022|The Smiths| 20| |1000033|The Smiths| 466| |1000056|The Smiths| 10| |1000067|The Smiths| 18| |1000070|The Smiths| 399| |1000073|The Smiths| 3| |1000077|The Smiths| 15| |1000082|The Smiths| 16| +-------+----------+---------+
What are the most played artists in our dataset?
spark.sql("""select a.artist, sum(pc.playcount) as all_playcounts
from playcounts pc join artists a on a.id = pc.artist_id
group by a.artist
order by all_playcounts desc
limit 10""").show
+----------------+--------------+ | artist|all_playcounts| +----------------+--------------+ | Radiohead| 2502133| | The Beatles| 2259185| | Green Day| 1930592| | Metallica| 1542806| |System of a Down| 1425942| | Pink Floyd| 1399419| | Nine Inch Nails| 1361392| | Modest Mouse| 1328882| | Bright Eyes| 1234387| | Nirvana| 1203227| +----------------+--------------+
What are the most active users in our datasets in terms of number of song plays?
spark.sql("""select user_id, sum(pc.playcount) as all_playcounts
from playcounts pc
group by user_id
having all_playcounts > 100
order by all_playcounts desc
limit 10""").show
+-------+--------------+ |user_id|all_playcounts| +-------+--------------+ |1059637| 674412| |2064012| 548427| |2069337| 393515| |2023977| 285978| |1046559| 183972| |1052461| 175822| |1070932| 168977| |1031009| 167273| |2020513| 165642| |2062243| 151504| +-------+--------------+
Now that our data is in the required format, we need to train a model that will learn from the play counts dataset what songs to recommend to users currently in the dataset. Our formulation of the problem is rather limited as it neither accounts for new users (what to recommend to students whose tastes we do not know yet?) nor for new items in the dataset.
The formulaton of the problem is as follows: Given an array whose rows are users (raters), columns are bands (items to be rated) and cells represent a rating from a user to a band, can we predict for a rater his/her score for all items he has not rated yet? The problem is more intuitive to understand in a table format:
item1 | item 2 | item 3 | item 4 | |
---|---|---|---|---|
user 1 | 3 | 3 | 4 | 5 |
user 2 | 3 | 1 | 4 | |
user 2 | 4 | 2 | ||
user 2 | 5 | 2 | 4 | 4 |
Foo | 4 | ???? |
The question we need to answer is the grade that user Foo will give to item 4.
We will use Spark ML's default implementation of collaborative filtering, namely the Alternating Least Squares method.
val Array(training, test) = playcounts.randomSplit(Array(0.8, 0.2))
The model is trained with mostly default settings.
val model = new ALS().
setSeed(42).
setImplicitPrefs(true).
setNumBlocks(10).
setRank(10).
setMaxIter(5).
setUserCol("user_id").
setItemCol("artist_id").
setRatingCol("playcount").
setPredictionCol("prediction").
fit(training)
ALS will first convert the input data to preferences, i.e. either 0 or 1 depending on the relative frequency of each artist in the user's portion of the dataset. This is because our input data is composed of implicit preferences, so we take indications of how a user would rate a band based on the number of plays.
Once we have a model, we can use it to make predictions. What we get as a result of a prediction is (roughly) an indicator of how closely a given artist will be to the user's existing preferences.
model.setColdStartStrategy("drop")
val predictions = model.transform(test)
predictions.
sample(true, 0.10).
orderBy(desc("user_id")).
join(artists, col("artist_id") === artists.col("id")).
select("user_id", "artist", "playcount", "prediction").
show(10)
+-------+----------+---------+----------+ |user_id| artist|playcount|prediction| +-------+----------+---------+----------+ |2440851|The Smiths| 1| 0.3520709| |2434693|The Smiths| 2| 0.4993019| |2433747|The Smiths| 3| 0.5907715| |2433323|The Smiths| 2| 0.494595| |2427874|The Smiths| 17|0.37485796| |2425447|The Smiths| 1|0.24094674| |2423986|The Smiths| 1| 0.6224065| |2423673|The Smiths| 5| 0.3433353| |2421721|The Smiths| 2|0.43127826| |2421721|The Smiths| 2|0.43127826| +-------+----------+---------+----------+ only showing top 10 rows
We can get an intuitive feeling of how the recommendation works by examining individual cases. We first define a function that given a userId
it will return the list of artists this user has played order by number of plays.
def currentlyLikes(data: Dataset[Row], userId: Int) = {
data.join(artists, training.col("artist_id") === artists.col("id")).
filter(x => x.getAs[Int]("user_id") == userId).
orderBy(desc("playcount")).
select("user_id", "artist", "playcount")
}
Then, we define our recommendation function: given a trained model and a userId, it will return a dataset with artist recommendations and their scores.
def recommend(model: ALSModel, userID: Int, howMany: Int) : Dataset[Row] = {
val toRecommend = model.
itemFactors.
withColumnRenamed("id","artist_id").
select("artist_id").
withColumn("user_id", lit(userID))
model.
transform(toRecommend).
select("artist_id", "prediction").
join(artists, col("artist_id") === artists.col("id")).
orderBy(desc("prediction")).
select("artist", "prediction").
limit(howMany)
}
In the case of user 2100000, the recommendations look pretty good: the user seems to like modern metal music and the algorithm suggests some more old school options.
val u1 = 2100000
currentlyLikes(training, u1).show(10)
recommend(model, u1, 10).show
+-------+--------------------+---------+ |user_id| artist|playcount| +-------+--------------------+---------+ |2100000|Red Hot Chili Pep...| 197| |2100000| Symphony X| 166| |2100000| Pearl Jam| 156| |2100000| The Haunted| 146| |2100000| Mastodon| 146| |2100000| Pain of Salvation| 145| |2100000| Shadows Fall| 113| |2100000| Muse| 112| |2100000| Guano Apes| 89| |2100000| Opeth| 80| +-------+--------------------+---------+ only showing top 10 rows +--------------------+----------+ | artist|prediction| +--------------------+----------+ | Metallica| 1.1241852| | Iron Maiden| 1.1201184| | In Flames| 1.094646| | System of a Down| 1.0938909| | Rammstein| 1.0563129| | Pantera| 1.0539472| | Megadeth| 1.045079| |Queens of the Sto...| 1.042593| | Slipknot| 1.0424271| | Cradle of Filth| 1.0408039| +--------------------+----------+
The recommendations for user 2062243 are more... interesting, as the user's interests are not particularly well defined.
val u2 = 2062243
currentlyLikes(training, u2).show(10)
recommend(model, u2, 10).show(10)
+-------+--------------------+---------+ |user_id| artist|playcount| +-------+--------------------+---------+ |2062243| Morrissey| 7193| |2062243| Mouse on Mars| 4913| |2062243| The Movielife| 3983| |2062243| The Beatles| 3658| |2062243| Led Zeppelin| 3354| |2062243| Mogwai| 2843| |2062243| Queen| 1888| |2062243|Motion City Sound...| 1706| |2062243| Talib Kweli| 1700| |2062243| Mudvayne| 1470| +-------+--------------------+---------+ only showing top 10 rows +---------------+----------+ | artist|prediction| +---------------+----------+ |Various Artists| 1.5014662| | Wu-Tang Clan| 1.379993| | The Specials| 1.3798537| | Andrew W.K.| 1.3796811| | Outkast| 1.3729253| | Sage Francis| 1.3654457| | N.W.A| 1.3599284| | Dr. Dre| 1.3540987| | MC Chris| 1.3460674| | Public Enemy| 1.3458596| +---------------+----------+
Parts of the code and the original idea for this work was taken from Ruza, Laserson, Owen and Wils, "Advanced Analytics with Spark", O' Reilly, 2017.