import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Callable
from PASRL_buffer import LAP
from utils.Pink_noise import ColoredActionNoise


@dataclass
class Hyperparameters:
    # Generic
    batch_size: int = 256  # this is changed from  256
    buffer_size: int = 1e6
    discount: float = 0.99
    target_update_rate: int = 250
    exploration_noise: float = 0.1
    hidden_exploration_noise: float = 0.1
    beta: float = 1
    noise_scale: float = 0.3

    # TD3
    target_policy_noise: float = 0.2
    noise_clip: float = 0.5
    policy_freq: int = 2

    # LAP
    alpha: float = 0.4
    min_priority: float = 1

    # ERE
    eta: float = 0.994
    ck_min: int = 2.5e4  # todo: after ablation runs change this parameter

    # TD3+BC
    lmbda: float = 0.1

    # Checkpointing
    max_eps_when_checkpointing: int = 20
    steps_before_checkpointing: int = 75e4
    reset_weight: float = 0.9

    # Encoder Model
    zs_dim: int = 256
    enc_hdim: int = 256
    enc_activ: Callable = F.elu
    encoder_lr: float = 3e-4
    lstm_dim: int = 80
    lstm_l: int = 2

    # Critic Model
    critic_hdim: int = 256
    critic_activ: Callable = F.elu
    critic_lr: float = 3e-4

    # Actor Model
    actor_hdim: int = 256
    actor_activ: Callable = F.relu
    actor_lr: float = 3e-4

    # Training
    ep_max_length: int = 1000


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def AvgL1Norm(x, eps=1e-8):
    return x / x.abs().mean(-1, keepdim=True).clamp(min=eps)


def LAP_huber(x, min_priority=1):
    return torch.where(x < min_priority, 0.5 * x.pow(2), min_priority * x).sum(1).mean()


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, zs_dim=256, hdim=256, activ=F.relu):
        super(Actor, self).__init__()

        self.activ = activ

        # latest time step information
        self.l0 = nn.Linear(state_dim + action_dim, hdim)
        self.l1 = nn.Linear(zs_dim + hdim, hdim)
        self.l2 = nn.Linear(hdim, hdim)
        self.l3 = nn.Linear(hdim, action_dim)

    def forward(self, state, prev_actions, zs):
        a = AvgL1Norm(self.l0(torch.cat([state[:, -1, :], prev_actions[:, -1, :]], 1)))
        a = torch.cat([a, zs], 1)
        a = self.activ(self.l1(a))
        a = self.activ(self.l2(a))
        return torch.tanh(self.l3(a))


class Encoder(nn.Module):
    def __init__(self, state_dim, action_dim, lstm_l, lstm_dim, zs_dim=256, hdim=256, activ=F.elu):
        super(Encoder, self).__init__()

        self.activ = activ
        self.lstm_l = lstm_l
        self.lstm_dim = lstm_dim

        # Define learnable initial hidden states
        self.initial_hidden = nn.Parameter(torch.randn(self.lstm_l, 1, self.lstm_dim).type(torch.FloatTensor), requires_grad=True)
        nn.init.xavier_uniform_(self.initial_hidden)

        # state encoder
        self.zs1 = nn.Linear(state_dim, hdim)
        self.zs2 = nn.GRU(input_size=hdim, hidden_size=lstm_dim, num_layers=lstm_l, batch_first=True)
        self.zs_dropout = nn.Dropout(0.2)
        self.zs3 = nn.Linear(lstm_dim, zs_dim)

        # state-action encoder
        self.zsa1 = nn.Linear(zs_dim+action_dim, hdim)
        self.zsa2 = nn.GRU(input_size=hdim, hidden_size=lstm_dim, num_layers=lstm_l, batch_first=True)
        self.zsa_dropout = nn.Dropout(0.2)
        self.zsa3 = nn.Linear(lstm_dim, zs_dim)

    def reset_hx(self):
        nn.init.xavier_uniform_(self.initial_hidden)
        hidden = self.initial_hidden.contiguous()
        return hidden

    def zs(self, state, hx):
        zs = self.activ(self.zs1(state))
        zs, nhx = self.zs2(zs, hx)
        zs = self.zs_dropout(zs)
        zs = AvgL1Norm(self.zs3(zs))
        return zs[:, -1, :], zs, nhx

    def zsa(self, zs, action, hx):
        t_info = torch.cat([zs, action], -1)

        zsa = self.activ(self.zsa1(t_info))
        zsa, _ = self.zsa2(zsa, hx)
        zsa = self.zsa_dropout(zsa)
        zsa = self.zsa3(zsa[:, -1, :])
        return zsa


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, zs_dim=256, hdim=256, activ=F.elu):
        super(Critic, self).__init__()

        self.activ = activ

        self.q01 = nn.Linear(state_dim + action_dim, hdim)
        self.q1 = nn.Linear(2 * zs_dim + hdim, hdim)
        self.q2 = nn.Linear(hdim, hdim)
        self.q3 = nn.Linear(hdim, 1)

        self.q02 = nn.Linear(state_dim + action_dim, hdim)
        self.q4 = nn.Linear(2 * zs_dim + hdim, hdim)
        self.q5 = nn.Linear(hdim, hdim)
        self.q6 = nn.Linear(hdim, 1)

    def forward(self, state, action, zsa, zs):
        # current timestep info
        sa = torch.cat([state[:, -1, :], action[:, -1, :]], 1)
        embeddings = torch.cat([zsa, zs], 1)

        q1 = AvgL1Norm(self.q01(sa))
        q1 = torch.cat([q1, embeddings], 1)
        q1 = self.activ(self.q1(q1))
        q1 = self.activ(self.q2(q1))
        q1 = self.q3(q1)

        q2 = AvgL1Norm(self.q02(sa))
        q2 = torch.cat([q2, embeddings], 1)
        q2 = self.activ(self.q4(q2))
        q2 = self.activ(self.q5(q2))
        q2 = self.q6(q2)
        return torch.cat([q1, q2], 1)


class Agent(object):
    def __init__(self, state_dim, action_dim, max_action, history_length, offline=False, hp=Hyperparameters()):
        # Changing hyperparameters example: hp=Hyperparameters(batch_size=128)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.history_length = history_length
        self.action_dim = action_dim
        self.hp = hp

        self.actor = Actor(state_dim, action_dim,
                           hp.zs_dim, hp.actor_hdim, hp.actor_activ).to(self.device)
        self.actor_optimizer = torch.optim.NAdam(self.actor.parameters(), lr=hp.actor_lr)
        self.actor_target = copy.deepcopy(self.actor)

        self.critic = Critic(state_dim, action_dim, hp.zs_dim, hp.critic_hdim, hp.critic_activ).to(self.device)
        self.critic_optimizer = torch.optim.NAdam(self.critic.parameters(), lr=hp.critic_lr)
        self.critic_target = copy.deepcopy(self.critic)

        self.encoder = Encoder(state_dim, action_dim, hp.lstm_l, hp.lstm_dim,
                               hp.zs_dim, hp.enc_hdim, hp.enc_activ).to(self.device)
        self.encoder_optimizer = torch.optim.NAdam(self.encoder.parameters(), lr=hp.encoder_lr)
        self.fixed_encoder = copy.deepcopy(self.encoder)
        self.fixed_encoder_target = copy.deepcopy(self.encoder)

        self.checkpoint_actor = copy.deepcopy(self.actor)
        self.checkpoint_encoder = copy.deepcopy(self.encoder)

        self.replay_buffer = LAP(self.device, hp.buffer_size, hp.batch_size, max_action, normalize_actions=True)

        self.max_action = max_action
        self.offline = offline

        self.training_steps = 0

        # Checkpointing tracked values
        self.eps_since_update = 0
        self.timesteps_since_update_list = []
        self.max_eps_before_update = 1
        self.min_return = 1e8
        self.best_min_return = -1e8

        # Value clipping tracked values
        self.max = -1e8
        self.min = 1e8
        self.max_target = 0
        self.min_target = 0

        # flatten params
        self.flatten_recurrent_params()

        # model parameters
        print(f'The Actor has {count_parameters(self.actor):,} trainable parameters')
        print(f'The Critic has {count_parameters(self.critic):,} trainable parameters')
        print(f'The Encoder has {count_parameters(self.encoder):,} trainable parameters')

    def init_episode_noise(self):
        self.noise = ColoredActionNoise(self.hp.beta, self.hp.noise_scale, self.hp.ep_max_length, self.action_dim)
        self.noise = self.noise.gen.buffer / np.max(np.abs(self.noise.gen.buffer))
        self.noise = self.noise * self.hp.exploration_noise

    def select_action(self, state, previous_actions, timestep=None, use_checkpoint=False,
                      use_exploration=True, hx=None):
        # set model mode
        self.checkpoint_encoder.eval()
        self.fixed_encoder.eval()

        with torch.no_grad():
            state = torch.tensor(state[np.newaxis, :, :], dtype=torch.float, device=self.device)
            previous_actions = torch.tensor(previous_actions[np.newaxis, :, :], dtype=torch.float, device=self.device)

            if use_checkpoint:
                zs, _, nhx = self.checkpoint_encoder.zs(state, hx=hx)
                action = self.checkpoint_actor(state, previous_actions, zs)
            else:
                zs, _, nhx = self.fixed_encoder.zs(state, hx=hx)
                action = self.actor(state, previous_actions, zs)

            action = action.cpu().data.numpy().flatten()

            if use_exploration:
                action = action + self.noise[:, timestep]

            return action.clip(-1, 1) * self.max_action, hx, nhx

    def train(self, ck):
        # set model modes
        self.fixed_encoder.train()
        self.fixed_encoder_target.eval()
        self.encoder.train()

        self.training_steps += 1

        state, action, prev_action, next_state, reward, not_done, hx, nhx = self.replay_buffer.sample(ck)

        #########################
        # Update Encoder
        #########################
        with torch.no_grad():
            next_zs, _, _ = self.encoder.zs(next_state, hx=nhx)

        zs, zsl, _ = self.encoder.zs(state, hx=hx)
        pred_zs = self.encoder.zsa(zsl, action, hx)
        encoder_loss = F.mse_loss(pred_zs, next_zs)

        self.encoder_optimizer.zero_grad()
        encoder_loss.backward()
        self.encoder_optimizer.step()

        #########################
        # Update Critic
        #########################
        with torch.no_grad():
            fixed_target_zs, fixed_target_zsl, _ = self.fixed_encoder_target.zs(next_state, hx=nhx)

            # add noise to actions
            noise = (torch.randn((self.hp.batch_size, self.action_dim)) * self.hp.target_policy_noise).clamp(-self.hp.noise_clip, self.hp.noise_clip).to(self.device)
            next_action = (self.actor_target(next_state, action, fixed_target_zs) + noise).clamp(-1, 1)

            # add noise to the nhx
            nhx += torch.randn_like(nhx) * self.hp.hidden_exploration_noise

            # make new history seq next action
            next_actions = action.clone()
            next_actions = torch.roll(next_actions, shifts=-1, dims=1)
            next_actions[:, -1, :] = next_action

            fixed_target_zsa = self.fixed_encoder_target.zsa(fixed_target_zsl, next_actions, nhx)

            Q_target = self.critic_target(next_state, next_actions, fixed_target_zsa, fixed_target_zs).min(1, keepdim=True)[0]
            Q_target = reward.unsqueeze(1) + not_done.unsqueeze(1) * self.hp.discount * Q_target.clamp(self.min_target, self.max_target)

            self.max = max(self.max, float(Q_target.max()))
            self.min = min(self.min, float(Q_target.min()))

            fixed_zs, fixed_zsl, _ = self.fixed_encoder.zs(state, hx)
            fixed_zsa = self.fixed_encoder.zsa(fixed_zsl, action, hx)

        Q = self.critic(state, action, fixed_zsa, fixed_zs)
        td_loss = (Q - Q_target).abs()
        critic_loss = LAP_huber(td_loss)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        #########################
        # Update Actor
        #########################
        if self.training_steps % self.hp.policy_freq == 0:
            actor = self.actor(state, prev_action, fixed_zs)

            # roll the actions
            actor_actions = action.clone()
            actor_actions[:, -1, :] = actor

            fixed_zsa = self.fixed_encoder.zsa(fixed_zsl, actor_actions, hx)
            Q = self.critic(state, actor_actions, fixed_zsa, fixed_zs)

            actor_loss = -Q.mean()
            if self.offline:
                actor_loss = actor_loss + self.hp.lmbda * Q.abs().mean().detach() * F.mse_loss(actor, action)

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

        #########################
        # Update Iteration
        #########################
        if self.training_steps % self.hp.target_update_rate == 0:
            self.actor_target.load_state_dict(self.actor.state_dict())
            self.critic_target.load_state_dict(self.critic.state_dict())
            self.fixed_encoder_target.load_state_dict(self.fixed_encoder.state_dict())
            self.fixed_encoder.load_state_dict(self.encoder.state_dict())

            self.max_target = self.max
            self.min_target = self.min

    # If using checkpoints: run when each episode terminates
    def maybe_train_and_checkpoint(self, ep_timesteps, ep_return):
        self.eps_since_update += 1
        self.timesteps_since_update_list.append(ep_timesteps)

        self.min_return = min(self.min_return, ep_return)

        # End evaluation of current policy early
        if self.min_return < self.best_min_return:
            self.train_and_reset()

        # Update checkpoint
        elif self.eps_since_update == self.max_eps_before_update:
            self.best_min_return = self.min_return
            self.checkpoint_actor.load_state_dict(self.actor.state_dict())
            self.checkpoint_encoder.load_state_dict(self.fixed_encoder.state_dict())

            self.train_and_reset()

    # Batch training
    def train_and_reset(self):
        for element in self.timesteps_since_update_list:
            for k in range(1, element):
                if self.training_steps == self.hp.steps_before_checkpointing:
                    self.best_min_return *= self.hp.reset_weight
                    self.max_eps_before_update = self.hp.max_eps_when_checkpointing

                # calculate ck and train
                c_k = max(int(len(self.replay_buffer.buffer) * self.hp.eta **
                              (k * (self.hp.ep_max_length / element))), int(self.hp.ck_min))
                self.train(c_k)

        self.eps_since_update = 0
        self.min_return = 1e8
        self.timesteps_since_update_list = []

    def save(self, filename):
        # save critic
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")

        # save actor
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")

        # save encoder
        torch.save(self.encoder.state_dict(), filename + "_encoder")
        torch.save(self.encoder_optimizer.state_dict(), filename + "_encoder_optimizer")

        # save the checkpoint ones (actor and encoder)
        torch.save(self.checkpoint_actor.state_dict(), filename + '_checkpoint_actor')
        torch.save(self.checkpoint_encoder.state_dict(), filename + '_checkpoint_encoder')

        print("Agent saved")

    def load(self, filename):
        # load critic
        self.critic.load_state_dict(torch.load(filename + "_critic", map_location=self.device))
        self.critic_optimizer.load_state_dict(
            torch.load(filename + "_critic_optimizer", map_location=self.device))
        self.critic_target = copy.deepcopy(self.critic)

        # load actor
        self.actor.load_state_dict(torch.load(filename + "_actor", map_location=self.device))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer", map_location=self.device))
        self.actor_target = copy.deepcopy(self.actor)

        # load encoder
        self.encoder.load_state_dict(torch.load(filename + "_encoder", map_location=self.device))
        self.encoder_optimizer.load_state_dict(torch.load(filename + "_encoder_optimizer", map_location=self.device))
        self.fixed_encoder = copy.deepcopy(self.encoder)
        self.fixed_encoder_target = copy.deepcopy(self.encoder)

        # load checkpoint actor and encoder
        self.checkpoint_actor.load_state_dict(torch.load(filename + "_checkpoint_actor", map_location=self.device))
        self.checkpoint_encoder.load_state_dict(torch.load(filename + "_checkpoint_encoder", map_location=self.device))

        self.flatten_recurrent_params()

    def flatten_recurrent_params(self):
        self.encoder.zs2.flatten_parameters()
        self.encoder.zsa2.flatten_parameters()
        self.checkpoint_encoder.zs2.flatten_parameters()
        self.checkpoint_encoder.zsa2.flatten_parameters()
        self.fixed_encoder.zs2.flatten_parameters()
        self.fixed_encoder.zsa2.flatten_parameters()
        self.fixed_encoder_target.zs2.flatten_parameters()
        self.fixed_encoder_target.zsa2.flatten_parameters()
