Commit b51f2839 authored by anon's avatar anon
Browse files

Prioritized Sweeping Value iteration

parent d7f82854
......@@ -199,21 +199,43 @@ class PrioritizedSweepingValueIterationAgent(AsynchronousValueIterationAgent):
for state in states:
actions = self.mdp.getPossibleActions(state)
for action in actions:
nextState = self.mdp.getTransitionStatesAndProbs(state, action)[0]
try:
predecessors[nextState].add(state)
except KeyError:
predecessors[nextState] = set()
stateAndProbs = self.mdp.getTransitionStatesAndProbs(state, action)
for stateAndProb in stateAndProbs:
nextState = stateAndProb[0]
try:
predecessors[nextState].add(state)
except KeyError:
predecessors[nextState] = set()
predecessors[nextState].add(state)
# Initialize the priorityqueue
prQueue = util.PriorityQueue()
for state in states:
if mpd.isTerminal(state):
continue
actions = self.mdp.getPossibleActions(state)
qValues = []
for action in actions:
qValues.append(self.getQValue(state, action))
prQueue.update(state, -abs(values[state] - max(qValues)))
prQueue = util.PriorityQueue()
for state in states:
if self.mdp.isTerminal(state):
continue
prQueue.update(state, -abs(self.values[state] - self.getHighestQValue(state)))
for i in range(self.iterations):
#breakpoint()
if prQueue.isEmpty():
break
s = prQueue.pop()
if not self.mdp.isTerminal(s):
self.values[s] = self.getHighestQValue(s)
try:
for predecessor in predecessors[s]:
diff = abs(self.values[predecessor] - self.getHighestQValue(predecessor))
if diff>self.theta:
prQueue.update(predecessor, -diff)
except KeyError:
predecessors[s] = set()
def getHighestQValue(self, state):
actions = self.mdp.getPossibleActions(state)
qValues = []
for action in actions:
qValues.append(self.getQValue(state, action))
return max(qValues)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment