import java.util.Random;
public class RandomizedSet {
// Key : number
// Value : corresponding index in arraylist
Map<Integer, Integer> map;
List<Integer> list;
Random rd;
/** Initialize your data structure here. */
public RandomizedSet() {
map = new HashMap<Integer, Integer>();
list = new ArrayList<Integer>();
rd = new Random();
}
/** Inserts a value to the set. Returns true if the set did not already contain the specified element. */
public boolean insert(int val) {
// Set already has given value
if(map.containsKey(val)) return false;
map.put(val, list.size());
list.add(val);
return true;
}
/** Removes a value from the set. Returns true if the set contained the specified element. */
public boolean remove(int val) {
if(!map.containsKey(val)) return false;
int indexA = map.get(val);
if(indexA != list.size() - 1) swap(list, map, indexA, list.size() - 1);
map.remove(list.get(list.size() - 1));
list.remove(list.size() - 1);
return true;
}
/** Get a random element from the set. */
public int getRandom() {
return list.get(rd.nextInt(list.size()));
}
// a : list index of element a
// b : list index of element b
private void swap(List<Integer> list, Map<Integer, Integer> map, int a, int b){
int valA = list.get(a);
int valB = list.get(b);
int indexA = map.get(valA);
int indexB = map.get(valB);
list.set(a, valB);
list.set(b, valA);
map.put(valA, indexB);
map.put(valB, indexA);
}
}
HashMap 的 value 只存一个 index 就不够用了,得存一个 Set<>,记录相同元素所有的 index,反正 add / remove 的平均复杂度都是 O(1)。
import java.util.*;
public class RandomizedCollection {
Map<Integer, Set<Integer>> map;
List<Integer> list;
Random rand;
/** Initialize your data structure here. */
public RandomizedCollection() {
map = new HashMap<>();
list = new ArrayList<>();
rand = new Random();
}
/** Inserts a value to the collection. Returns true if the collection did not already contain the specified element. */
public boolean insert(int val) {
boolean flag = false;
if(!map.containsKey(val)){
flag = true;
map.put(val, new HashSet<Integer>());
}
list.add(val);
map.get(val).add(list.size() - 1);
return flag;
}
/** Removes a value from the collection. Returns true if the collection contained the specified element. */
public boolean remove(int val) {
boolean flag = false;
if(map.containsKey(val) && map.get(val).size() > 0){
flag = true;
int indexA = map.get(val).iterator().next();
int indexB = list.size() - 1;
swap(map, list, indexA, indexB);
map.get(val).remove(list.size() - 1);
list.remove(list.size() - 1);
}
return flag;
}
/** Get a random element from the collection. */
public int getRandom() {
return list.get(rand.nextInt(list.size()));
}
// Swap elements at list index "a" and "b" in arraylist, and hashmap
private void swap(Map<Integer, Set<Integer>> map, List<Integer> list, int indexA, int indexB){
// O(1)
int valA = list.get(indexA);
int valB = list.get(indexB);
// O(1) average
map.get(valA).remove(indexA);
map.get(valB).remove(indexB);
// O(1) average
map.get(valA).add(indexB);
map.get(valB).add(indexA);
// O(1)
list.set(indexA, valB);
list.set(indexB, valA);
}
}
问题二:运气不好,或者 blacklist 非常大的话,我们可能要调用多次 getRandom,而这个 API 可能是非常贵的。
因此另一个思考角度是,从 white list 出发。
对于【0,10】,blacklist = 【1,2,3,7,8】来讲,white list = 【0,4,5,6,9,10】.因此假设我们有足够内存去维护 whitelist,我们可以直接生成并返回一个 white list 中的元素,这样可以保证一次 getRandom() 操作保证得到想要的元素。
假如 white list 内存放不下呢?
方法照旧。每次我们生成一个 white list 的合理 index 之后,我们要找的就是 “第 index + 1 个不在 blacklist 中的数”。这个可以通过扫 blacklist 实现,O(1) 的内存开销。
相当于实现了一个 index -> whitelist[index] 的 mapping,可以保证一次 getRandom() 就可以得到有效元素。
i 从 0 到 n 循环,如果大于 blacklist 最后一个数,或者小于 blacklist 当前数,都 count --;
否则 blackPtr ++ ;
每次循环的时候,count 扣完了,当前 i 就是目标元素。
static Random rand = new Random();
public static int getRandom(int n, List<Integer> blackList){
int totalNum = n + 1;
int whiteListSize = totalNum - blackList.size();
int index = rand.nextInt(whiteListSize);
int count = index + 1;
int blackPtr = 0;
for(int i = 0; i <= n; i++){
if(i > blackList.get(blackList.size() - 1) || i < blackList.get(blackPtr)){
count --;
} else {
blackPtr ++;
}
if(count == 0) return i;
}
return -1;
}
public static void main(String[] args){
List<Integer> blackList = new ArrayList<Integer>();
blackList.add(1);blackList.add(2);blackList.add(3);
blackList.add(7);blackList.add(8);
blackList.add(13);blackList.add(19);blackList.add(20);
for(int i = 0; i < 50; i++){
System.out.print(" , " + getRandom(20, blackList));
}
}
static Random rand = new Random();
public static int[] reserviorSampling(int k, int[] nums){
if(k >= nums.length) return nums;
int i = 0;
int[] rst = new int[k];
for(; i < k; i++){
rst[i] = nums[i];
}
for(; i < nums.length; i++){
// random is exclusive
int num = rand.nextInt(i + 1);
if(num < k) rst[num] = nums[i];
}
return rst;
}
public static void main(String[] args){
int[] count = new int[10];
for(int i = 0; i < 10000; i++){
int[] sampled = reserviorSampling(5, new int[]{0,1,2,3,4,5,6,7,8,9});
for(int num : sampled) count[num]++;
}
for(int i = 0; i < 10; i++){
System.out.println("Count of " + i + " " + count[i] + " times");
}
}
刚写完这章的总结,leetcode 就搞了个蓄水池抽样的题出来。。好与时俱进啊!
import java.util.*;
public class Solution {
ListNode head;
ListNode rst;
int count;
Random rand;
/** @param head The linked list's head. Note that the head is guanranteed to be not null, so it contains at least one node. */
public Solution(ListNode head) {
this.head = head;
rand = new Random();
}
/** Returns a random node's value. */
public int getRandom() {
ListNode cur = head.next;
rst = head;
count = 2;
while(cur != null){
int num = rand.nextInt(count++);
if(num == 0) rst = cur;
cur = cur.next;
}
return rst.val;
}
}
代码和流程惊人的简单:
遍历每个 i ,随机从 i 还有 i 后面的区间抽一个数,和 i 交换;
重点是 index = 【i, n - 1】区间。每个元素有留在原位置的概率。
其原理非常类似于一群人在一个黑箱子里抽彩票,for 循环中的每次迭代相当于一个人,每次抽样范围 index = 箱子中剩下的所有彩票,其数量依次递减。到最后每个人虽然抽奖时箱子里彩票总数不同,但是获奖概率却是均等的。
import java.util.*;
public class Solution {
int[] original;
Random rand;
public Solution(int[] nums) {
original = nums;
rand = new Random();
}
/** Resets the array to its original configuration and return it. */
public int[] reset() {
return original;
}
/** Returns a random shuffling of the array. */
public int[] shuffle() {
int[] rst = Arrays.copyOf(original, original.length);
for(int i = 0; i < rst.length; i++){
int index = rand.nextInt(rst.length - i) + i;
int temp = rst[i];
rst[i] = rst[index];
rst[index] = temp;
}
return rst;
}
}