from cmath import inf
import torch
import random
import numpy as np
import neptune
import csv
import os 
from itertools import combinations
import time
import json
import re
import ast


class Qtable_WQL:
    def __init__(self, env, seed = 42, lr=0.1, gamma = 0.99):  # eps_start, eps_end, eps_decay


        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.env = env
        # print(self.env.probabilities)
        # print(self.env.cost_matrix)
        self.batch_sizex = env.batch_size # has to be 1 for Qtable
        self.num_nodes = self.env.num_nodes
        self.num_servers = self.env.num_servers
        self.lr = lr
        self.gamma = gamma
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = self.env.device
        
        # total rewards for separate sets 
        self.total_reward = torch.empty((self.env.num_servers)).to(self.device) 
        self.total_reward_estimate = torch.empty((self.env.num_servers)).to(self.device) 
         
        if '_' in self.env.graph_type:
            graph_group = re.match(r'^\D+(?=_)', self.env.graph_type).group(0)
        else: 
            graph_group = self.env.graph_type
        self.path = f"data/qtable_wql/var_distance{self.env.var_distance}/{graph_group}/{self.env.graph_type}_{self.num_nodes}_{self.num_servers}"
        print(self.path)
        self.precompute()
        if os.path.exists(os.path.join(self.path, 'Q_policy.pth')):
            self.Q_policy = torch.load(os.path.join(self.path, 'Q_policy.pth')).to(self.device) 
        if os.path.exists(os.path.join(self.path, 'Q_table.pth')):
            self.Q_table = torch.load(os.path.join(self.path, 'Q_table.pth')).to(self.device) 
        else:
            self.compute_Q_policy()

    def precompute(self): 

        start_time = time.time()
        
        
        os.makedirs(self.path, exist_ok=True)

        if os.path.exists(os.path.join(self.path, 'env_rewards.pth')):

            with open(os.path.join(self.path, 'locations_from_id.json'), 'r') as f:
                self.locations_from_id = json.load(f)
            # with open(os.path.join(self.path, 'locations_to_id.json'), 'r') as f:
            #     self.locations_to_id = json.load(f)
            with open(os.path.join(self.path, 'locations_to_id.json'), 'r') as f:
                self.locations_to_id_str_keys = json.load(f)

            # # Convert string keys back to tuples
            # self.locations_to_id = {tuple(key): value for key, value in self.locations_to_id_str_keys.items()}

            # with open(os.path.join(self.path, 'locations_to_id.json'), 'r') as f:
            #     self.locations_to_id_str_keys = json.load(f)

            # Convert string keys back to tuples
            self.locations_to_id = {ast.literal_eval(key): value for key, value in self.locations_to_id_str_keys.items()}


            with open(os.path.join(self.path, 'n_server_states.txt'), 'r') as f:
                self.n_server_states = int(f.read())

            self.env_rewards = torch.load(os.path.join(self.path, 'env_rewards.pth')).to(self.device) 

            self.next_server_states = torch.load(os.path.join(self.path, 'next_server_states.pth')).to(self.device) 

        else: 
            start_time_locations = time.time()
            self.locations_from_id = {}
            self.locations_to_id = {}
            server_state_id = 0
            for locations in combinations(range(self.num_nodes), self.num_servers):
                # print(locations)
                self.locations_from_id[server_state_id] = locations
                # print(locations_from_id)
                self.locations_to_id[locations] = server_state_id
                # print(locations_to_id)
                server_state_id += 1
            self.n_server_states = server_state_id
            print(f"Running time for Precomputing Locations: {time.time() - start_time_locations} seconds")
            

            start_time_rewards = time.time()

            # calculate reward for each <state, action>
            self.env_rewards = torch.zeros((self.n_server_states, self.num_nodes, self.num_servers)).to(self.device) 

            for s_server in range(self.n_server_states):
                locations = self.locations_from_id[s_server]
                for s_request in range(self.num_nodes):
                    for a in range(self.num_servers):
                        from_node = locations[a]
                        self.env_rewards[s_server, s_request, a] = self.env.cost_matrix[from_node, s_request]
            
            print(f"Running time for Precomputing Rewards: {time.time() - start_time_rewards} seconds")

            start_time_nss = time.time()
            # calculate next server state for each <state, action>
            self.next_server_states = torch.zeros((self.n_server_states, self.num_nodes, self.num_servers), dtype=torch.int32).to(self.device) 
            for s_server in range(self.n_server_states):
                locations = self.locations_from_id[s_server]
                for s_request in range(self.num_nodes):
                    for a in range(self.num_servers):
                        if s_request in locations: # avoid having multiple servers on the same node
                            self.next_server_states[s_server, s_request, a] = s_server
                        else:
                            from_node = locations[a]
                            next_locations = list(locations)
                            next_locations.remove(from_node)
                            next_locations.append(s_request)
                            next_locations.sort()
                            self.next_server_states[s_server, s_request, a] = self.locations_to_id[tuple(next_locations)]
                
            print(f"Running time for Precomputing Next Server States: {time.time() - start_time_nss} seconds")


            # Saving everything not to repeat it all again
            with open(os.path.join(self.path, 'locations_from_id.json'), 'w') as f:
                json.dump(self.locations_from_id, f)


            self.locations_to_id_str_keys = {str(key): value for key, value in self.locations_to_id.items()}

            with open(os.path.join(self.path, 'locations_to_id.json'), 'w') as f:
                json.dump(self.locations_to_id_str_keys, f)
            
            # with open(os.path.join(self.path, 'locations_to_id.json'), 'w') as f:
            #     json.dump(self.locations_to_id, f)


            with open(os.path.join(self.path, 'n_server_states.txt'), 'w') as f:
                f.write(str(self.n_server_states))

            torch.save(self.env_rewards, os.path.join(self.path, 'env_rewards.pth'))

            torch.save(self.next_server_states, os.path.join(self.path, 'next_server_states.pth'))

                
        print(f"Running time for Precompute: {time.time() - start_time} seconds")
        
    def draw_request(self):
        return np.random.choice(self.num_nodes, p= self.env.probabilities)

    def draw_next_state(self, s_server, s_request, a):
        s_server_next = self.next_server_states[s_server, s_request, a]
        s_request_next = self.draw_request()
        return (s_server_next, s_request_next)

    def greedy_policy(self, s_server, s_request):
        return np.argmin(self.env_rewards[s_server, s_request].cpu())

    def random_policy(self, s_server, s_request):
        return np.random.choice(self.num_servers)

    
    def evaluate_policy(self, policy, n_iter=100_000):
        np.random.seed(self.seed)
        s_server = 0
        s_request = 0
        rewards = []
        for i in range(n_iter):
            a = policy(s_server, s_request)
            rewards.append(self.env_rewards[s_server, s_request, a].item())
            (s_server, s_request) = self.draw_next_state(s_server, s_request, a)
        rewards_tensor = torch.tensor(rewards)
        # print(rewards)
        mean_reward = torch.mean(rewards_tensor)
        
        return mean_reward

    # compute Q table with NumPy

    def compute_Q_policy(self, n_iter=100, save = False):
        
        
        print_interval = int(n_iter / 10)  # print 10 times
        start_time = time.time() # measure running time

        self.Q_table = torch.zeros((self.n_server_states, self.num_nodes, self.num_servers), dtype=torch.float64).to(self.device) 
        arrival_rate_vector = torch.tensor(self.env.probabilities, dtype=torch.float64).view(self.num_nodes, 1).to(self.device) 
        for i in range(n_iter):
            V_server_request = torch.min(self.Q_table, dim=-1)[0]  # value of state = min_action Q(state, action)
            V_server = torch.squeeze(torch.matmul(V_server_request.double(), arrival_rate_vector))  # value of server state with random request arrival
            V_discounted = self.gamma * V_server  # temporal discount for value of next state
            Q_table_updated = self.env_rewards + V_discounted[self.next_server_states.long()]  # Bellman equation
            # report progress
            if i % print_interval == 0:
                policy = torch.argmin(self.Q_table, dim=-1)
                policy_updated = torch.argmin(Q_table_updated, dim=-1)
                n_policy_changes = torch.sum(policy != policy_updated)
                pctg_policy_changes = n_policy_changes / (self.n_server_states * self.num_nodes) * 100
                Q_change = torch.mean(torch.abs(self.Q_table - Q_table_updated))
                print(f"Iteration {i + 1}: average Q change = {Q_change:.3f}, number of decisions changed = {n_policy_changes} ({pctg_policy_changes:.1f}%)")
            self.Q_table = Q_table_updated
        self.Q_table = -self.Q_table
        print(f"Running time: {time.time() - start_time} seconds")
        self.Q_policy = torch.argmax(self.Q_table, dim=-1)
        # if save: 
        torch.save(self.Q_table, os.path.join(self.path, 'Q_table.pth'))
        torch.save(self.Q_policy, os.path.join(self.path, 'Q_policy.pth'))
        # return Q_policy, Q_table

    def compare_policies(self): 
        mean_greedy = self.evaluate_policy(self.greedy_policy)
        print("Greedy policy:", -mean_greedy.item())
        mean_random = self.evaluate_policy(self.random_policy)
        print("Random policy:", -mean_random.item())
        mean_qtable = self.evaluate_policy(lambda s_server, s_request: self.Q_policy[s_server, s_request])
        print("Qtable:", -mean_qtable.item())

        output_file_name = os.path.join(self.path, 'compare_policies.csv')
                                 
        with open(output_file_name, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['graph_type', 'agent', 'gamma','num_nodes', 'num_servers', 'estimate'])
            writer.writerow([self.env.graph_type, self.class_name, self.gamma, self.env.num_nodes, self.num_servers,  round(-mean_qtable.item(), 3)]) 
            writer.writerow([self.env.graph_type, 'random', self.gamma, self.env.num_nodes, self.num_servers,  round(-mean_random.item(), 3)]) 
            writer.writerow([self.env.graph_type, 'greedy', self.gamma, self.env.num_nodes, self.num_servers,  round(-mean_greedy.item(), 3)]) 

    

    def get_action(self, state): 

        qt_index = state[:,:self.env.num_servers].to(self.device)
        server_location_batch = state[:,self.env.num_servers].to(self.device)
        qt_index_tuple = tuple(map(tuple, qt_index.int().tolist()))
        q_id_tensor = torch.tensor([self.locations_to_id[index] for index in qt_index_tuple], dtype=torch.long)
        server_location_batch = server_location_batch.long()
        gt_lst = self.Q_policy[q_id_tensor, server_location_batch]
        concatenated_gt = gt_lst.to(self.device)
        action = torch.gather(qt_index, 1, concatenated_gt.view(-1, 1))
        return action


    def estimate_seq(self, state, requests):

        total_reward_episode = []
        for i in range(requests.shape[1]): 
            action = self.get_action(state).to(self.device)
            next_state, reward, _ = self.env.step(action, state, next_req = requests[:, i].reshape(state.shape[0],1))
            # print(state, action, reward)
            state = next_state
            total_reward_episode.append(reward.reshape(state.shape[0],1))

        estimates = torch.cat(total_reward_episode, dim=1)
        return torch.sum(estimates), torch.quantile(estimates, 0.25), torch.quantile(estimates, 0.75), estimates

        
    # print(f"Step {step+1}, Average Estimate Reward {torch.mean(self.total_reward_estimate[1:]):.2f}")
    @property
    def class_name(self):
        return self.__class__.__name__

      