import numpy as np
import math
import torch
from torch.nn import functional as F


def check(input):
    if type(input) == np.ndarray:
        return torch.from_numpy(input)


def get_gard_norm(it):
    sum_grad = 0
    for x in it:
        if x.grad is None:
            continue
        sum_grad += x.grad.norm() ** 2
    return math.sqrt(sum_grad)


def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
    """Decreases the learning rate linearly"""
    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def huber_loss(e, d):
    a = (abs(e) <= d).float()
    b = (e > d).float()
    return a * e ** 2 / 2 + b * d * (abs(e) - d / 2)


def mse_loss(e):
    return e ** 2 / 2


def get_shape_from_obs_space(obs_space):
    if obs_space.__class__.__name__ == 'Box':
        obs_shape = obs_space.shape
    elif obs_space.__class__.__name__ == 'list':
        obs_shape = obs_space
    else:
        raise NotImplementedError
    return obs_shape


def get_shape_from_act_space(act_space):
    if act_space.__class__.__name__ == 'Discrete':
        act_shape = 1
    elif act_space.__class__.__name__ == "MultiDiscrete":
        act_shape = act_space.shape
    elif act_space.__class__.__name__ == "Box":
        act_shape = act_space.shape[0]
    elif act_space.__class__.__name__ == "MultiBinary":
        act_shape = act_space.shape[0]
    else:  # agar
        act_shape = act_space[0].shape[0] + 1
    return act_shape


def get_dim_from_act_space(act_space):
    if act_space.__class__.__name__ == 'Discrete':
        act_shape = act_space.n
    elif act_space.__class__.__name__ == "MultiDiscrete":
        act_shape = act_space.shape
    elif act_space.__class__.__name__ == "Box":
        act_shape = act_space.shape[0]
    elif act_space.__class__.__name__ == "MultiBinary":
        act_shape = act_space.shape[0]
    else:  # agar
        act_shape = act_space[0].shape[0] + 1
    return act_shape


def tile_images(img_nhwc):
    """
    Tile N images into one big PxQ image
    (P,Q) are chosen to be as close as possible, and if N
    is square, then P=Q.
    input: img_nhwc, list or array of images, ndim=4 once turned into array
        n = batch index, h = height, w = width, c = channel
    returns:
        bigim_HWc, ndarray with ndim=3
    """
    img_nhwc = np.asarray(img_nhwc)
    N, h, w, c = img_nhwc.shape
    H = int(np.ceil(np.sqrt(N)))
    W = int(np.ceil(float(N) / H))
    img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(N, H * W)])
    img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
    img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
    img_Hh_Ww_c = img_HhWwc.reshape(H * h, W * w, c)
    return img_Hh_Ww_c


class TrainDataSampler:
    def __init__(self, train_set, action_type, obs_dim, act_dim, add_action_noise=False, expert_noise_rate=1.0, disc_use_act_prob=False,
                 pretrain_classifier=False, classifier_use_gru=False, classifier_gru_his_len=10,
                 classifier_use_act_enc=False, classifier_use_data_tag=False):
        self.pretrain_classifier = pretrain_classifier
        self.classifier_use_gru = classifier_use_gru
        self.classifier_gru_his_len = classifier_gru_his_len
        self.classifier_use_act_enc = classifier_use_act_enc
        self.classifier_use_data_tag = classifier_use_data_tag
        self.states = []
        self.share_states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.next_share_states = []
        self.dones = []
        self.episode_rewards = []
        # set data for pretrain classifier
        self.agent_tags = []
        if self.pretrain_classifier:
            for episode_data in train_set:
                episode_data['state'], episode_data['action'] = np.array(episode_data['state']), np.array(episode_data['action'])
                episode_length, agent_num = episode_data['state'].shape[0], episode_data['state'].shape[1]
                agent_tag = np.repeat(np.arange(0, agent_num).reshape(1, -1), episode_length, axis=0)
                if self.classifier_use_gru:
                    all_his_episode_actions = np.expand_dims(episode_data['action'], axis=0)
                    for off in range(1, self.classifier_gru_his_len):
                        off_traj_action = np.concatenate([
                            np.zeros((off, *episode_data['action'].shape[1:]), dtype=episode_data['action'].dtype),
                            episode_data['action'][:episode_length - off, :]
                        ], axis=0) if off <= episode_length else np.zeros_like(episode_data['action'])
                        all_his_episode_actions = np.concatenate([
                            np.expand_dims(off_traj_action, axis=0),
                            all_his_episode_actions,
                        ], axis=0)
                    # (traj_length, agent_num, his_length, act_dim)
                    all_his_episode_actions = all_his_episode_actions.transpose((1, 2, 0, 3))
                    episode_data['action'] = all_his_episode_actions.reshape(
                        (-1, self.classifier_gru_his_len, act_dim if action_type == 'Continuous' else 1)) \
                        if not self.classifier_use_act_enc else all_his_episode_actions
                else:
                    # (traj_length, agent_num, act_dim)
                    episode_data['action'] = episode_data['action'].reshape(-1, act_dim if action_type == 'Continuous' else 1) \
                        if self.classifier_use_act_enc else all_his_episode_actions
                self.states.append(episode_data['state'].reshape(-1, obs_dim) if not self.classifier_use_act_enc else episode_data['state'])
                self.actions.append(episode_data['action'])
                # (traj_length, agent_num, 1)
                if self.classifier_use_data_tag:
                    episode_data['agent_tags'] = np.repeat(
                        np.array(episode_data['agent_tags']).reshape((-1, 1, 1)), episode_data['state'].shape[1], axis=1
                    ).reshape(-1, 1) if not self.classifier_use_act_enc else np.repeat(
                        np.array(episode_data['agent_tags']).reshape((-1, 1, 1)), episode_data['state'].shape[1], axis=1
                    )
                self.agent_tags.append(
                    episode_data['agent_tags'] if self.classifier_use_data_tag else
                    (agent_tag.reshape(-1, 1) if not self.classifier_use_act_enc else agent_tag)
                )
            self.states = np.concatenate(self.states, axis=0)
            self.actions = np.concatenate(self.actions, axis=0)
            self.agent_tags = np.concatenate(self.agent_tags, axis=0)
            print('self.states', self.states.shape)
            print('self.actions', self.actions.shape)
            print('self.agent_tags', self.agent_tags.shape)
            # shuffle all data
            indexes = np.arange(self.states.shape[0])
            np.random.shuffle(indexes)
            self.states = self.states[indexes]
            self.actions = self.actions[indexes]
            self.agent_tags = self.agent_tags[indexes]
            # split train/val set
            self.total_steps_num = self.states.shape[0]
            self.train_num = int(self.states.shape[0] * 0.8)
            self.valid_num = self.states.shape[0] - self.train_num
            self.train_set = {
                'states': self.states[:self.train_num],
                'actions': self.actions[:self.train_num],
                'tags': self.agent_tags[:self.train_num],
            }
            self.valid_set = {
                'states': self.states[self.train_num:],
                'actions': self.actions[self.train_num:],
                'tags': self.agent_tags[self.train_num:],
            }
            self.start = 0
            return
        for episode_data in train_set:
            self.states.append(episode_data['state'])
            self.share_states.append(episode_data['share_state'])
            self.actions.append(episode_data['action'])
            self.rewards.append(episode_data['reward'])
            self.next_states.append(episode_data['next_state'])
            self.next_share_states.append(episode_data['next_share_state'])
            self.dones.append(episode_data['done'])
            self.episode_rewards.append(episode_data['episode_reward'])
        self.states = np.concatenate(self.states, axis=0)
        self.share_states = np.concatenate(self.share_states, axis=0)
        self.actions = np.concatenate(self.actions, axis=0)
        self.rewards = np.concatenate(self.rewards, axis=0)
        self.next_states = np.concatenate(self.next_states, axis=0)
        self.next_share_states = np.concatenate(self.next_share_states, axis=0)
        self.dones = np.concatenate(self.dones, axis=0)
        self.total_steps_num = self.states.shape[0]
        # add action noise to expert action to confusion disc
        # action_noise: (agent_num, act_dim)
        self.action_type = action_type
        self.add_action_noise = add_action_noise
        self.act_dim = act_dim
        self.expert_noise_rate = expert_noise_rate
        self.disc_use_act_prob = disc_use_act_prob
        self.action_noise = torch.from_numpy(np.std(self.actions, axis=0)) * self.expert_noise_rate

    def sample_batch_data(self, batch_size):
        indexes = np.random.choice(np.arange(self.total_steps_num), size=batch_size, replace=False)
        batch_states = torch.from_numpy(self.states[indexes])  # .to(dtype=torch.float32, device=device)
        batch_share_states = torch.from_numpy(self.share_states[indexes])  # .to(dtype=torch.float32, device=device)
        batch_actions = torch.from_numpy(self.actions[indexes])  # .to(dtype=torch.float32, device=device)
        # add noise to expert action if necessary
        if self.add_action_noise and self.action_type == 'Continuous':
            action_noise = torch.normal(mean=torch.zeros_like(batch_actions), std=self.action_noise)
            batch_actions = batch_actions + action_noise
        elif self.add_action_noise and self.action_type == 'Discrete':
            batch_actions = torch.where(
                torch.rand(batch_actions.shape) < self.expert_noise_rate,
                torch.randint(0, self.act_dim, batch_actions.shape, dtype=batch_actions.dtype),
                batch_actions,
            )
        # change action to one-hot if disc learn action one-hot
        if self.disc_use_act_prob and self.action_type == 'Discrete':
            batch_actions = F.one_hot(batch_actions.squeeze(-1), num_classes=self.act_dim).to(dtype=torch.float32)
        batch_states = batch_states.reshape(-1, *batch_states.shape[2:]).numpy()
        batch_share_states = batch_share_states.reshape(-1, *batch_share_states.shape[2:]).numpy()
        batch_actions = batch_actions.reshape(-1, *batch_actions.shape[2:]).numpy()

        return batch_share_states, batch_states, batch_actions

    def get_total_steps_num(self):
        return self.total_steps_num

    def get_total_episodes_num(self):
        return len(self.episode_rewards)

    def sample_classifier_batch_data(self, batch_size):
        indexes = np.random.choice(np.arange(self.train_num), size=batch_size, replace=False)
        batch_states = torch.from_numpy(self.train_set['states'][indexes])
        batch_actions = torch.from_numpy(self.train_set['actions'][indexes])
        batch_tags = torch.from_numpy(self.train_set['tags'][indexes])

        return batch_states, batch_actions, batch_tags

    def start_val(self):
        self.start = 0

    def get_all_valid_data(self, batch_size):
        start = self.start
        end = min(start + batch_size, self.valid_num)
        val_states = torch.from_numpy(self.valid_set['states'][start: end])
        val_actions = torch.from_numpy(self.valid_set['actions'][start: end])
        val_tags = torch.from_numpy(self.valid_set['tags'][start: end])
        finish = end >= self.valid_num
        self.start = self.start + batch_size

        return val_states, val_actions, val_tags, finish


# add method for training gail
def get_flat_grads(f, net):
    print('autograd', len(torch.autograd.grad(f, net.parameters(), create_graph=True, allow_unused=True)))
    print('param', len(list(net.parameters())))
    flat_grads = torch.cat([
        grad.view(-1) if grad is not None else torch.zeros_like(param).view(-1)
        for grad, param in zip(torch.autograd.grad(f, net.parameters(), create_graph=True, allow_unused=True), net.parameters())
    ])

    return flat_grads


def get_flat_params(net):
    return torch.cat([param.view(-1) for param in net.parameters()])


def conjugate_gradient(Av_func, b, max_iter=10, residual_tol=1e-10):
    x = torch.zeros_like(b)
    r = b - Av_func(x)
    p = r
    rsold = r.norm() ** 2

    for _ in range(max_iter):
        Ap = Av_func(p)
        alpha = rsold / torch.dot(p, Ap)
        x = x + alpha * p
        r = r - alpha * Ap
        rsnew = r.norm() ** 2
        if torch.sqrt(rsnew) < residual_tol:
            break
        p = r + (rsnew / rsold) * p
        rsold = rsnew

    return x


def set_params(net, new_flat_params):
    start_idx = 0
    for param in net.parameters():
        end_idx = start_idx + np.prod(list(param.shape))
        param.data = torch.reshape(
            new_flat_params[start_idx:end_idx], param.shape
        )

        start_idx = end_idx


def print_network_grad(net_name, net):
    print('-----------', net_name, '------------')
    for p in net.parameters():
        print(p.grad.mean() if p.grad is not None else torch.zeros_like(p).mean())
