from models import *
import torch
import numpy as np
import random
import collections
from env import *
from torch.autograd import Variable
from torch import Tensor
import matplotlib.pyplot as plt


class DDPG:
    def __init__(self, state_dim, action_dim, critic_input_dim, actor_lr, critic_lr, sigma, device):
        self.actor = Actor(state_dim, action_dim).to(device)
        self.target_actor = Actor(state_dim, action_dim).to(device)
        self.critic = Critic(critic_input_dim)
        self.target_critic = Critic(critic_input_dim)
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.target_actor.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.action_dim = action_dim
        self.device = device
        self.sigma = sigma

    def choose_action(self, state, explore):
        action_np = np.zeros(2)
        action = self.actor(state)  # .clamp(-0.1, 0.1)
        # action = action.clamp(-0.1, 0.1)
        # for i in range(len(action)):
        #     action_np[i] = action[i].item()*0.1 + self.sigma * np.random.randn()
        if explore:
            return action.detach().cpu().numpy() * 0.1 + self.sigma * np.random.randn(2)
        else:
            return action.detach().cpu().numpy() * 0.1

    def soft_update(self, net, target_net, tau):
        for param_target, param in zip(target_net.parameters(), net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - tau) + param.data * tau)


class MADDPG:
    def __init__(self, agent_num, device, actor_lr, critic_lr, state_dims, action_dims, critic_input_dim, sigma, gamma,
                 tau):
        self.agents = []
        for i in range(agent_num):
            self.agents.append(
                DDPG(state_dims[i], action_dims[i], critic_input_dim, actor_lr, critic_lr, sigma, device))
        self.agent_num = agent_num
        self.gamma = gamma
        self.tau = tau
        self.critic_loss = torch.nn.MSELoss()
        self.device = device
        self.state_dims = state_dims

    @property
    def policies(self):
        return [agt.actor for agt in self.agents]

    @property
    def target_policies(self):
        return [agt.target_actor for agt in self.agents]

    def choose_action(self, states, explore):
        states = [
            torch.tensor(states[i], dtype=torch.float, device=self.device)
            for i in range(self.agent_num)
        ]
        return [
            agent.choose_action(state, explore).astype(float)
            for agent, state in zip(self.agents, states)
        ]

    def update(self, sample, agent_i):
        state, action, reward, next_state, done = sample
        current_agent = self.agents[agent_i]
        current_agent.critic_optimizer.zero_grad()
        all_target_actions = [pi(_next_state)
                              for pi, _next_state in zip(self.target_policies, next_state)]
        target_critic_input = torch.cat((*next_state, *all_target_actions), dim=1)
        target_critic_value = (
                reward[agent_i].view(-1, 1) + self.gamma * current_agent.target_critic(target_critic_input)
                * (1 - done[agent_i].view(-1, 1)))
        critic_input = torch.cat((*state, *action), dim=1)
        critic_value = current_agent.critic(critic_input)
        critic_loss = self.critic_loss(critic_value, target_critic_value.detach())
        critic_loss.backward()
        current_agent.critic_optimizer.step()

        current_agent.actor_optimizer.zero_grad()
        all_actions = []
        current_action_i = current_agent.actor(state[agent_i])
        for i, (pi, s) in enumerate(zip(self.policies, state)):
            if i == agent_i:
                all_actions.append(current_action_i)
            else:
                all_actions.append(pi(s))

        actor_loss_input = torch.cat((*state, *all_actions), dim=1)
        actor_loss = -current_agent.critic(actor_loss_input).mean()
        actor_loss += (current_action_i ** 2).mean() * 1e-3
        actor_loss.backward()
        current_agent.actor_optimizer.step()

    def update_all_targets(self):
        for agent in self.agents:
            agent.soft_update(agent.actor, agent.target_actor, self.tau)
            agent.soft_update(agent.critic, agent.target_critic, self.tau)


class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done

    def size(self):
        return len(self.buffer)


def evaluate(agent_num, maddpg, num_trajectories, trajectory_length, env):
    rewards = np.zeros(agent_num)
    trajs = []
    count = 0
    done = np.zeros(2)
    collision = []
    collision_count = -1
    for i in range(num_trajectories):
        done = np.zeros(2)
        state = env.reset()
        traj = []
        collision_count = 0
        for j in range(trajectory_length):

            actions = maddpg.choose_action(state, False)

            traj.append((state, actions))
            next_state, reward, done = env.step(state, actions)
            rewards += np.array(reward)
            if ((state[0][0] - state[1][0]) ** 2 + (state[0][1] - state[1][1]) ** 2) ** 0.5 <= 0.1:
                # print(((state[0][0] - state[1][0]) ** 2 + (state[0][1] - state[1][1]) ** 2) ** 0.5)
                collision_count = 1
            state = next_state
            # if min(done) == 1:
            #     count += 1
            #     #traj.append(state)
            # print('################################')
            #     #traj = get_trajectories(maddpg, 1, 100, env)
            #     break
        if done[0] == 1 and done[1] == 1:
            count += 1

        trajs.append(traj)
        collision.append(collision_count)
    rewards = rewards / num_trajectories
    return count, rewards.tolist(), trajs, collision


def get_trajectories(maddpg, num_trajectories, trajectory_length, env):
    trajectories = []
    for i in range(num_trajectories):
        state = env.reset()
        learner_x = []
        learner_y = []
        expert_x = []
        expert_y = []
        for j in range(trajectory_length):
            actions = maddpg.choose_action(state, False)
            learner_x.append(state[0][0])
            learner_y.append(state[0][1])
            expert_x.append(state[1][0])
            expert_y.append(state[1][1])
            next_state, reward, done = env.step(state, actions)
            state = next_state
        plt.plot(learner_x, learner_y, '+')
        plt.plot(expert_x, expert_y, 'x')
        theta = np.linspace(0, 2 * np.pi, 100)
        plt.plot(0 + 0.05 * np.cos(theta), 0.5 + 0.05 * np.sin(theta))
        plt.plot(-0.5 + 0.05 * np.cos(theta), 0 + 0.05 * np.sin(theta))
        plt.show()
    return trajectories


def train(theta_1, theta_2, feedback, physical, delta_1, delta_2, save, iteration_number):
    num_episodes = 5000
    episode_length = 40
    buffer_size = 100000
    actor_lr = 1e-2
    critic_lr = 1e-2
    gamma = 0.95
    sigma = 0.03
    tau = 1e-2
    batch_size = 1024
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    update_interval = 100
    minimal_size = 2000
    replay_buffer = ReplayBuffer(buffer_size)
    state_dims = np.array([2, 2])
    action_dims = np.array([2, 2])
    agent_num = 2
    critic_input_dim = sum(state_dims) + sum(action_dims)
    maddpg = MADDPG(agent_num, device, actor_lr, critic_lr, state_dims, action_dims, critic_input_dim, sigma, gamma,
                    tau)
    total_step = 0
    traj = []
    count = 0
    env = ENV(theta_1, theta_2, feedback, physical, delta_1, delta_2)
    for ep in range(num_episodes):
        state = env.reset()
        for j in range(episode_length):
            actions = maddpg.choose_action(state, True)
            # print(actions)
            next_state, reward, done = env.step(state, actions)
            replay_buffer.add(state, actions, reward, next_state, done)
            state = next_state
            # print(state)
            # print(done)
            total_step += 1

            if replay_buffer.size(
            ) >= minimal_size and total_step % update_interval == 0:
                sample = replay_buffer.sample(batch_size)

                # print(sample)
                # print("##############")

                def stack_array(x):
                    rearranged = [[sub_x[i] for sub_x in x]
                                  for i in range(len(x[0]))]
                    # print(rearranged)
                    # print("##############")
                    return [
                        torch.FloatTensor(np.vstack(aa)).to(device)
                        for aa in rearranged
                    ]

                new_sample = [stack_array(x) for x in sample]
                for a in range(agent_num):
                    maddpg.update(new_sample, a)
                maddpg.update_all_targets()
        # if (ep + 1) % 100 == 0:
        #     count, return_ep, traj, collision = evaluate(agent_num, maddpg, 10, 40, env)
        #     print(count, return_ep, collision)
            # if count >= 9:
            #     break
    # if count == 0:
    # traj = get_trajectories(maddpg, 10, 40, env)
    _, _, traj, _ = evaluate(agent_num, maddpg, 10, 40, env)
    if save:
        actors = maddpg.policies
        file_name_learner_actor = 'learner_actor_' + str(iteration_number) + '.pth'
        file_name_expert_actor = 'expert_actor_' + str(iteration_number) + '.pth'
        torch.save(actors[0].state_dict(), file_name_learner_actor)
        torch.save(actors[1].state_dict(), file_name_expert_actor)
    return traj
