How to balance my data across the partitions?

I have an RDD with 202092 partitions, which reads a dataset created by others. I can manually see that the data is not balanced across the partitions, for example some of them have 0 images and other have 4k, while the mean lies at 432. When processing the data, I got this error:

The memory overhead limit exceeding issue I think is due to DirectMemory buffers used during fetch. I think it's fixed in 2.0.0. (We had the same issue, but stopped digging much deeper when we found that upgrading to 2.0.0 resolved it. Unfortunately I don't have Spark issue numbers to back me up.)