import torch
import numpy as np
import random
import gymnasium as gym
from torch.distributions import Categorical
from copy import deepcopy

class QLearning:

    def __init__(self, envs, args):
        self.args = args
        self._ldba = deepcopy(envs.call('unwrapped')[0].get_ldba())
        self.jump_mask = torch.tensor(envs.call('unwrapped')[0].get_jump_mask())
        self.n_states = self.jump_mask.shape[0]
        self.is_discrete = isinstance(envs.unwrapped.single_action_space, gym.spaces.Discrete)
        if not self.is_discrete:
            raise Exception('Q-learning only supports tabular settings.')
        self.n_actions = envs.unwrapped.single_action_space.n
        self.qs, self.buffer = {}, {}
        self.unseen_actions = {}
        self.epsilon = args.q_epsilon
        self.update_frequency = 1

    def anneal(self, update, num_updates):
        frac = 1.0 - (update - 1.0) / num_updates
        self.epsilon = frac * self.args.q_epsilon

    def store(self, obs, info, done, action, reward, logprob, value, next_obs, next_info, next_done, step):
        key = tuple([e.item() for e in torch.cat([obs[0], info['ldba_obs']], -1)])
        next_key = tuple([e.item() for e in torch.cat([next_obs[0], next_info['ldba_obs']], -1)])
        if key in self.buffer:
            self.buffer[key][action.item()] = (next_key, reward.item(), next_done.item(), tuple(e.item() for e in next_info['ap'][0]))
        else:
            self.buffer[key] = {action.item(): (next_key, reward.item(), next_done.item(), tuple(e.item() for e in next_info['ap'][0]))}
        if key not in self.qs: self.qs[key] = torch.zeros(self.n_actions)
        if next_key not in self.qs: self.qs[next_key] = torch.zeros(self.n_actions)
        if key not in self.unseen_actions: self.unseen_actions[key] = torch.ones(self.n_actions)

    def get_action_and_value(self, x, action=None, info={}):
        key = tuple([e.item() for e in torch.cat([x[0], info['ldba_obs']], -1)])
        if (torch.rand(1).item() < self.epsilon) or key not in self.qs:
            action = torch.randint(self.n_actions, (1,1))
        else:
            action = Categorical(probs=torch.softmax(self.qs[key]/1e-6, dim=0)).sample().reshape(1,1)

        if key not in self.unseen_actions: self.unseen_actions[key] = torch.ones(self.n_actions)
        self.unseen_actions[key][action] = 0
        dummy = action * 0
        return action, dummy.squeeze(), dummy.squeeze(), dummy

    def eval_action(self, x, info={}):
        key = tuple([e.item() for e in torch.cat([x[0], info['ldba_obs']], -1)])
        if key not in self.qs:
            return torch.randint(self.n_actions, (1,1))
        return Categorical(probs=torch.softmax(self.qs[key]/1e-6, dim=0)).sample().reshape(1,1)

    def _sample_batch(self):
        keys = list(self.buffer.keys())
        batch = []
        for _ in range(self.args.batch_size):
            key = random.choice(keys)
            a = int(random.choice(list(self.buffer[key].keys())))
            next_key, r, done, _ = self.buffer[key][a]
            batch.append((key, a, next_key, r, done))
        return batch

    def train(self, step, n_stages, batch=None):
        batch = self._sample_batch() if batch is None else batch
        mean_q, q_loss = 0, 0
        for key, a, next_key, r, done in batch:
            discount = self.args.gamma**(r > 0 if self.args.q_eventual else 1)
            target = (r + discount * torch.max(self.qs[next_key])).item()
            q_loss += (self.qs[key][a] - target)**2
            self.qs[key][a] += self.args.q_lr*(target - self.qs[key][a])
            mean_q += self.qs[key][a]
        return {
            "mean_q": mean_q/len(batch),
            "q_loss": q_loss/len(batch),
        }

    def save(self, working_dir):
        pass


class CycleQLearning:

    def __init__(self, envs, args):
        self.args = args
        self._ldba = deepcopy(envs.call('unwrapped')[0].get_ldba())
        self.jump_mask = torch.tensor(envs.call('unwrapped')[0].get_jump_mask())
        self.n_states = self.jump_mask.shape[0]
        self.is_discrete = isinstance(envs.unwrapped.single_action_space, gym.spaces.Discrete)
        if not self.is_discrete:
            raise Exception('Q-learning only supports tabular settings.')
        self.n_actions = envs.unwrapped.single_action_space.n
        self.qs, self.buffer = {}, {}
        self.unseen_actions = {}
        self.epsilon = args.q_epsilon
        self.update_frequency = 1

    def anneal(self, update, num_updates):
        frac = 1.0 - (update - 1.0) / num_updates
        self.epsilon = frac * self.args.q_epsilon

    def store(self, obs, info, done, action, reward, logprob, value, next_obs, next_info, next_done, step):
        key = tuple([e.item() for e in torch.cat([obs[0], info['ldba_obs']], -1)])
        next_key = tuple([e.item() for e in torch.cat([next_obs[0], next_info['ldba_obs']], -1)])
        if key in self.buffer:
            self.buffer[key][action.item()] = (next_key, reward.item(), next_done.item(), tuple(e.item() for e in next_info['ap'][0]))
        else:
            self.buffer[key] = {action.item(): (next_key, reward.item(), next_done.item(), tuple(e.item() for e in next_info['ap'][0]))}
        if key not in self.qs: self.qs[key] = torch.zeros(self.n_actions)
        if next_key not in self.qs: self.qs[next_key] = torch.zeros(self.n_actions)
        if key not in self.unseen_actions: self.unseen_actions[key] = torch.ones(self.n_actions)

    def get_action_and_value(self, x, action=None, info={}):
        key = tuple([e.item() for e in torch.cat([x[0], info['ldba_obs']], -1)])
        if (torch.rand(1).item() < self.epsilon) or key not in self.qs:
            action = torch.randint(self.n_actions, (1,1))
        else:
            action = Categorical(probs=torch.softmax(self.qs[key]/1e-6, dim=0)).sample().reshape(1,1)

        if key not in self.unseen_actions: self.unseen_actions[key] = torch.ones(self.n_actions)
        self.unseen_actions[key][action] = 0
        dummy = action * 0
        return action, dummy.squeeze(), dummy.squeeze(), dummy

    def eval_action(self, x, info={}):
        key = tuple([e.item() for e in torch.cat([x[0], info['ldba_obs']], -1)])
        if key not in self.qs:
            return torch.randint(self.n_actions, (1,1))
        return Categorical(probs=torch.softmax(self.qs[key]/1e-6, dim=0)).sample().reshape(1,1)

    def _sample_batch(self):
        keys = list(self.buffer.keys())
        batch = []
        for _ in range(self.args.batch_size):
            key = random.choice(keys)
            a = int(random.choice(list(self.buffer[key].keys())))
            next_key, r, done, _ = self.buffer[key][a]
            batch.append((key, a, next_key, r, done))
        return batch

    def train(self, step, n_stages, batch=None):
        batch = self._sample_batch() if batch is None else batch
        mean_q, q_loss = 0, 0
        for key, a, next_key, r, done in batch:
            if key[2] == n_stages:
                target = 150
            else:
                discount = self.args.gamma**(r > 0 if self.args.q_eventual else 1)
                target = (r + discount * torch.max(self.qs[next_key])).item()
            q_loss += (self.qs[key][a] - target)**2
            self.qs[key][a] += self.args.q_lr*(target - self.qs[key][a])
            mean_q += self.qs[key][a]
        return {
            "mean_q": mean_q/len(batch),
            "q_loss": q_loss/len(batch),
        }

    def save(self, working_dir):
        pass