from collections import deque, namedtuple
import math
import random
import torch as torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from scipy.spatial import distance_matrix
import gym
import d4rl
import matplotlib.pyplot as plt
import matplotlib
from two_create_dataset import *
from two_networks import *
from two_running_stats import *
from timeit import default_timer as timer
from two_sac import SAC
import sys
from torch.utils.data import TensorDataset, DataLoader

class Planning_Agent():
    def __init__(self, env_name, device, q_name2, epochs = 40, batch_size = 512, lr = 0.001, tau = 0.03, discount = 0.99, H = 10, N_traj = 100, beta = 0.0, kappa = 1, normalize = False, var_scale = 1, sphere_norm=True):
        self.debug_n = 1000
        self.q_name2 = q_name2
        self.K_Q=10
        self.epochs = epochs
        self.device = device
        self.batch_size = batch_size
        self.lr = lr
        self.tau = tau
        self.discount = discount
        self.H = H
        self.N_traj = N_traj
        self.beta = beta
        self.kappa = kappa
        self.sigma = 0.1
        self.normalize = normalize
        self.var_scale = var_scale
        self.sphere_norm=sphere_norm

        self.env_name = env_name
        self.env = gym.make(self.env_name)
        self.env.reset()
        
        self.offline_dataset = OfflineDataset(device, self.env, self.normalize)

        self.state_dim = len(self.offline_dataset.data_s1_f_m[0])
        self.action_dim = len(self.offline_dataset.data_a1_f_m[0])

        self.planned_traj = np.zeros((self.H, self.action_dim))
        self.q_fn = QNetwork(self.device, self.lr, self.state_dim, self.action_dim, neurons=500)
        self.q_fn_target = QNetwork(self.device, self.lr, self.state_dim, self.action_dim, neurons=500)
        self.f_b_s = StochasticMLP(self.device, self.lr, self.state_dim, self.action_dim)

        self.latent_state_dim = self.state_dim-2
        self.q_fn_latent = LatentQMLP(self.device, self.lr, self.latent_state_dim, self.action_dim)
        self.q_fn_latent_target = LatentQMLP(self.device, self.lr, self.latent_state_dim, self.action_dim)
        self.f_b_s_latent = LatentStochasticPolicyMLP(self.device, self.lr, self.latent_state_dim, self.action_dim)

        self.latent_dynamics_model = LatentDynamicsModel(self.device, self.lr, self.state_dim, self.latent_state_dim, self.action_dim, neurons=500, sphere_norm=self.sphere_norm)

        self.f_m = DynamicsMLP(self.device, self.lr, self.state_dim, self.action_dim)

        

        self.running_stats_b = RunningStats()
        self.running_stats_c = RunningStats()
        
        
        self.first_action = False

    
    def reset(self):
        self.planned_traj = np.zeros((self.H, self.action_dim))
        self.running_stats_b.reset()
        self.running_stats_c.reset()
        self.first_action = False
    
    def normalize_states(self, states, denormalize = False):
        if self.normalize:
            if denormalize:
                states = states*self.offline_dataset.state_std + self.offline_dataset.state_mean
            else:
                states = (states - self.offline_dataset.state_mean)/self.offline_dataset.state_std
        return states
    
    def normalize_actions(self, actions, denormalize = False):
        if self.normalize:
            if denormalize:
                actions = actions*self.offline_dataset.action_std + self.offline_dataset.action_mean
            else:
                actions = (actions - self.offline_dataset.action_mean)/self.offline_dataset.action_std
        return actions

    def normalize_rewards(self, rewards, denormalize = False):
        if self.normalize:
            if denormalize:
                rewards = rewards*self.offline_dataset.reward_std + self.offline_dataset.reward_mean
            else:
                rewards = (rewards - self.offline_dataset.reward_mean)/self.offline_dataset.reward_std
        return rewards
    
    def encode_state(self, state_tensor, decode = False):
        with torch.no_grad():
            if decode:
                state_tensor = self.latent_dynamics_model.state_decoder(state_tensor)
            else:
                state_tensor = self.latent_dynamics_model.state_encoder(state_tensor)
        return state_tensor
    
    def train_latent_dynamics_model(self):
        max_n = self.epochs * int(self.offline_dataset.n_data_f_m/self.batch_size)
        loss_function = nn.MSELoss()
        curr_batch = 0
        self.tau /= 0.2
        self.latent_dynamics_model.update_target_network(1.0)
        dataset = TensorDataset(self.offline_dataset.data_s1_f_m, self.offline_dataset.data_a1_f_m, self.offline_dataset.data_r_f_m, self.offline_dataset.data_s2_f_m)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        for epoch in range(self.epochs):
            if epoch % (self.epochs/2) == 0:
                self.tau *= 0.2
            
        
            for i, (batch_states, batch_actions, batch_rewards, batch_next_states) in enumerate(dataloader):
                self.latent_dynamics_model.optimizer.zero_grad()
                self.latent_dynamics_model.update_target_network(self.tau)
        
                latent_states = self.latent_dynamics_model.state_encoder(batch_states)
                decoded_latent_states = self.latent_dynamics_model.state_decoder(latent_states)
                predicted_latent_next_states, predicted_rewards = self.latent_dynamics_model.dynamics_model(latent_states, batch_actions)
        
                with torch.no_grad():
                    target_latent_next_states = self.latent_dynamics_model.state_encoder_target(batch_next_states)

                f_m_loss = loss_function(target_latent_next_states, predicted_latent_next_states)
                r_loss = loss_function(batch_rewards.squeeze(), predicted_rewards.squeeze())
                state_decoder_loss = loss_function(batch_states, decoded_latent_states)
                loss = 3*f_m_loss + 0.5 * r_loss + state_decoder_loss
        
                if curr_batch % self.debug_n == 0:
                    print("-------------")
                    print("iteration " + str(curr_batch) +"/"+str(max_n)+" (epoch "+str(epoch)+"):")
                    print("loss: " + str(loss.item()))
                    print("f_m_loss: " + str(f_m_loss.item()))
                    print("r_loss: " + str(r_loss.item()))
                    print("state_decoder_loss: " + str(state_decoder_loss.item()))
        
                loss.backward()
                curr_batch += 1

                self.latent_dynamics_model.optimizer.step()
        filename = "_latent"
        filename += "_" + self.env_name
        if self.normalize:
            filename += "_normalized"
        #if self.sphere_norm:
        #    filename+="_snorm"
        torch.save(self.latent_dynamics_model.state_dict(), "f_ms/dynamics"  + filename + ".pth")

    def train_latent_f_b_s(self):
        curr_batch = 0
        self.batch_size=512
        max_n = self.epochs * int(self.offline_dataset.n_data_f_b/self.batch_size)
        with torch.no_grad():
            dataset = TensorDataset(self.encode_state(self.offline_dataset.data_s1_f_b), self.offline_dataset.data_a1_f_b)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        for epoch in range(self.epochs):
            for i, (batch_states, batch_actions) in enumerate(dataloader):
                self.f_b_s_latent.optimizer.zero_grad()
                dist = self.f_b_s_latent(batch_states)
                loss = -dist.log_prob(batch_actions).mean()
                loss += 0.01 * torch.mean(self.f_b_s_latent.max_logstd) - 0.01 * torch.mean(self.f_b_s_latent.min_logstd)
                if curr_batch%self.debug_n==0:
                    print("-------------")
                    print("iteration " + str(curr_batch) +"/"+str(max_n)+" (epoch "+str(epoch)+"):")
                    print("loss: " + str(loss.item()))
                loss.backward()
                self.f_b_s_latent.optimizer.step()
                curr_batch += 1
        filename = "_latent"
        filename += "_" + self.env_name
        if self.normalize:
            filename += "_normalized"
        #if self.sphere_norm:
        #    filename+="_snorm"
        torch.save(self.f_b_s_latent.state_dict(), "f_bs/f_b_s"  + filename + ".pth")
    
    def train_latent_q_network(self):
        q_vals = []
        self.batch_size=512
        max_n = self.epochs * int(self.offline_dataset.n_data_q/self.batch_size)
        curr_batch = 0
        loss_function = nn.MSELoss()
        with torch.no_grad():
            dataset = TensorDataset(self.encode_state(self.offline_dataset.data_s1_q), self.offline_dataset.data_a1_q, self.offline_dataset.data_r_q, self.encode_state(self.offline_dataset.data_s2_q), self.offline_dataset.data_a2_q, self.offline_dataset.data_d_q)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        with torch.no_grad():
            initial_state = self.encode_state(self.offline_dataset.data_s1_q)[0]
            initial_action = self.offline_dataset.data_a1_q[0]
  
        self.tau = 0.005
        self.tau /= 0.2
        for epoch in range(self.epochs):
            if epoch % (self.epochs/2) == 0:
                self.tau *= 0.2
            for i, (batch_states, batch_actions, batch_rewards, batch_next_states, batch_next_actions, batch_terminals) in enumerate(dataloader):
                self.q_fn_latent.optimizer.zero_grad()
                with torch.no_grad():
                    batch_rewards=self.normalize_rewards(batch_rewards, True)
                    q_targs = self.q_fn_latent_target(batch_next_states, batch_next_actions)
                        
                batch_targets = batch_rewards + self.discount * (1 - batch_terminals) * q_targs
                batch_eval = self.q_fn_latent(batch_states, batch_actions)
                loss = loss_function(batch_targets, batch_eval)
                if curr_batch%50==0:
                    with torch.no_grad():
                        q_vals.append(self.q_fn_latent(initial_state, initial_action).item())
                if curr_batch%500==0:
                    print("-------------")
                    print("iteration " + str(curr_batch) +"/"+str(max_n)+" (epoch "+str(epoch)+"):" + str(q_vals[len(q_vals) - 1]))
                    print("loss: " + str(loss.item()))
                
                loss.backward()
                self.q_fn_latent.optimizer.step()
                curr_batch += 1
                
                curr_state_dict = self.q_fn_latent.state_dict()
                target_state_dict = self.q_fn_latent_target.state_dict()
                with torch.no_grad():
                    for key in target_state_dict:
                        target_state_dict[key].mul_(1-self.tau)
                        target_state_dict[key].add_(curr_state_dict[key].mul(self.tau))
                self.q_fn_latent_target.load_state_dict(target_state_dict)
        filename = "_latent"
        filename += "_" + self.env_name
        if self.normalize:
            filename += "_normalized"
        #if self.sphere_norm:
        #    filename+="_snorm"
        torch.save(self.q_fn_latent.state_dict(), "q_fns/q_fn" + filename + ".pth")
    
    def init(self, train_q, train_f_b, train_f_m):
        filename = "_latent"
        filename += "_" + self.env_name
        if self.normalize:
            filename += "_normalized"
        if train_f_m:
            self.train_latent_dynamics_model()
        else:
            self.latent_dynamics_model.load_state_dict(torch.load("f_ms/dynamics"  + filename + ".pth", "cuda"))
        if train_f_b:
            self.train_latent_f_b_s()
        else:
            self.f_b_s_latent.load_state_dict(torch.load("f_bs/f_b_s"  + filename + ".pth", "cuda"))
        if train_q:
            self.train_latent_q_network()
        else:
            self.q_fn_latent.load_state_dict(torch.load("q_fns/q_fn" + filename + ".pth", "cuda"))
        return
    
    def sample_action(self, state_tensor):
        with torch.no_grad():
            dist = self.f_b_s_latent(state_tensor, self.var_scale)
            action = dist.sample()
        return action
    
    def predict_dynamics(self, state_tensor, action_tensor):
        with torch.no_grad():
            new_state, R = self.latent_dynamics_model.dynamics_model(state_tensor, action_tensor)
        return new_state, R

    def terminal_cost(self, observation, action):
        with torch.no_grad():
            cost = (self.q_fn_latent(observation, action)).squeeze(1)
        return cost

    def optimize_trajectories(self, R_n, A_n):
        R_n = self.kappa * R_n
        R_n = R_n - np.max(R_n)
        expR_n = np.exp(R_n)
        return np.average(A_n, weights=expR_n, axis=0)
    
    def plan_action_agent(self, observation, agent_c):
        count_b = 0
        count_c = 0
        if not self.first_action:
            count_b = int(self.N_traj / 2)
            count_c = int(self.N_traj / 2)
        else:
            mean1, std_dev1 = self.running_stats_b.running_mean(), self.running_stats_b.running_std_dev()
            mean2, std_dev2 = self.running_stats_c.running_mean(), self.running_stats_c.running_std_dev()

            samples1 = torch.normal(mean1, std_dev1, size=(self.N_traj,))
            samples2 = torch.normal(mean2, std_dev2, size=(self.N_traj,))

            count_b = torch.sum(samples1 > samples2).item()
            count_c = self.N_traj - count_b
        self.first_action = True
        R_b = torch.tensor(np.zeros((count_b)), dtype=torch.float32).to(self.device)
        state_tensor_b = self.encode_state(self.normalize_states(torch.tensor(observation, dtype=torch.float32).repeat(count_b, 1).to(self.device)))
        A_b = torch.tensor(np.zeros((count_b,self.H,self.action_dim)), dtype=torch.float32).to(self.device)
        R_c = torch.tensor(np.zeros((count_c)), dtype=torch.float32).to(self.device)
        state_tensor_c = self.encode_state(self.normalize_states(torch.tensor(observation, dtype=torch.float32).repeat(count_c, 1).to(self.device)))
        A_c = torch.tensor(np.zeros((count_c,self.H,self.action_dim)), dtype=torch.float32).to(self.device)
        prev_planned_traj = torch.tensor(self.planned_traj, dtype=torch.float32).to(self.device)
        for t in range(0, self.H):
            with torch.no_grad():
                if count_b > 0:
                    action_bs = self.sample_action(state_tensor_b.unsqueeze(1).repeat(1, self.K_Q, 1))
                    q_bs = self.terminal_cost(state_tensor_b.unsqueeze(1).repeat(1, self.K_Q, 1), action_bs)
                    _, action_ind = torch.max(q_bs, dim=1)
                    action_b = action_bs[torch.arange(0,count_b), action_ind.view(-1)]
                    planned_actions = (1-self.beta)*self.normalize_actions(action_b, True) + self.beta * prev_planned_traj[t].repeat(count_b, 1)
                    state_tensor_b, R = self.predict_dynamics(state_tensor_b, self.normalize_actions(planned_actions))
                    A_b[:,t] = planned_actions
                    R = self.normalize_rewards(R, True)[0]
                    R_b += R
                if count_c > 0:
                    if t==0:
                        state_tensor_c_temp=torch.tensor(observation, dtype=torch.float32).repeat(count_c, 1).to(self.device)
                    else:
                        state_tensor_c_temp = self.normalize_states(self.encode_state(state_tensor_c, True), True)
                    action_c = agent_c.policy(state_tensor_c_temp)
                    action_cs = action_c.unsqueeze(1).repeat(1, self.K_Q, 1)
                    action_cs = self.normalize_actions(action_cs)
                    action_cs = action_cs + self.sigma * torch.as_tensor(np.random.normal(0, 1, action_cs.shape), dtype=torch.float32, device=self.device)
                    action_cs = self.normalize_actions(action_cs, True)
                    q_cs = agent_c.critic(state_tensor_c_temp.unsqueeze(1).repeat(1, self.K_Q, 1), action_cs)[0]
                    _, action_ind = torch.max(q_cs, dim=1)
                    action_c = action_cs[torch.arange(0,count_c), action_ind.view(-1)]
                    planned_actions = (1-self.beta)*action_c + self.beta * prev_planned_traj[t].repeat(count_c, 1)
                    A_c[:,t] = planned_actions
                    state_tensor_c, R = self.predict_dynamics(state_tensor_c, self.normalize_actions(planned_actions))
                    R = self.normalize_rewards(R, True)[0]
                    R_c += R
            
        with torch.no_grad():
            if count_b > 0:
                action_last_b = self.sample_action(state_tensor_b)
                V_b = self.terminal_cost(state_tensor_b, action_last_b)
                R_b += V_b
            if count_c > 0:
                state_tensor_c_temp = self.normalize_states(self.encode_state(state_tensor_c, True), True)
                action_last_c = agent_c.policy(state_tensor_c_temp)
                V_c = agent_c.critic(state_tensor_c_temp, action_last_c)[0].squeeze()
                R_c += V_c

        self.running_stats_b.update(R_b)
        self.running_stats_c.update(R_c)
        R_n = torch.cat((R_b,R_c), dim=0).to("cpu").numpy()
        A_n = torch.cat((A_b,A_c), dim=0).to("cpu").numpy()
        self.planned_traj = self.optimize_trajectories(R_n, A_n)
        return self.planned_traj[0]