import numpy as np
import torch as th
from torch import nn
import math

from stochastic_actor import StochasticActor
from critic import Critic
import os
from stable_baselines3.common.utils import polyak_update
from torch.distributions import Normal

from torch.utils.tensorboard import SummaryWriter

class TraininingModel(nn.Module):
    def __init__(self, space, batch_size = 256, device = 'cpu', eps = 0.1, learning_rate = 0.0003, log_dir = 'log', eval_freq = 500, total_timesteps = 100000, eval_episodes = [], depth = 6, 
                 gae_lambda: float = 1.0,
                 n_epochs: int = 10,
                 not_midpoint: bool = False,
                 alpha = 1.0,
    ):
        super().__init__()
        self.space = space
        space.to(device)
        self.eps = eps
        self.actor = StochasticActor(space.dim, learning_rate = learning_rate).to(device)
        self.critic = Critic(space.dim, learning_rate = learning_rate).to(device)
        self.batch_size = batch_size
        self.device = device
        self.log_dir = log_dir
        self.eval_freq = eval_freq
        self.eval_episodes = [[ep[0].to(device), ep[1].to(device)] for ep in eval_episodes]
        self.total_timesteps = total_timesteps
        self.depth = depth
        self.gae_lambda = gae_lambda
        self.not_midpoint = not_midpoint
        self.alpha = alpha

        self.n_epochs = n_epochs

        self.summary_writer = SummaryWriter(log_dir + "/tensorboard")
        self.best_success_count = 0
        self.best_average_cost = -np.inf


    def get_depth(self):
        divisor = self.total_timesteps // (2**(self.depth+1)-1) + 1
        return min(self.num_cycle // divisor, self.depth)
                
    def train_critic_target(self, states_0, states_1, target_v_values):
        current_v_values = self.critic(states_0, states_1)
        critic_loss = nn.functional.mse_loss(current_v_values, target_v_values)
        if self.space.symmetric:
            reverse_v_values = self.critic(states_1, states_0)
            critic_loss += nn.functional.mse_loss(reverse_v_values, current_v_values)
            
        self.critic.optimizer.zero_grad()
        critic_loss.backward()
        self.critic.optimizer.step()

        return critic_loss.item()
    
    def train_actor(self, states_0, states_1):
        middles = self.actor(states_0, states_1)
        middles = self.space.clamp(middles)
        q_value_0 = self.critic(states_0, middles)
        q_value_1 = self.critic(middles, states_1)
        if self.not_midpoint:
            actor_loss = (q_value_0+q_value_1).mean()
            
            if self.space.symmetric:
                actor_loss += self.critic(middles, self.actor(states_1, states_0)).mean()
        else:
            actor_loss = (q_value_0**2).mean()+self.alpha*(q_value_1**2).mean()
            
            if self.space.symmetric:
                actor_loss += (self.critic(middles, self.actor(states_1, states_0))**2).mean()
                
        self.actor.optimizer.zero_grad()
        actor_loss.backward()
        self.actor.optimizer.step()

        return actor_loss.item()

    def rollout(self):
        states = th.zeros((2,self.space.dim),device = self.device)
        states[0] = self.space.get_random_state()
        states[1] = self.space.get_random_state()
        states_0 = th.tensor([], device = self.device)
        states_1 = th.tensor([], device = self.device)
        depth = self.get_depth()

        for dep in range(depth):
            states_0 = th.cat((states_0, states[:-1]))
            states_1 = th.cat((states_1, states[1:]))                
            with th.no_grad():
                middles = self.actor(states[:-1], states[1:])
            middles = self.space.clamp(middles)
            new_states = th.zeros((states.shape[0]+middles.shape[0],self.space.dim), device=self.device)
            new_states[::2] = states
            new_states[1::2] = middles
            states = new_states
            
        states_0 = th.cat((states_0, states[:-1]))
        states_1 = th.cat((states_1, states[1:]))                
        states_0 = th.cat((states_0, states))
        states_1 = th.cat((states_1, states))                
        self.num_cycle+=1

        return states_0, states_1

    def calc_TD(self, states_0, states_1):
        advantages = th.zeros(states_0.shape[0], device = self.device)
        depth = round(math.log2(states_0.shape[0]//3))
        S = (2**depth-1)
        S2 = (2**(depth+1)-1)
        advantages[S:S2] = self.space.calc_deltas(states_0[S:S2], states_1[S:S2])
        self.timesteps+=S2-S
        for t in range(depth):
            d = depth - t - 1
            S = 2**d-1
            s=2**(d+1)-1
            t=s+2**(d+1)
            with th.no_grad():
                advantages[S:S+2**(d)]= (1-self.gae_lambda)*(self.critic(states_0[s:t-1:2], states_1[s:t-1:2])\
                                                             + self.critic(states_0[s+1:t:2], states_1[s+1:t:2]))\
                                                             + self.gae_lambda * (advantages[s:t-1:2]+advantages[s+1:t:2])
        return advantages
    
    def learn(self):
        self.critic_losses = []
        self.actor_losses = []
        self.evaluations = []
        self.min_deltas = []
        self.success_rates = []
        self.average_costs = []
        self.num_cycle=0
        self.timesteps=0
        self.loss_timesteps = []
        self.eval_timesteps = []
        self.best_evaluation = np.inf
        iteration = 0
        self.num_cycle = 0
        while self.timesteps < self.total_timesteps:
            states_0, states_1 = self.rollout()
            target_v_values = self.calc_TD(states_0, states_1)
            while states_0.shape[0] + 2**(self.get_depth()+1) -1 <= self.batch_size:
                a_states_0, a_states_1 = self.rollout()
                a_target_v_values = self.calc_TD(a_states_0, a_states_1)
                states_0=th.cat((states_0, a_states_0))
                states_1=th.cat((states_1, a_states_1))            
                target_v_values = th.cat((target_v_values, a_target_v_values))
            num_data = states_0.shape[0]
            for _ in range(self.n_epochs):
                for batch in range((num_data-1) //self.batch_size+1):
                    s = batch*self.batch_size
                    t = (batch+1)*self.batch_size
                    critic_loss = self.train_critic_target(states_0[s:t], states_1[s:t], target_v_values[s:t])
                    actor_loss = self.train_actor(states_0[s:t], states_1[s:t])
            self.critic_losses.append(critic_loss)
            self.actor_losses.append(actor_loss)
            self.loss_timesteps.append(self.timesteps)
            iteration+=1
            if iteration % 100 == 0:
                self.summary_writer.add_scalar('critic_loss', critic_loss, self.timesteps)
                self.summary_writer.add_scalar('actor_loss', actor_loss, self.timesteps)
            if iteration % self.eval_freq == 0:
                evaluation, min_delta, success_count, average_cost = self.evaluate()
                self.summary_writer.add_scalar('evaluation', evaluation, self.timesteps)
                self.summary_writer.add_scalar('min_delta', min_delta, self.timesteps)
                self.summary_writer.add_scalar('success_rate', success_count / len(self.eval_episodes), self.timesteps)
                self.summary_writer.add_scalar('average_cost', average_cost, self.timesteps)
                self.evaluations.append(evaluation)
                self.min_deltas.append(min_delta)
                self.success_rates.append(success_count / len(self.eval_episodes))
                self.average_costs.append(average_cost)
                self.eval_timesteps.append(self.timesteps)
                if success_count > self.best_success_count or (success_count == self.best_success_count and average_cost < self.best_average_cost):
                    self.save_model()
                    self.best_success_count = success_count
                    self.best_average_cost = average_cost
                self.save_evals()

    def save_model(self):
        th.save(self.actor.state_dict(), self.log_dir + '/actor_model.pt')
        th.save(self.critic.state_dict(), self.log_dir + '/critic_model.pt')

    def save_evals(self):
        np.savez(self.log_dir + '/train_losses.npz',
                 timesteps = self.loss_timesteps,
                 critic_losses = self.critic_losses,
                 actor_losses = self.actor_losses)
        np.savez(self.log_dir + '/evaluations.npz',
                 timesteps = self.eval_timesteps,
                 evaluations = self.evaluations,
                 success_rates = self.success_rates,
                 min_deltas = self.min_deltas,
                 average_costs = self.average_costs)

    def evaluate(self):
        length_sum = 0.
        min_delta_sum = 0.
        success_count = 0
        success_sum = 0.
        for state_0, state_1 in self.eval_episodes:
            states = th.stack((state_0, state_1),0)
            for _ in range(self.depth):
                with th.no_grad():
                    middles = self.actor(states[:-1], states[1:], deterministic = True)
                middles = self.space.clamp(middles)
                new_states = th.zeros((states.shape[0]+middles.shape[0],self.space.dim), device=self.device)
                new_states[::2] = states
                new_states[1::2] = middles
                states = new_states
            deltas = self.space.calc_deltas(states[:-1], states[1:])
            if deltas.max() < self.eps:
                success_count+=1
                success_sum+=deltas.sum().item()
            length_sum+=deltas.sum().item()
            min_delta_sum+=deltas.max().item()
        N = len(self.eval_episodes)
        if success_count == 0.:
            success_sum=np.inf
        else:
            success_sum /= success_count
        return length_sum/N, min_delta_sum/N, success_count, success_sum
