import torch
import torch.nn.functional as F
import numpy as np
import torch.utils
import torchvision

from visgrid.envs import TaxiEnv
from utils.datasets import StoredDataset

from tqdm import tqdm
from utils.printarr import printarr

import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

def generate_taxi_data(n_samples=1000, w=torch.eye(5)):
    '''
        x, y ~ randint(0, 10)
        p_x, p_y ~ randint(0, 10)
        in_taxi ~ randint(0, 1)
        actions: 0: up, 1: down, 2: left, 3: right, 4: dropoff, 5: pickup
    '''
    N = 9
    smoothing_noise = 0.1/9
    jitter_noise = 0.1/9
    actions = torch.randint(0, 6, (n_samples,))
    x, y = torch.randint(1, N, (n_samples,)), torch.randint(1, N, (n_samples,))
    px, py = torch.randint(1, N, (n_samples,)), torch.randint(1, N, (n_samples,))
    in_taxi = torch.randint(0, 2, (n_samples,))
    in_taxi = in_taxi.float()

    same_pos = (torch.rand((n_samples,)) < 0.33).float() + in_taxi
    px, py = x * (same_pos > 0).float() + px*(same_pos <= 0).float(), y * (same_pos > 0).float() + py*(same_pos <= 0).float()

    in_taxi_possible = (actions == 5).float() * (1-in_taxi) * torch.randint_like(in_taxi, 2)
    px = x * in_taxi_possible + px * (1-in_taxi_possible)
    py = y * in_taxi_possible + py * (1-in_taxi_possible)


    delta_x = (actions == 3).float() - (actions == 2).float() # left/right
    delta_y = (actions == 0).float() - (actions == 1).float() # up/down
    next_x = x + delta_x
    next_y = y + delta_y
    next_px = px + (delta_x * in_taxi)
    next_py = py + (delta_y * in_taxi)


    next_in_taxi = (1-(actions == 4).float()) * in_taxi # dropoff
    next_in_taxi = (actions == 5).float() * (x == px).float() * (y == py).float() + (1-(actions==5).float()) * next_in_taxi.float() #pickup
    rewards = (actions == 5).float() * (next_in_taxi == 1).float() + torch.randn((n_samples,))*0.01 - 0.5


    s = torch.stack([x/N,y/N,px/N,py/N,(in_taxi + torch.randn((n_samples,))*smoothing_noise)], dim=-1)-0.5
    next_s = torch.stack([next_x/N, next_y/N, next_px/N, next_py/N,(next_in_taxi + torch.randn((n_samples,))*smoothing_noise)], dim=-1)-0.5


    return torch.einsum('bj, ij-> bi', s + torch.randn_like(s)*smoothing_noise, w), actions, torch.einsum('bj, ij-> bi', next_s+torch.randn_like(s)*jitter_noise, w), rewards



class GridworldN:
    '''
        Action 0: NoOp, observation actions
        Action 2i, 2i+1 with i \in {0,..., N-1} controls the ith dimension
    '''

    def __init__(self, n_dim=2, step_size=1/5, jitter_std=1e-3, diagonal=False):
        assert n_dim > 0
        self.step_size = step_size
        self.jitter_std = jitter_std
        self.n_dim = n_dim # number of dimensions
        self.diagonal = diagonal


    def step(self, state, action):
        assert torch.all(action < 2*self.n_dim + 1)
        action_dim = F.one_hot((action+1) // 2, self.n_dim+1)[..., 1:] # controlled dimension
        direction = ((action % 2)*2 - 1).unsqueeze(-1) # direction
        next_state = torch.clamp(state + torch.randn_like(state)*self.jitter_std  + action_dim * direction * self.step_size, -1, 1)
        return next_state
    
    def generate_dataset(self, n_samples):
        states = torch.rand((n_samples, self.n_dim))*2-1
        actions = torch.randint(self.n_dim*2+1, (n_samples, ))
        next_states = self.step(states, actions)
        rewards = torch.zeros((n_samples,))
        # data = torch.utils.data.TensorDataset(states, actions, next_states, rewards, states, next_states)
        # data = list(zip(states.tolist(), actions.tolist(), next_states.tolist(), rewards.tolist(), states.tolist(), next_states.tolist()))
        dataset = [states.numpy(), actions.numpy(), next_states.numpy(), rewards.numpy(), states.numpy(), next_states.numpy()]
        dataset = list(zip(*dataset))
        return StoredDataset(dataset)


    @property
    def state_dim(self):
        return self.n_dim
    
    @property
    def obs_dim(self):
        return self.n_dim
    
    @property
    def n_actions(self):
        return self.n_dim*2+1

    @property
    def discrete(self):
        return False


class GridworldEgo:
    '''
        Action 0: NoOp, observation actions
        Action 1, 2: Rotate Left/Right
        Action 3, 4: Move Forward/Backward
    '''
    ACTIONS = [(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)]
    def __init__(self, step_size=1/10, jitter_std=0.01):
        self.step_size = step_size
        self.jitter_std = jitter_std

    @property
    def state_dim(self):
        return 4
    @property
    def obs_dim(self):
        return 4
    @property
    def n_actions(self):
        return 5
    
    def step(self, state, action):
        pass
    
    def generate_dataset(self, n_samples):
        pass



class GridworldAttractionN:
    ''' 
        N dimensional continuous gridworld with pushable objects.
        When object is close enough (< epsilon distance) then it moves the same as the agent.
        Action 0: NoOp, observation actions
        Action 2i, 2i+1 with i \in {0,..., N-1} controls the ith dimension

    '''

    def __init__(self, n_dim=2, step_size=1/10, jitter_std=0.01, n_objects=1, epsilon=1/10/10):
        assert n_dim > 0
        self.step_size = step_size
        self.jitter_std = jitter_std
        self.n_dim = n_dim # number of dimensions
        self.n_objects = n_objects
        self.epsilon = epsilon # distance threshold
    
    @property
    def state_dim(self):
        return self.n_dim * (1 + self.n_objects)

    @property
    def obs_dim(self):
        return self.state_dim

    @property
    def n_actions(self):
        return self.n_dim*2+1

    def step(self, state, action):
        assert torch.all(action < 2*self.n_dim + 1)
        action_dim = F.one_hot((action+1) // 2, self.n_dim+1)[..., 1:] # controlled dimension
        direction = ((action % 2)*2 - 1).unsqueeze(-1) # direction
        next_state_agent = torch.clamp(state[..., :self.n_dim] + torch.randn_like(state[..., :self.n_dim])*self.jitter_std  + action_dim * direction * self.step_size, 0, 1)
        state_objects = state[:, self.n_dim:].reshape(-1, self.n_objects, self.n_dim) # B, O, N
        close_enough = (((state_objects - state[:, None, :self.n_dim])**2).sum(-1) < self.epsilon**2).float() # B, O

        next_state_objects = close_enough.unsqueeze(-1) * ((action_dim * direction * self.step_size)[:, None] + torch.randn_like(state_objects)*self.jitter_std + state_objects) + \
                             (1-close_enough).unsqueeze(-1) * (torch.randn_like(state_objects)*self.step_size/2 + state_objects)
        next_state_objects = torch.clamp(next_state_objects, 0, 1)

        return torch.can_dim([next_state_agent.unsqueeze(1), next_state_objects], dim=1).reshape(-1, self.n_dim*(self.n_objects+1))

    def generate_dataset(self, n_samples):
        states = torch.rand((n_samples, self.n_dim))
        # sample n_dimbject states
        states_objects = torch.rand((n_samples, self.n_objects, self.n_dim))
        
        # choose object to make close to agent.
        close_to_agent = F.one_hot(torch.randint(self.n_objects+1, (n_samples,)), self.n_objects+1)[:, 1:, None] # B, O, 1

        states_objects = states_objects * (1-close_to_agent) + close_to_agent * (states[:, None] + torch.randn_like(states_objects)*self.epsilon/2)
        states = torch.cat([states[:, None], states_objects], dim=1).reshape(n_samples, -1)

        actions = torch.randint(self.n_dim*2+1, (n_samples, ))
        next_states = self.step(states, actions)
        rewards = torch.zeros((n_samples,))


        dataset = [states.numpy(), actions.numpy(), next_states.numpy(), rewards.numpy(), states.numpy(), next_states.numpy()]
        dataset = list(zip(*dataset))
        return StoredDataset(dataset)

    @property
    def discrete(self):
        return False


class VisualTaxi:

    def __init__(self, **taxi_cfg):
        self.cfg = taxi_cfg
        self.max_length = taxi_cfg['max_length']
        self.env = TaxiEnv(**{k:v for k,v in self.cfg.items() if k != 'max_length'})
        self.resizer = torchvision.transforms.Resize(32)
    
    def generate_dataset(self, n_samples):
        samples = []
        s, info = self.env.reset()
        t = 0
        max_length = self.max_length
        for _ in tqdm(range(n_samples)):
            factored_state = info['state']
            if np.all(factored_state[0:2] == factored_state[2:4]):
                prob_pickup = 0.125
                prob_pickup = 0.2
                if factored_state[4] < 0.5: # not in taxi
                    action = 5 if np.random.rand(1) > prob_pickup else np.random.randint(5)
                else:
                    action = 5 if np.random.rand(1) < prob_pickup else np.random.randint(5)
            else:
                action = self.env.action_space.sample()

            (next_s, reward, done, truncated, next_info) = self.env.step(action)
            samples.append((self.smooth_n_normalize(s, self.env.should_render), 
                            action, 
                            self.smooth_n_normalize(next_s, self.env.should_render), 
                            reward, 
                            info['state'].astype(np.float32), 
                            next_info['state'].astype(np.float32)
                            ))
            s, info = next_s, next_info
            if done or t >= max_length:
                s, info = self.env.reset()
                t = 0
            t += 1
        
        samples = [torch.from_numpy(np.array(i)) for i in zip(*samples)]

        ss, next_ss = samples[-2:]
        actions = samples[1]
        # self.visualize_data(ss.numpy(), samples[0].numpy())
        delta_ss = ((next_ss - ss).abs() != 0).int()
        dropoff = ((next_ss - ss)[:, 4] < 0).float().sum() / (actions==5).float().sum()

        print(f'Dropoff proportions {dropoff}')
        affected_dimensions = torch.nonzero(delta_ss)
        affacted_dimensions_distribution = torch.bincount(affected_dimensions[:, 1])
        affacted_dimensions_distribution = affacted_dimensions_distribution / affacted_dimensions_distribution.sum()
        print(f'Effect Distribution {affacted_dimensions_distribution}')
        _, action_effects = delta_ss.max(-1) 
        action_dist = torch.bincount(samples[1]) 
        action_dist = action_dist / action_dist.sum()
        print(f'Action Distribution {action_dist}')

        in_taxi = torch.bincount(ss[..., -2].long()).float()/ss.shape[0]
        print(f'in_taxi proportions {in_taxi}')

        dataset = [s.numpy() for s in samples]
        dataset = list(zip(*dataset))
        return StoredDataset(dataset)

    def visualize_data(self, states, obss):
        states = states.astype(np.int64)
        printarr(states)
        grid_size_x, grid_size_y, goal_positions = (*self.var_domains[0:2], self.var_domains[-1])
        n_pos = grid_size_x * grid_size_y
        taxi_positions = np.bincount(states[:, 0] + grid_size_x*states[:, 1] + n_pos*states[:, -1], minlength=n_pos*goal_positions).reshape(goal_positions, grid_size_y, grid_size_x)
        passenger_positions = states[:, 2] + grid_size_x*states[:, 3] + n_pos*states[:, -1]
        passenger_positions_in_taxi = np.bincount(passenger_positions[states[:, 4]==1], minlength=n_pos*goal_positions).reshape(goal_positions, grid_size_y, grid_size_x)
        passenger_positions_off_taxi = np.bincount(passenger_positions[states[:, 4]==0], minlength=n_pos*goal_positions).reshape(goal_positions, grid_size_y, grid_size_x)

        plt.figure(figsize=(15, 15))
        for i in range(goal_positions):
            idx = i*3
            plt.subplot(goal_positions, 3, idx+1)
            taxi_dist = taxi_positions[i] / taxi_positions[i].sum()
            sns.heatmap(taxi_dist, annot=True)
            plt.subplot(goal_positions, 3, idx+2)
            passenger_positions_in_taxi_dist = passenger_positions_in_taxi[i]/passenger_positions_in_taxi[i].sum()
            sns.heatmap(passenger_positions_in_taxi_dist, annot=True)
            plt.subplot(goal_positions, 3, idx+3)
            passenger_positions_off_taxi_dist = passenger_positions_off_taxi[i]/passenger_positions_off_taxi[i].sum()
            sns.heatmap(passenger_positions_off_taxi_dist, annot=True)

        plt.savefig('taxi_dist.png')


        # randomly choose some images per goal
        images = []
        printarr(obss)
        for i in range(goal_positions):
            goal_images = obss[states[:, -1]==i]
            goal_images = obss[np.random.choice(goal_images.shape[0], 5)]
            goal_images = np.concatenate(list(((goal_images + 0.5)*255).astype(np.int32)), axis=-1)
            images.append(goal_images)

        images = np.concatenate(images, axis=1)
        plt.figure()
        plt.imshow(images.transpose(1,2,0))
        plt.savefig('taxi_images.png')


    def smooth_n_normalize(self, obs, pixel=False):
        if pixel:
            obs = np.clip(obs + np.random.randn(*obs.shape)*10/255, 0, 1) - 0.5
            obs = obs.transpose(2, 0, 1)
            obs = self.resizer(torch.asarray(obs)).mean(0, keepdim=True).numpy()
        else:
            factored_state_space = self.env.state_space
            obs = obs + np.random.randn(*obs.shape)*0.1
            obs = obs / factored_state_space.nvec - 0.5
        return obs.astype(np.float32)

    @property
    def obs_dim(self):
        if self.env.should_render:
            shape = self.env.observation_space.shape
            # return (shape[-1], *shape[:-1])
            # return (shape[-1], 32, 32)
            return [1, 32, 32]
        return len(self.env.observation_space)
    
    @property
    def state_dim(self):
        return len(self.env.state_space.nvec)

    @property
    def n_actions(self):
        return self.env.action_space.n

    @property
    def discrete(self):
        return True

    @property
    def var_domains(self):
        return self.env.state_space.nvec.tolist()



if __name__ == '__main__':
    env = VisualTaxi(should_render=True)

    samples = env.generate_dataset(int(1e4))
    
