import numpy as np
import gym
from scipy.linalg import solve_discrete_are
from IPython import embed
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


UPPER_BOUND = np.inf
LOWER_BOUND = -np.inf


def sample(dx, du):
    A = np.eye(dx)
    B = np.eye(du) * .1
    
    Q = np.diag(np.random.uniform(0, 1, dx))
    R = .1 * np.eye(du)

    return A, B, Q, R




class LQREnv(gym.Env):
    def __init__(self, A, B, Q, R, H):
        # Define the system dynamics
        self.A = A 
        self.B = B

        # Define the cost matrices
        self.Q = Q
        self.R = R

        self.H = H

        self.dx = Q.shape[0]
        self.du = R.shape[0]

        # Define the state and action spaces
        self.observation_space = gym.spaces.Box(low=LOWER_BOUND, high=UPPER_BOUND, shape=(self.dx,))
        self.action_space = gym.spaces.Box(low=LOWER_BOUND, high=UPPER_BOUND, shape=(self.du,))


        # Define the initial state
        self.state = np.zeros(self.dx)

        self.current_step = 0  # Current step count


    def reset(self,x_init=None):
        self.state = np.random.normal(.75, 0, self.dx)
        self.current_step = 0
        return self.state

    def transit(self, x, u):
        cost = (x.dot(self.Q).dot(x) + u.dot(self.R).dot(u))
        x_prime = self.A.dot(x) + self.B.dot(u)
        return x_prime, cost


    def step(self, action, include_u=True):

        if self.current_step >= self.H:
            raise ValueError("Episode has already ended")


        # Apply the control input to the system and get the next state
        next_state = self.A.dot(self.state) + self.B.dot(action)

        # Compute the cost
        cost = (self.state.dot(self.Q).dot(self.state) + action.dot(self.R).dot(action))

        # Update the state
        self.state = next_state
        self.current_step += 1

        done = (self.current_step >= self.H)

        # Return the observation, reward, done flag, and info dictionary
        return next_state, cost, done, {}

    def render(self, mode='human'):
        pass

    def deploy_eval(self, ctrl, include_partial_hist=False, grow_context=False):
        return self.deploy(
            ctrl,
            include_partial_hist=include_partial_hist,
            grow_context=grow_context)

    def deploy(self, ctrl, include_partial_hist=False, grow_context=False):
        x = self.reset()
        xs = []
        xps = []
        us = []
        rs = []
        done = False

        while not done:
            u = ctrl.act(x)

            xs.append(x)
            us.append(u)

            x, r, done, _ = self.step(u)
            
            rs.append(r)
            xps.append(x)

            if include_partial_hist:
                new_x = torch.tensor(xs[-1][None, None, :]).float().to(device)
                new_u = torch.tensor(us[-1][None, None, :]).float().to(device)
                new_xp = torch.tensor(xps[-1][None, None, :]).float().to(device)
                new_r = torch.tensor(np.array([r])[None, None, :]).float().to(device)

                if grow_context:
                    new_rollin_xs = torch.cat((ctrl.batch['rollin_xs'], new_x), axis=1)
                    new_rollin_us = torch.cat((ctrl.batch['rollin_us'], new_u), axis=1)
                    new_rollin_xps = torch.cat((ctrl.batch['rollin_xps'], new_xp), axis=1)
                    new_rollin_rs = torch.cat((ctrl.batch['rollin_rs'], new_r), axis=1)
                else:
                    new_rollin_xs = torch.cat((ctrl.batch['rollin_xs'][:, 1:], new_x), axis=1)
                    new_rollin_us = torch.cat((ctrl.batch['rollin_us'][:, 1:], new_u), axis=1)
                    new_rollin_xps = torch.cat((ctrl.batch['rollin_xps'][:, 1:], new_xp), axis=1)
                    new_rollin_rs = torch.cat((ctrl.batch['rollin_rs'][:, 1:], new_r), axis=1)

                batch = {
                    'rollin_xs': new_rollin_xs,
                    'rollin_us': new_rollin_us,
                    'rollin_xps': new_rollin_xps,
                    'rollin_rs': new_rollin_rs,
                }
                ctrl.set_batch(batch)

        return np.array(xs), np.array(us), np.array(xps), np.array(rs)

class LQRController:
    def __init__(self, A, B, Q, R):
        self.A = np.array(A)
        self.B = np.array(B)
        self.Q = np.array(Q)
        self.R = np.array(R)
        self.P = solve_discrete_are(A, B, Q, R)
        self.K = np.dot(np.linalg.pinv(self.R + np.dot(B.T, np.dot(self.P, B))), np.dot(B.T, np.dot(self.P, A)))

    def reset(self):
        return

    def act(self, x):
        u = -np.dot(self.K, x)
        return u

class RandController(LQRController):

    def __init__(self, A, B, Q, R):
        super().__init__(A, B, Q, R)

    def act(self, x):
        u = super().act(x)
        u += np.random.normal(0, .5, u.shape)
        return u



class TransformerController:
    def __init__(self, model, batch, Q=None):
        self.model = model
        self.batch = batch
        self.du = model.config['du']
        self.dx = model.config['dx']
        self.H = model.H
        self.zeros = torch.zeros(1, self.dx**2 + self.du + 1).float().to(device)
        self.zerosQ = torch.zeros(1, self.H, self.dx**2).float().to(device)
        self.batch['zeros'] = self.zeros
        self.batch['zerosQ'] = self.zerosQ
        if Q is not None:
            self.batch['Qs'] = torch.tensor(Q[None,:,:]).float().to(device)


    def act(self, x):
        states = torch.tensor(x)[None,:].float().to(device)
        self.batch['states'] = states

        a = self.model(self.batch)
        a = a.cpu().detach().numpy()[0]
        return a



