from dataclasses import dataclass
import argparse
from typing import List, Tuple, Dict
from time import time
from pprint import pprint

parser = argparse.ArgumentParser()
parser.add_argument("n", type=int)
parser.add_argument("--print", action="store_true")

@dataclass
class Config:
    print: bool = False


class Graph:
    def __init__(self, vertices: List[str], edges: Dict[str, Dict[str, int]]):
        self.vertices = {vertex: i for i, vertex in enumerate(vertices)}
        self.adjacency = edges

    def num_vertices(self) -> int:
        return len(self.vertices)

    def num_edges(self) -> int:
        return sum(len(edges) for edges in self.adjacency.values())
    

class MeanPayoffPolicyIteration:
    def __init__(self, graph: Graph):
        self.graph = graph
        self.policy = self.get_initial_policy()

    def get_initial_policy(self):
        policy = dict()
        for vertex in self.graph.vertices:
            neighbors = self.graph.adjacency[vertex].keys()
            policy[vertex] = min(neighbors, key=lambda v: self.graph.vertices[v])
        return policy

    def get_optimal_policy(self):
        info = {'policies': [self.policy.copy()]}
        while True:
            new_policy = self.get_new_policy()
            if new_policy == self.policy:
                break
            self.policy = new_policy
            if Config.print:
                print("#"*10 + f"Policy {len(info['policies'])}" + "#"*10)
                pprint(new_policy) 
            info['policies'].append(self.policy.copy())

        return self.policy, info
    
    def get_new_policy(self):
        appraisal = dict()
        for vertex in self.graph.vertices:
            visited = []
            weights = []
            current_vertex = vertex
            while current_vertex not in visited:
                visited.append(current_vertex)
                weights.append(self.graph.adjacency[current_vertex][self.policy[current_vertex]])
                current_vertex = self.policy[current_vertex]
            

            cycle_start = visited.index(current_vertex)
            cycle_weight = sum(weights[cycle_start:])
            cycle_len = len(visited) - cycle_start



            path_len = visited.index(min(visited[cycle_start:], key=lambda v: self.graph.vertices[v]))
            path_weight = sum(weights[:path_len])

            mp_val = cycle_weight / cycle_len
            pot = path_weight - mp_val * path_len
            appraisal[vertex] = (mp_val, pot)


        if Config.print:
            print("#"*10 + "Value and Potential Function" + "#"*10)
            pprint(appraisal)

        if Config.print:
            print("#"*10 + "Appraisal" + "#"*10)
        new_policy = dict()
        for vertex in self.graph.vertices:
            neighbors = self.graph.adjacency[vertex].keys()
            if Config.print:
                for v in neighbors:
                    print(f"({vertex} -> {v}): ({appraisal[v][0]}, {appraisal[v][1] - appraisal[vertex][0] + self.graph.adjacency[vertex][v]})")

            new_policy[vertex] = max(
                neighbors, 
                key=lambda v: (
                    appraisal[v][0], 
                    appraisal[v][1] - appraisal[vertex][0] + self.graph.adjacency[vertex][v],
                    -self.graph.vertices[v] if v != self.policy[vertex] else 1
                )
            )
        return new_policy

def quadr(n):
    vertices = ['t1'] + ['b' + str(i) for i in range(1,n+1)] + ['t' + str(i) for i in range(2,n+1)]
    edges = {}
    for i in range(1,n+1):
        edges["b" + str(i)] =  {'b' + str(j): (n+1)**2+1 for j in range(1,i)} | {'t' + str(j): 0 for j in range(1, n+1)}
    for i in range(1, n+1):
        edges["t" + str(i)] = {'b' + str(j): (n+1)**2+1 for j in range(1,i+1)} | {'t' + str(j): 0  for j in range(1, i)} | {'t'+str(i): (n+1)**2 - n + (i-1)} 
    return Graph(vertices, edges)



if __name__ == "__main__":
    parser.parse_args()
    args = parser.parse_args()
    n = args.n
    Config.print = args.print

    current_time = time()
    graph = quadr(n)

    if Config.print:
        print("#"*10 + "Graph" + "#"*10)
        pprint(graph.adjacency)


    mpi = MeanPayoffPolicyIteration(graph)

    if Config.print:
        print("#"*10 + "Policy 0" + "#"*10)
        pprint(mpi.policy)

    policy, info = mpi.get_optimal_policy()


    print("#"*10 + "Time" + "#"*10)
    print(time() - current_time)

    if Config.print:
        print("#"*10 + "Optimal Policy" + "#"*10)
        pprint(policy)

    print("#"*10 + "Num. Iterations" + "#"*10)
    print("Actual: ", len(info['policies']), "Guess: ", (n**2 + 7*n - 6) // 2)