import torch
import random
import networkx as nx
from copy import deepcopy
import time




def WFA(s_init, s_current, h_r, k, sp, L = -10000):
    # create a new graph
    logging = False
    G_dash = nx.DiGraph()

    ##############
    # ADD NODES
    ##############
    # add source and sink
    G_dash.add_node('source')
    G_dash.add_node('sink')
    for i in range(len(h_r)):
        G_dash.add_node(str(i))
        G_dash.add_node(str(i) + '_dash')

    # add the initial locations of the servers.
    for i in s_init:
        G_dash.add_node('s_start_' + str(i))

    # add the current location of the servers
    # the k-th server (for which we are computing WFA) should be at the current request
    s_current_c = deepcopy(s_current)
    s_current_c[k] = h_r[-1]
    for i in s_current_c:
        G_dash.add_node('s_curr_' + str(i))

    ##############
    # ADD EDGES
    ##############

    # first, connect source to initial server locations
    for i in s_init:
        G_dash.add_edges_from([
            ('source', 's_start_' + str(i), {"capacity": 1, "weight": 0}),
        ])

    # connect each server to each request: this only potentially enables service,
    # we are not saying that each server goes to each request
    for i in s_init:
        for j in range(len(h_r)):
            G_dash.add_edges_from([
                ('s_start_' + str(i), str(j), {"capacity": 1, "weight": sp[i][h_r[j]]})
            ])

    # connect r and r'
    for i in range(len(h_r)):
        G_dash.add_edges_from([
            (str(i), str(i) + '_dash', {"capacity": 1, "weight": L})
        ])

    # connect each r'_t1 to r'_t2 if t1<t2, i.e., if t1 arrived before t2
    for i in range(len(h_r)):
        for j in range(i + 1, len(h_r)):
            G_dash.add_edges_from([
                (str(i) + '_dash', str(j), {"capacity": 1, "weight": sp[h_r[i]][h_r[j]]})
            ])

    # connect each r' to the current configuration of the servers
    for i in range(len(h_r)):
        for j in s_current_c:
            G_dash.add_edges_from([
                (str(i) + '_dash', 's_curr_' + str(j), {"capacity": 1, "weight": sp[h_r[i]][j]})
            ])

    # connect initial server locations to current server locations:
    for i in s_init:
        for j in s_current_c:
            G_dash.add_edges_from([
                ('s_start_' + str(i), 's_curr_' + str(j), {"capacity": 1, "weight": sp[i][j]})
            ])

    # connect current server locations to sink
    for j in s_current_c:
        G_dash.add_edges_from([
            ('s_curr_' + str(j), 'sink', {"capacity": 1, "weight": 0})
        ])
    # return the offline opt term + the online opt term
    if logging:
        print("Offline Cost is: {}".format(
            nx.cost_of_flow(G_dash, nx.max_flow_min_cost(G_dash, 'source', 'sink')) - (len(h_r) * L)))
    return nx.cost_of_flow(G_dash, nx.max_flow_min_cost(G_dash, 'source', 'sink')) - (len(h_r) * L) + sp[s_current[k]][
        h_r[-1]]

class WorkFunction():
    def __init__(self, env, num_requests = 100):
        self.env = env
        self.batch_size = env.batch_size # has to be 1 for Qtable
        self.num_requests = num_requests
        if self.batch_size > 1:
                raise ValueError('The environment batch size has to be equal to 1') 
        self.device = self.env.device 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.requests = []
        self.server_init = sorted(random.sample(range(self.env.num_nodes), self.env.num_servers))
        self.lookout_w = 100
        # Initialize the first element and add it to the deque
        if self.env.request_same_node== False: 
            first_request = random.randint(0, self.env.num_nodes - 1)
            while first_request in self.server_init:
                first_request = random.randint(0, self.env.num_nodes - 1)   
            self.requests.append(first_request)
        self.sp = {}
        for i, row in enumerate(env.cost_matrix):
            self.sp[i] = {j: int(row[j].item()) for j in range(len(row))}


    def estimate(self, num_steps=1,  print_results = False):

        
        steps_for_display = int(10000/self.batch_size)
        num_steps = int(num_steps*1000/self.batch_size)

        server_current = self.server_init.copy()
        my_tensor = torch.tensor(sorted(server_current))
        state = torch.cat((my_tensor, torch.tensor([self.requests[0]]))).unsqueeze(0).to(self.device)
        print("start")
        start_time = time.time()
        
        for step in range(num_steps):
            actions_all = state[:, :-1]
            server_current = actions_all.squeeze(0).tolist()
            min_cost = 1e10
            best_k = 0
            for k in range(len(self.server_init)):
                k_cost = WFA(s_init=self.server_init, s_current=server_current, h_r=list(self.requests)[:step + 1], k=k, sp = self.sp)
                if k_cost < min_cost:
                    min_cost = k_cost
                    best_k = k

            action = actions_all[:, best_k]
            next_state, reward, _ = self.env.step(action.unsqueeze(0), state)
            # server_current[best_k] = self.requests[r]
            self.requests.append(next_state[:, -1].item())
            self.total_reward = torch.cat((self.total_reward, reward.to(self.device) ), 0)
            state = next_state.to(self.device) 
            if print_results:
                if ((step+1)  % steps_for_display == 0):
                    print(f"Step {step+1},  Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")


        if print_results:
            print(f"Average Reward {torch.mean(self.total_reward[self.env.num_servers:]):.2f}")

        estimates = self.total_reward[self.env.num_servers:]

        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Overall, it took {round(elapsed_time, 3)}")

        return torch.mean(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates


    def estimate_seq(self, state, requests):

        total_reward_episode = []
        self.requests.append(state[:, -1].item())  
        for i in range(requests.shape[1]): 
            actions_all = state[:, :-1]
            server_current = actions_all.squeeze(0).tolist()
            min_cost = 1e10
            best_k = 0
            for k in range(len(self.server_init)):
                if len(self.requests) > self.lookout_w: 
                    k_cost = WFA(s_init=self.server_init, s_current=server_current, h_r=list(self.requests)[-self.lookout_w:], k=k, sp = self.sp)
                else: 
                    k_cost = WFA(s_init=self.server_init, s_current=server_current, h_r=list(self.requests), k=k, sp = self.sp)
                
                if k_cost < min_cost:
                    min_cost = k_cost
                    best_k = k

            action = actions_all[:, best_k]

            # print(state, state.size(), action, action.size(), requests[:, i].reshape(state.shape[0],1), requests[:, (i)].reshape(state.shape[0],1).size()) 
            next_state, reward, _ = self.env.step(action.unsqueeze(0), state, next_req = requests[:, i].reshape(state.shape[0],1))
            self.requests.append(next_state[:, -1].item())
            state = next_state
            total_reward_episode.append(reward.reshape(state.shape[0],1))

        estimates = torch.cat(total_reward_episode, dim=1)
        return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates






