Recommending music

Most of us have used systems that recommend us new content based on our previous interactions with the system and what other users with similar interests with us are using. Amazon pioneered product recommendation with the now ubiquitous "Customers who bought this also bought" feature; it is now a core business feature of most media platforms, for example Netflix, Spotify, YouTube and iTunes.

At the core of this technologies lie recommendation algorithms based on the collaborative filtering technique.

To demonstrate how we can build a recommendation pipeline from the ground up, we will use the Audioscrobbler dataset. Audioscrobbler track(ed) user's music preferences and recommended new bands and songs to listen to. In this example, we buid a system that allows us to ask questions like: "I like Deep Purple and Led Zeppelin, which other band should I listen to?" and get answers like "Black Sabbath"!

We start our exploration by loading all required modules

In [1]:
import org.apache.spark.ml.recommendation.ALSModel
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.ml.recommendation.ALS
import scala.util.Random
import org.apache.spark.ml.stat.Correlation
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.{Matrix, Vectors}

The AudioScrobbler dataset comes in three files:

  • artist_data contains a map between artist ids and artist names
  • user_artist_data for each user_id and artist_id, this file contains the number of plays
  • artist_alias contains artist_ids that resolver to the same artist (we will not use this)

We first load the playcounts file as a text file and print its first 5 lines

In [2]:
// Load the play count schema
val rawPlayCounts = spark.sparkContext.
    textFile("../datasets/audioscrobbler/user_artist_data.txt")
rawPlayCounts.take(5).foreach(println)
1000002 1 55
1000002 1000006 33
1000002 1000007 8
1000002 1000009 144
1000002 1000010 314

We then define a schema, to convert the raw text file to a DataSet, which is much easier to work with (and is also what the ALS algorithm expects as input).

In [3]:
// Define the play count table schema
val pcSchemaString = "user_id artist_id playcount"
val pcFields = 
    pcSchemaString.split(" ").
    map(fieldName => StructField(fieldName, IntegerType, nullable = true))
val pcSchema = StructType(pcFields)

Now that we have a schema, we need to convert our raw data RDD into a DataSet. The source RDD is essentially a list of file lines. To convert it, we split each line into an array of Strings which we then convert to a Row of 3 items. Row is Spark SQL's equivalent of Scala's Tuple, in the sense that it contains an arbitrary number of items of various types.

To obtain the DataSet, we combine our list of Rows with the schema we defined earlier. We also register the resulting DataSet to the SQL context, which allows us to run SQL queries on it.

In [4]:
val rowRDD = rawPlayCounts.
             map{_.split(' ')}.
             map{x => Row(x(0).toInt, x(1).toInt, x(2).toInt)}

// Create a Spark data frame
val playcounts = spark.createDataFrame(rowRDD, pcSchema).cache()
playcounts.createOrReplaceTempView("playcounts")

Let's run an SQL query to verify that it works. We select the top 5 users by number of play counts:

In [5]:
spark.sql("""select user_id, sum(playcount) as total_plays 
             from playcounts 
             group by user_id 
             order by total_plays 
             desc limit 5""").show
+-------+-----------+
|user_id|total_plays|
+-------+-----------+
|1059637|     674412|
|2064012|     548427|
|2069337|     393515|
|2023977|     285978|
|1046559|     183972|
+-------+-----------+

We also load the artist names as well; while this is not strictly necessary for our recommendation algorithm to work, it is nice to see band names rather than ids in the output.

The artist_data file is a bit more challenging to handle, as it contains errors. While it is tab-separated, sometimes the artist name is missing or the artist ID cannot be mapped to an integer. This is a nice use case of Scala's flatMap and monads. In the parsing code, instead of returning concrete values, we wrap our return values in the Option monad: when we can actually return something, we return a Some, otherwise a None. flatMap will automatically ignore the None values and extract the contents of Some at the flattening stage.

In [6]:
// Load the artist names as well
val rawArtists = spark.sparkContext.
    textFile("../datasets/audioscrobbler/artist_data.txt")

val aSchemaString = "id artist"
val aSchema = StructType(Array(StructField("id", IntegerType, nullable=true), 
                               StructField("artist", StringType, nullable=true)))

val rowRDD = rawArtists.flatMap{line => 
    val (id, artist) = line.span( _ != '\t')
    if (artist.isEmpty) {
        None
    } else {
        try {
            Some(Row(id.toInt, artist.trim))
        } catch {
            case _ : Exception => None
        }
    }
}

val artists = spark.createDataFrame(rowRDD, aSchema).cache()
artists.createOrReplaceTempView("artists")
artists.count
Out[6]:
1848281

Exploring the dataset

We now perform some basic exploration of the dataset to make sure the data makes sense.

Joining the two files

We first perform a join by means of API calls:

In [14]:
playcounts.join(artists, playcounts.col("artist_id") === artists.col("id")).
           select("user_id", "artist", "playcount").
           show(10)
+-------+----------+---------+
|user_id|    artist|playcount|
+-------+----------+---------+
|1000019|The Smiths|        1|
|1000020|The Smiths|      199|
|1000022|The Smiths|       20|
|1000033|The Smiths|      466|
|1000056|The Smiths|       10|
|1000067|The Smiths|       18|
|1000070|The Smiths|      399|
|1000073|The Smiths|        3|
|1000077|The Smiths|       15|
|1000082|The Smiths|       16|
+-------+----------+---------+
only showing top 10 rows

Then, we perform the same join using SQL

In [9]:
spark.sql("""select pc.user_id, a.artist, pc.playcount 
          from playcounts pc join artists a on a.id = pc.artist_id 
          limit 10""").show
+-------+----------+---------+
|user_id|    artist|playcount|
+-------+----------+---------+
|1000019|The Smiths|        1|
|1000020|The Smiths|      199|
|1000022|The Smiths|       20|
|1000033|The Smiths|      466|
|1000056|The Smiths|       10|
|1000067|The Smiths|       18|
|1000070|The Smiths|      399|
|1000073|The Smiths|        3|
|1000077|The Smiths|       15|
|1000082|The Smiths|       16|
+-------+----------+---------+

Most played artists

What are the most played artists in our dataset?

In [15]:
spark.sql("""select a.artist, sum(pc.playcount) as all_playcounts 
            from playcounts pc join artists a on a.id = pc.artist_id 
            group by a.artist 
            order by all_playcounts desc
            limit 10""").show
+----------------+--------------+
|          artist|all_playcounts|
+----------------+--------------+
|       Radiohead|       2502133|
|     The Beatles|       2259185|
|       Green Day|       1930592|
|       Metallica|       1542806|
|System of a Down|       1425942|
|      Pink Floyd|       1399419|
| Nine Inch Nails|       1361392|
|    Modest Mouse|       1328882|
|     Bright Eyes|       1234387|
|         Nirvana|       1203227|
+----------------+--------------+

Most active users

What are the most active users in our datasets in terms of number of song plays?

In [16]:
spark.sql("""select user_id, sum(pc.playcount) as all_playcounts 
            from playcounts pc
            group by user_id
            having all_playcounts > 100
            order by all_playcounts desc
            limit 10""").show
+-------+--------------+
|user_id|all_playcounts|
+-------+--------------+
|1059637|        674412|
|2064012|        548427|
|2069337|        393515|
|2023977|        285978|
|1046559|        183972|
|1052461|        175822|
|1070932|        168977|
|1031009|        167273|
|2020513|        165642|
|2062243|        151504|
+-------+--------------+

Train a recommender model

Now that our data is in the required format, we need to train a model that will learn from the play counts dataset what songs to recommend to users currently in the dataset. Our formulation of the problem is rather limited as it neither accounts for new users (what to recommend to students whose tastes we do not know yet?) nor for new items in the dataset.

The formulaton of the problem is as follows: Given an array whose rows are users (raters), columns are bands (items to be rated) and cells represent a rating from a user to a band, can we predict for a rater his/her score for all items he has not rated yet? The problem is more intuitive to understand in a table format:

item1 item 2 item 3 item 4
user 1 3 3 4 5
user 2 3 1 4
user 2 4 2
user 2 5 2 4 4
Foo 4 ????

The question we need to answer is the grade that user Foo will give to item 4.

We will use Spark ML's default implementation of collaborative filtering, namely the Alternating Least Squares method.

In [18]:
val Array(training, test) = playcounts.randomSplit(Array(0.8, 0.2))

The model is trained with mostly default settings.

In [19]:
val model = new ALS().
    setSeed(42).
    setImplicitPrefs(true).
    setNumBlocks(10).
    setRank(10).
    setMaxIter(5).
    setUserCol("user_id").
    setItemCol("artist_id").
    setRatingCol("playcount").
    setPredictionCol("prediction").
    fit(training)

ALS will first convert the input data to preferences, i.e. either 0 or 1 depending on the relative frequency of each artist in the user's portion of the dataset. This is because our input data is composed of implicit preferences, so we take indications of how a user would rate a band based on the number of plays.

Once we have a model, we can use it to make predictions. What we get as a result of a prediction is (roughly) an indicator of how closely a given artist will be to the user's existing preferences.

In [20]:
model.setColdStartStrategy("drop")
val predictions = model.transform(test)
predictions.
    sample(true, 0.10).
    orderBy(desc("user_id")).
    join(artists, col("artist_id") === artists.col("id")).
    select("user_id", "artist", "playcount", "prediction").
    show(10)
+-------+----------+---------+----------+
|user_id|    artist|playcount|prediction|
+-------+----------+---------+----------+
|2440851|The Smiths|        1| 0.3520709|
|2434693|The Smiths|        2| 0.4993019|
|2433747|The Smiths|        3| 0.5907715|
|2433323|The Smiths|        2|  0.494595|
|2427874|The Smiths|       17|0.37485796|
|2425447|The Smiths|        1|0.24094674|
|2423986|The Smiths|        1| 0.6224065|
|2423673|The Smiths|        5| 0.3433353|
|2421721|The Smiths|        2|0.43127826|
|2421721|The Smiths|        2|0.43127826|
+-------+----------+---------+----------+
only showing top 10 rows

Evaluating recommendations for individual cases

We can get an intuitive feeling of how the recommendation works by examining individual cases. We first define a function that given a userId it will return the list of artists this user has played order by number of plays.

In [21]:
def currentlyLikes(data: Dataset[Row], userId: Int)  = {
    data.join(artists, training.col("artist_id") === artists.col("id")).
          filter(x => x.getAs[Int]("user_id") == userId).
          orderBy(desc("playcount")).
          select("user_id", "artist", "playcount")
}

Then, we define our recommendation function: given a trained model and a userId, it will return a dataset with artist recommendations and their scores.

In [22]:
def recommend(model: ALSModel, userID: Int, howMany: Int) : Dataset[Row] = {
    val toRecommend = model.
                      itemFactors.
                      withColumnRenamed("id","artist_id").
                      select("artist_id").
                      withColumn("user_id", lit(userID))
    model.
        transform(toRecommend).
        select("artist_id", "prediction").
        join(artists, col("artist_id") === artists.col("id")).
        orderBy(desc("prediction")).
        select("artist", "prediction").
        limit(howMany)
}

In the case of user 2100000, the recommendations look pretty good: the user seems to like modern metal music and the algorithm suggests some more old school options.

In [23]:
val u1 = 2100000

currentlyLikes(training, u1).show(10)
recommend(model, u1, 10).show
+-------+--------------------+---------+                                        
|user_id|              artist|playcount|
+-------+--------------------+---------+
|2100000|Red Hot Chili Pep...|      197|
|2100000|          Symphony X|      166|
|2100000|           Pearl Jam|      156|
|2100000|         The Haunted|      146|
|2100000|            Mastodon|      146|
|2100000|   Pain of Salvation|      145|
|2100000|        Shadows Fall|      113|
|2100000|                Muse|      112|
|2100000|          Guano Apes|       89|
|2100000|               Opeth|       80|
+-------+--------------------+---------+
only showing top 10 rows

+--------------------+----------+
|              artist|prediction|
+--------------------+----------+
|           Metallica| 1.1241852|
|         Iron Maiden| 1.1201184|
|           In Flames|  1.094646|
|    System of a Down| 1.0938909|
|           Rammstein| 1.0563129|
|             Pantera| 1.0539472|
|            Megadeth|  1.045079|
|Queens of the Sto...|  1.042593|
|            Slipknot| 1.0424271|
|     Cradle of Filth| 1.0408039|
+--------------------+----------+

The recommendations for user 2062243 are more... interesting, as the user's interests are not particularly well defined.

In [25]:
val u2 = 2062243

currentlyLikes(training, u2).show(10)
recommend(model, u2, 10).show(10)
+-------+--------------------+---------+                                        
|user_id|              artist|playcount|
+-------+--------------------+---------+
|2062243|           Morrissey|     7193|
|2062243|       Mouse on Mars|     4913|
|2062243|       The Movielife|     3983|
|2062243|         The Beatles|     3658|
|2062243|        Led Zeppelin|     3354|
|2062243|              Mogwai|     2843|
|2062243|               Queen|     1888|
|2062243|Motion City Sound...|     1706|
|2062243|         Talib Kweli|     1700|
|2062243|            Mudvayne|     1470|
+-------+--------------------+---------+
only showing top 10 rows

+---------------+----------+                                                    
|         artist|prediction|
+---------------+----------+
|Various Artists| 1.5014662|
|   Wu-Tang Clan|  1.379993|
|   The Specials| 1.3798537|
|    Andrew W.K.| 1.3796811|
|        Outkast| 1.3729253|
|   Sage Francis| 1.3654457|
|          N.W.A| 1.3599284|
|        Dr. Dre| 1.3540987|
|       MC Chris| 1.3460674|
|   Public Enemy| 1.3458596|
+---------------+----------+

Acknoweledgements

Parts of the code and the original idea for this work was taken from Ruza, Laserson, Owen and Wils, "Advanced Analytics with Spark", O' Reilly, 2017.