import numpy as np
import gym
from scipy.linalg import solve_discrete_are
from lqr_env import LQREnv
from IPython import embed
import torch
import scipy
from bandit_env import Controller
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


UPPER_BOUND = np.inf
LOWER_BOUND = -np.inf


def sample(dim, H):
    goal = np.random.randint(0, dim, 2)
    env = DarkroomEnv(dim, goal, H)
    return env



class DarkroomEnv(LQREnv):
    def __init__(self, dim, goal, H):
        self.dim = dim
        self.goal = np.array(goal)
        self.H = H
        self.dx = 2
        self.du = 5
        self.observation_space = gym.spaces.Box(low=0, high=dim - 1, shape=(self.dx,))
        self.action_space = gym.spaces.Discrete(self.du)
        
    def sample_x(self):
        return np.random.randint(0, self.dim, 2)
    
    def sample_u(self):
        i = np.random.randint(0, 5)
        a = np.zeros(self.action_space.n)
        a[i] = 1
        return a

    def reset(self):
        self.current_step = 0
        self.state = np.array([0, 0])
        return self.state


    def transit(self, s, a):
        a = np.argmax(a)
        assert a in np.arange(self.action_space.n)
        s = np.array(s)
        if a == 0:
            s[0] += 1
        elif a == 1:
            s[0] -= 1
        elif a == 2:
            s[1] += 1
        elif a == 3:
            s[1] -= 1
        s = np.clip(s, 0, self.dim - 1)

        if np.all(s == self.goal):
            r = 1
        else:
            r = 0
        return s, r

    def step(self, action):
        if self.current_step >= self.H:
            raise ValueError("Episode has already ended")

        self.state, r = self.transit(self.state, action)
        self.current_step += 1
        done = (self.current_step >= self.H)
        return self.state.copy(), r, done, {}


    def opt_a(self, x):
        if x[0] < self.goal[0]:
            a = 0
        elif x[0] > self.goal[0]:
            a = 1
        elif x[1] < self.goal[1]:
            a = 2
        elif x[1] > self.goal[1]:
            a = 3
        else:
            a = 4
        zeros = np.zeros(self.action_space.n)
        zeros[a] = 1
        return zeros


class DarkroomEnvStitch(DarkroomEnv):
    """
    Darkroom environment with two goals, one on the right and one on the bottom.
    If the goal is on the right, the agent is initialized on the left during train and top during eval.
    If the goal is on the bottom, the agent is initialized on the top during train and left during eval.
    """
    def __init__(self, dim, goal, H, eval=False):
        self.goals = [np.array([dim // 2, dim - 1]), np.array([dim - 1, dim // 2])]
        assert any([np.all(goal == g) for g in self.goals])
        super().__init__(dim, goal, H)

        self.eval = eval
        if eval:
            self.initial_states = [np.array([0, dim // 2]), np.array([dim // 2, 0])]
            self.demo_states = [
                [np.array([i, dim // 2]) for i in range(dim // 2)] + [np.array([dim // 2, i]) for i in range(dim // 2, dim)],
                [np.array([dim // 2, i]) for i in range(dim // 2)] + [np.array([i, dim // 2]) for i in range(dim // 2, dim)],
            ]
        else:
            self.initial_states = [np.array([dim // 2, 0]), np.array([0, dim // 2])]
            self.demo_states = [
                [np.array([dim // 2, i]) for i in range(dim)],
                [np.array([i, dim // 2]) for i in range(dim)],
            ]

    def reset(self):
        self.current_step = 0
        if np.all(self.goal == self.goals[0]):
            self.state = self.initial_states[0]
        else:
            self.state = self.initial_states[1]
        return self.state

    def sample_stitch_x(self):
        assert self.eval
        all_demo_states = self.demo_states[0] + self.demo_states[1]
        return all_demo_states[np.random.randint(0, len(all_demo_states))]

    def sample_stitch_opt_a(self, x):
        assert self.eval
        if x[0] == self.dim // 2 and x[1] != self.dim // 2:
            if x[1] < self.dim - 1:
                a = 2
            else:
                a = 4
        elif x[1] == self.dim // 2 and x[0] != self.dim // 2:
            if x[0] < self.dim - 1:
                a = 0
            else:
                a = 4
        else:
            if np.random.rand() < 0.5:
                a = 2
            else:
                a = 0
        zeros = np.zeros(self.action_space.n)
        zeros[a] = 1
        return zeros

    def sample_opt_x(self):
        if np.all(self.goal == self.goals[0]):
            return self.demo_states[0][np.random.randint(0, len(self.demo_states[0]))]
        else:
            return self.demo_states[1][np.random.randint(0, len(self.demo_states[1]))]

    def opt_a(self, x):
        if self.eval:
            if np.all(self.goal == self.goals[0]):
                # down then right
                return super().opt_a(x)
            else:
                # right then down
                if x[1] < self.goal[1]:
                    a = 2
                elif x[1] > self.goal[1]:
                    a = 3
                elif x[0] < self.goal[0]:
                    a = 0
                elif x[0] > self.goal[0]:
                    a = 1
                else:
                    a = 4
                zeros = np.zeros(self.action_space.n)
                zeros[a] = 1
                return zeros
        else:
            return super().opt_a(x)


class DarkroomOptPolicy(Controller):
    def __init__(self, env):
        super().__init__()
        self.env = env
        self.goal = env.goal

    def reset(self):
        return

    def act(self, x):
        return self.env.opt_a(x)
        
        
class RandCommit(Controller):
    def __init__(self, env):
        super().__init__()
        self.goal = None
        self.env = env

    def reset(self):
        self.goal = None


    def set_batch(self, batch):
        self.batch = batch
        rs = batch['rollin_rs'].flatten().cpu().detach().numpy()
        if len(rs) > 0 and np.max(rs) > 0:
            i = np.argmax(rs)
            self.goal = batch['rollin_xps'][0,i,:].cpu().detach().numpy()
        else:
            self.goal = None

    def act(self, x):
        if self.goal is None:
            a = np.random.choice(np.arange(self.env.action_space.n))
        else:
            if x[0] < self.goal[0]:
                a = 0
            elif x[0] > self.goal[0]:
                a = 1
            elif x[1] < self.goal[1]:
                a = 2
            elif x[1] > self.goal[1]:
                a = 3
            else:
                a = 4
        zeros = np.zeros(self.env.action_space.n)
        zeros[a] = 1
        return zeros


class RandPolicy(DarkroomOptPolicy):

    def __init__(self, env):
        super().__init__(env)
        self.env = env

    def act(self, x):
        a = np.random.choice([0, 1, 2, 3, 4])
        zeros = np.zeros(self.env.action_space.n)
        zeros[a] = 1
        return zeros


class DarkroomTransformerController(Controller):
    def __init__(self, model, sample=False):
        self.model = model
        self.du = model.config['du']
        self.dx = model.config['dx']
        self.H = model.H
        self.zeros = torch.zeros(1, self.dx**2 + self.du + 1).float().to(device)
        self.zerosQ = torch.zeros(1, self.H, self.dx**2).float().to(device)
        self.sample = sample
        self.temp = 1.0


    def act(self, x):
        self.batch['zeros'] = self.zeros
        self.batch['zerosQ'] = self.zerosQ
        
        states = torch.tensor(x)[None,:].float().to(device)
        self.batch['states'] = states

        a = self.model(self.batch)
        a = a.cpu().detach().numpy()[0]



        if self.sample:
            probs = scipy.special.softmax(a / self.temp)
            i = np.random.choice(np.arange(self.du), p=probs)
            # print(f"max: {probs.round(2)}")
        else:
            i = np.argmax(a)

        a = np.zeros(self.du)
        a[i] = 1.0
        return a



if __name__ == '__main__':
    env = sample(3)
    ctrl = RandPolicy(env)
    embed()