import copy
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb

import utils


class ReplayBuffer(object):
    def __init__(self, state_dim, action_dim, max_size=int(1e6), device=None):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0

        self.state = np.zeros((max_size, state_dim))
        self.action = np.zeros((max_size, action_dim))
        self.next_state = np.zeros((max_size, state_dim))
        self.reward = np.zeros((max_size, 1))
        self.not_done = np.zeros((max_size, 1))

        # if device is None:
        #     self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # else:
        self.device = device

    def add(self, state, action, next_state, reward, done):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        return (
            torch.FloatTensor(self.state[ind]).to(self.device),
            torch.FloatTensor(self.action[ind]).to(self.device),
            torch.FloatTensor(self.next_state[ind]).to(self.device),
            torch.FloatTensor(self.reward[ind]).to(self.device),
            torch.FloatTensor(self.not_done[ind]).to(self.device)
        )

    def convert_data(self, dataset):
        self.state = dataset['observations']
        self.action = dataset['actions']
        self.next_state = dataset['next_observations']
        self.reward = dataset['rewards'].reshape(-1, 1)
        self.not_done = 1. - dataset['terminals'].reshape(-1, 1)
        self.size = self.state.shape[0]

    def normalize_states(self, eps=1e-3):
        mean = self.state.mean(0, keepdims=True)
        std = self.state.std(0, keepdims=True) + eps
        self.state = (self.state - mean) / std
        self.next_state = (self.next_state - mean) / std
        return mean, std




class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, action_dim)

        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def Q1(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1


class TD3(object):
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            discount=0.99,
            tau=0.005,
            policy_noise=0.2,
            noise_clip=0.5,
            policy_freq=2,
            normalize=True,
            mean=0,
            std=1,
            device=torch.device('cpu'),
            use_wandb=False
    ):
        self.device = device

        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.total_it = 0

        # Added parameters:
        self.normalize = normalize
        self.state_norm_mean = mean
        self.state_norm_std = std

        self.use_wandb = use_wandb

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        return self.actor(state).cpu().data.numpy().flatten()

    def learn(self, replay_buffer, batch_size=256):
        self.total_it += 1

        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

        if self.normalize:
            mean = torch.tensor(self.state_norm_mean).to(state.device)
            std = torch.tensor(self.state_norm_std).to(state.device)
            state = (state - mean) / std
            next_state = (next_state - mean) / std

        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:
            # Compute actor losses
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            debug = True
            if debug:
                total_norm = 0.0
                for p in self.actor.parameters():
                    param_norm = p.grad.detach().data.norm(2)
                    total_norm += param_norm.item() ** 2
                total_norm = total_norm ** 0.5
                if self.use_wandb:
                    wandb.log({'actor_loss': actor_loss.item(), 'total_norm': total_norm})

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def save(self, dir_path):
        torch.save(self.critic.state_dict(), os.path.join(dir_path, "critic"))
        torch.save(self.critic_optimizer.state_dict(), os.path.join(dir_path, "critic_optimizer"))

        torch.save(self.actor.state_dict(), os.path.join(dir_path, "actor"))
        torch.save(self.actor_optimizer.state_dict(), os.path.join(dir_path, "actor_optimizer"))

    def load(self, dir_path):
        self.critic.load_state_dict(torch.load(os.path.join(dir_path, "critic")))
        self.critic_optimizer.load_state_dict(torch.load(os.path.join(dir_path, "critic_optimizer")))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(os.path.join(dir_path, "actor")))
        self.actor_optimizer.load_state_dict(torch.load(os.path.join(dir_path, "actor_optimizer")))
        self.actor_target = copy.deepcopy(self.actor)


class TD3_Wrapper:
    def __init__(self, env, eval_env, config, agent_path, evaluations_path):
        self.eval_env = eval_env
        self.env = env
        self.config = config
        self.agent_path = agent_path
        self.evaluations_path = evaluations_path

        # Set seeds
        # self.env.seed()
        # self.eval_env.seed()
        # self.env.action_space.seed(self.config.system.seed)
        # torch.manual_seed(self.config.system.seed)
        # np.random.seed(self.config.system.seed)

        if config.system.cpu or not torch.cuda.is_available():
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda')

    def eval_policy(self, policy, epoch, mean, std, eval_episodes=30):
        all_rewards = []
        for _ in range(eval_episodes):
            state, done = self.eval_env.reset(), False
            ep_reward = 0.0
            while not done:
                if self.config.policy.normalize:
                    action = policy.select_action((np.array(state) - mean) / std)
                else:
                    action = policy.select_action(np.array(state))
                state, reward, done, _ = self.eval_env.step(action)
                ep_reward += reward

            all_rewards.append(ep_reward)

        avg_reward, std_reward, avg_norm_reward, std_norm_reward = utils.get_eval_statistics(all_rewards, self.config.env.eval_env)

        if epoch > -1:
            print("---------------------------------------")
            print(f"Epoch {epoch}: Evaluation over {eval_episodes} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
            print("---------------------------------------")
            if self.config.wandb.enable:
                wandb.log(
                    {'eval mean reward': avg_reward, 'avg_norm_reward': avg_norm_reward, 'eval std reward': std_reward,
                     'epochs': epoch})

        return all_rewards

    def test_policy(self):
        state_dim = self.env.observation_space.shape[0]
        action_dim = self.env.action_space.shape[0]
        max_action = float(self.env.action_space.high[0])

        state_norm_mean = 0.0
        state_norm_std = 1.0

        kwargs = {"state_dim": state_dim, "action_dim": action_dim, "max_action": max_action,
                  "discount": self.config.policy.discount, "tau": self.config.policy.tau,
                  "policy_noise": self.config.policy.policy_noise * max_action,
                  "noise_clip": self.config.policy.noise_clip * max_action,
                  "policy_freq": self.config.policy.policy_freq,
                  'normalize': self.config.policy.normalize,
                  'mean': state_norm_mean,
                  'std': state_norm_std,
                  'device': self.device,
                  'use_wandb': self.config.wandb.enable}

        policy = TD3(**kwargs)

        policy.load(self.agent_path)
        all_rewards = self.eval_policy(policy, -1, state_norm_mean, state_norm_std)
        avg_reward, std_reward, avg_norm_reward, std_norm_reward = utils.get_eval_statistics(all_rewards, self.config.env.type)
        print("---------------------------------------")
        print(f"Evaluation done. Computed on {len(all_rewards)} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
        print("---------------------------------------")

    def train(self):
        state_dim = self.env.observation_space.shape[0]
        action_dim = self.env.action_space.shape[0]
        max_action = float(self.env.action_space.high[0])

        state_norm_mean = 0.0
        state_norm_std = 1.0

        kwargs = {"state_dim": state_dim, "action_dim": action_dim, "max_action": max_action,
                  "discount": self.config.policy.discount, "tau": self.config.policy.tau,
                  "policy_noise": self.config.policy.policy_noise * max_action,
                  "noise_clip": self.config.policy.noise_clip * max_action,
                  "policy_freq": self.config.policy.policy_freq,
                  'normalize': self.config.policy.normalize,
                  'mean': state_norm_mean,
                  'std': state_norm_std,
                  'device': self.device,
                  'use_wandb': self.config.wandb.enable}

        policy = TD3(**kwargs)

        if self.config.load_model:
            # policy_file = file_name if config.load_model == "default" else config.load_model
            policy.load(self.agent_path)

        replay_buffer = ReplayBuffer(state_dim, action_dim, device=self.device)

        # Evaluate untrained policy
        evaluations = [self.eval_policy(policy, 0, state_norm_mean, state_norm_std)]

        state, done = self.env.reset(), False
        episode_reward = 0
        episode_timesteps = 0
        episode_num = 0

        for t in range(int(self.config.train.max_timesteps)):
            episode_timesteps += 1

            # Select action randomly or according to policy
            if t < self.config.train.start_timesteps:
                action = self.env.action_space.sample()
            else:
                action = (
                        policy.select_action(np.array(state))
                        + np.random.normal(0, max_action * self.config.policy.expl_noise, size=action_dim)
                ).clip(-max_action, max_action)

            # Perform action
            next_state, reward, done, _ = self.env.step(action)
            max_steps = self.env.spec.max_episode_steps
            done_bool = float(done) if episode_timesteps < max_steps else 0

            # Store data in replay buffer
            replay_buffer.add(state, action, next_state, reward, done_bool)

            state = next_state
            episode_reward += reward

            # Train agent after collecting sufficient data
            if t >= self.config.train.start_timesteps:
                policy.learn(replay_buffer, self.config.train.batch_size)

            if done:
                # Reset environment
                state, done = self.env.reset(), False
                episode_reward = 0
                episode_timesteps = 0
                episode_num += 1

            # Evaluate episode
            if t >= self.config.train.start_timesteps and (t + 1) % self.config.train.eval_freq == 0:
                evaluations.append(self.eval_policy(policy, t + 1, state_norm_mean, state_norm_std))
                if self.config.save_model:
                    os.makedirs(self.agent_path, exist_ok=True)
                    policy.save(self.agent_path)
                    torch.save(np.array(evaluations), self.evaluations_path)

            if t >= self.config.train.start_timesteps and self.config.save_model and (t + 1) % self.config.train.save_freq == 0:
                path = os.path.join(self.agent_path, f'num_steps_{t+1}_reward_{int(np.mean(evaluations[-1]))}')
                os.makedirs(path, exist_ok=True)
                policy.save(path)
                print("#######################################")
                print(f"Saving policy on time step {t+1}, with mean reward {np.mean(evaluations[-1]):.3f}")
                print("#######################################")
