import os
import time
from typing import List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import rl.successor_features.pytorch_util as ptu
from rl.successor_features.risk import distortion_de
from rl.rl_algorithm import RLAlgorithm
from rl.successor_features.dgpi import DGPI
from rl.utils.buffer import ReplayBuffer
from rl.utils.nets import (mlp, FlattenMlp)
from rl.utils.utils import (get_tau, dsf_reward_eval_mo, huber, layer_init,
                            linearly_decaying_epsilon, polyak_update)
from torch.utils.tensorboard import SummaryWriter


class LinearSchedule(object):

    def __init__(self, schedule_timesteps, initial=1., final=0.):
        self.schedule_timesteps = schedule_timesteps
        self.final = final
        self.initial = initial

    def __call__(self, t):
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial + fraction * (self.final - self.initial)


def softmax(x):
    return F.softmax(x, dim=-1)


class DPsi(nn.Module):
    def __init__(self,
                obs_dim,
                action_dim,
                phi_dim,
                embedding_size=64,
                num_quantiles=32,
                layer_norm=True,
                hidden_sizes=[256,1024],
                **kwargs,):

        super(DPsi, self).__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.phi_dim = phi_dim
        self.layer_norm = layer_norm
        self.base_fc = []
        input_size = self.obs_dim
        last_size = input_size
        for next_size in hidden_sizes[:-1]:
            self.base_fc += [
                nn.Linear(last_size, next_size),
                nn.LayerNorm(next_size) if layer_norm else nn.Identity(),
                nn.ReLU(inplace=True),
            ]
            last_size = next_size
        self.base_fc = nn.Sequential(*self.base_fc)
        self.num_quantiles = num_quantiles
        self.embedding_size = embedding_size
        self.tau_fc = nn.Sequential(
            nn.Linear(embedding_size, last_size),
            nn.LayerNorm(last_size) if layer_norm else nn.Identity(),
            nn.Sigmoid(),
        )
        self.merge_fc = nn.Sequential(
            nn.Linear(last_size, hidden_sizes[-1]),
            nn.LayerNorm(hidden_sizes[-1]) if layer_norm else nn.Identity(),
            nn.ReLU(inplace=True),
        )
        self.last_fc = nn.Linear(hidden_sizes[-1], self.action_dim*self.phi_dim)
        self.const_vec = ptu.from_numpy(np.arange(1, 1 + self.embedding_size))

    def forward(self, state, tau):
        """
        Calculate Quantile Value in Batch
        tau: quantile fractions
        """
        state = self.base_fc(state)  

        x = th.cos(tau.unsqueeze(-1) * self.const_vec * np.pi)  
        x = self.tau_fc(x)  

        state = th.mul(x, state.unsqueeze(-2))  
        state = self.merge_fc(state)  
        output = self.last_fc(state)
        output = output.view(-1, self.num_quantiles, self.action_dim, self.phi_dim)  
        return output.transpose(1,2) 


class RDSFOLS(RLAlgorithm):

    def __init__(self,
                env,
                dgpi: DGPI = None,
                learning_rate: float = 3e-4,
                initial_epsilon: float = 0.01,
                final_epsilon: float = 0.01,
                epsilon_decay_steps: int = None,  
                tau_target: float = 1.0,
                target_net_update_freq: int = 1000,
                buffer_size: int = int(1e6),
                embedding_size: int = 64,
                num_quantiles: int = 32,
                layer_norm: bool = True,
                hidden_sizes: List = [256, 1024],
                fp_hidden_sizes: List = [128, 128],
                model_arch: List = [200, 200, 200, 200],
                tau_type: str = 'iqn',
                lr_fp: float = float(1e-5), 
                risk_type: str = 'cvar', 
                risk_param: float = 0.25,
                risk_param_final: bool = None,
                risk_schedule_timesteps: int = 1,
                batch_size: int = 256,
                fp: bool = None,
                target_fp: bool = None,
                learning_starts: int = 100,
                gradient_updates: int = 1,
                gamma: float = 0.99,
                min_priority: float = 1.0,
                project_name: str = 'RDSFOLS',
                experiment_name: str = 'RDSFOLS',
                device: Union[th.device, str] = 'auto'):

        super(RDSFOLS, self).__init__(env, device)
        self.dgpi = dgpi
        self.phi_dim = len(self.env.w)
        self.learning_rate = learning_rate
        self.initial_epsilon = initial_epsilon
        self.epsilon = initial_epsilon
        self.epsilon_decay_steps = epsilon_decay_steps
        self.final_epsilon = final_epsilon
        self.tau_target = tau_target
        self.target_net_update_freq = target_net_update_freq
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.embedding_size = embedding_size
        self.num_quantiles = num_quantiles
        self.layer_norm = layer_norm
        self.hidden_sizes = hidden_sizes
        self.model_arch = model_arch
        self.tau_type = tau_type
        self.lr_fp = lr_fp
        self.risk_type = risk_type
        self.risk_param = risk_param
        self.risk_param_final = risk_param_final
        self.risk_schedule_timesteps = risk_schedule_timesteps
        self.learning_starts = learning_starts
        self.batch_size = batch_size
        self.gradient_updates = gradient_updates
        
        
        self.dpsi_net = DPsi(self.observation_dim, self.action_dim, self.phi_dim, embedding_size=embedding_size, num_quantiles=num_quantiles,
                             layer_norm=layer_norm, hidden_sizes=hidden_sizes).to(self.device)
        self.target_dpsi_net = DPsi(self.observation_dim, self.action_dim, self.phi_dim, embedding_size=embedding_size, num_quantiles=num_quantiles,
                                    layer_norm=layer_norm, hidden_sizes=hidden_sizes).to(self.device)

        self.target_dpsi_net.load_state_dict(self.dpsi_net.state_dict())
        for param in self.target_dpsi_net.parameters():
            param.requires_grad = False
        self.dpsi_optim = optim.Adam(self.dpsi_net.parameters(), lr=self.learning_rate)
        
        self.replay_buffer = ReplayBuffer(self.observation_dim, 1, rew_dim=self.phi_dim, max_size=buffer_size, action_dtype=np.uint8)
        self.min_priority = min_priority
        self.alpha = 0.6
        self.tau_type = tau_type

        # fraction proposal network
        self.fp = fp
        self.target_fp = target_fp
        if self.tau_type == 'fqf':
            self.fp = FlattenMlp(
                input_size=self.observation_dim + 1,
                output_size=num_quantiles,
                hidden_sizes=fp_hidden_sizes,
                output_activation=softmax,
            ).to(self.device)
            self.target_fp = FlattenMlp(
                input_size=self.observation_dim + 1,
                output_size=num_quantiles,
                hidden_sizes=fp_hidden_sizes,
                output_activation=softmax,
            ).to(self.device)

            self.fp_optimizer = optim.Adam(self.fp.parameters(), lr=lr_fp)

        self.risk_schedule = LinearSchedule(risk_schedule_timesteps, risk_param,
                                            risk_param if risk_param_final is None else 
                                            risk_param_final)
 

    def get_config(self):
        return {'env_id': self.env.unwrapped.spec.id,
                'learning_rate': self.learning_rate,
                'initial_epsilon': self.initial_epsilon,
                'epsilon_decay_steps:': self.epsilon_decay_steps,
                'batch_size': self.batch_size,
                'tau_target': self.tau_target,
                'gamma': self.gamma,
                'embedding_size': self.embedding_size,
                'num_quantiles': self.num_quantiles,
                'layer_norm': self.layer_norm,
                'hidden_sizes': self.hidden_sizes,
                'model_arch': self.model_arch,
                'tau_type': self.tau_type,
                'lr_fp': self.lr_fp,
                'risk_type': self.risk_type, 
                'risk_param': self.risk_param,
                'risk_param_final': self.risk_param_final,
                'risk_schedule_timesteps': self.risk_schedule_timesteps,
                'gradient_updates': self.gradient_updates,
                'buffer_size': self.buffer_size,
                'learning_starts': self.learning_starts}

    def save(self, save_replay_buffer=True, save_dir='weights/'):
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        saved_params = {}
        saved_params['dpsi_net_state_dict'] = self.dpsi_net.state_dict() 
        saved_params['target_dpsi_net_state_dict'] = self.target_dpsi_net.state_dict()
        saved_params['dpsi_nets_optimizer_state_dict'] = self.dpsi_optim.state_dict()
        #???????????????????????????????fqf
        if save_replay_buffer:
            saved_params['replay_buffer'] = self.replay_buffer
        th.save(saved_params, save_dir + "/" + self.experiment_name + '.tar')
    
    def load(self, path, load_replay_buffer=True):
        params = th.load(path)
        self.dpsi_net.load_state_dict(params['dpsi_net_state_dict'])
        self.target_dpsi_net.load_state_dict(params['target_dpsi_net_state_dict'])
        self.dpsi_optim.load_state_dict(params['dpsi_nets_optimizer_state_dict'])
        #fqf?????????????????????????????????
        if load_replay_buffer and 'replay_buffer' in params:
            self.replay_buffer = params['replay_buffer']
    
    def sample_batch_experiences(self):
        return self.replay_buffer.sample(self.batch_size, to_tensor=True, device=self.device)
    
    def train(self, w: th.tensor, writer):
        for _ in range(self.gradient_updates):
            s_obs, s_actions, s_rewards, s_next_obs, s_dones = self.sample_batch_experiences()
            tau, tau_hat, presum_tau = get_tau(s_obs, s_actions, tau_type=self.tau_type, num_quantiles=self.num_quantiles, fp=self.target_fp)
            
            with th.no_grad():
                dpsi_actions_for_next_tau = self.dpsi_net(s_next_obs, tau_hat)
                z_actions_for_next_tau = th.einsum('r,bsar->bsa', w, dpsi_actions_for_next_tau)
                q_actions_for_next_tau = th.sum(th.unsqueeze(presum_tau, dim=1) * z_actions_for_next_tau, dim=-1, keepdims=True)
                new_next_actions_for_next_tau = th.argmax(q_actions_for_next_tau, dim=1)

                next_tau, next_tau_hat, next_presum_tau = get_tau(s_next_obs, new_next_actions_for_next_tau, tau_type=self.tau_type, num_quantiles=self.num_quantiles, fp=self.target_fp)
                q_values = th.stack([policy.q_values(s_next_obs, w, next_tau_hat, next_presum_tau) for policy in self.dgpi.policies]) # [p, 256, 9]
                bq, bp = th.max(q_values, dim=0)     
                max_bq, new_next_actions = th.max(bq, dim=1)
                target_dpsi = self.target_dpsi_net(s_next_obs, next_tau_hat)
                target_dpsi = target_dpsi.gather(1, new_next_actions.long().reshape(-1,1,1,1).expand(target_dpsi.size(0), 1, target_dpsi.size(2), target_dpsi.size(3)))
                # z_pred = th.einsum('r,bsar->bsa', w, self.dpsi_net(s_obs, tau_hat))
                z_pred = self.dpsi_net(s_obs, tau_hat)
                dpsi_targets = target_dpsi.reshape(-1, self.num_quantiles*self.phi_dim)

                rewards_repeat = th.unsqueeze(s_rewards, dim=1).repeat(1, self.num_quantiles, 1)
                rewards_target = th.unsqueeze(rewards_repeat, dim=1)
                rewards_target = rewards_target.reshape(-1, self.num_quantiles*self.phi_dim)

                target_dpsi = (rewards_target + (1 - s_dones) * self.gamma * dpsi_targets).detach()
                target_dpsi = target_dpsi.reshape(-1, 1, self.phi_dim, self.num_quantiles)

            dpsi_values = self.dpsi_net(s_obs, tau_hat)
            dpsi_value = dpsi_values.gather(1, s_actions.long().reshape(-1,1,1,1).expand(dpsi_values.size(0), 1, dpsi_values.size(2), dpsi_values.size(3)))
            dpsi_value = dpsi_value.reshape(-1, 1, self.phi_dim, self.num_quantiles)

            td_error = (dpsi_value - target_dpsi)
            dpsi_loss = huber(td_error.abs(), min_priority=self.min_priority)

            self.dpsi_optim.zero_grad()
            dpsi_loss.backward()
            self.dpsi_optim.step()

            """
            Update FP
            """
            if self.tau_type == 'fqf':
                with th.no_grad():
                    dWdtau = 2 * self.dpsi_net(s_obs, tau) - z_pred - z_pred
                    dWdtau /= dWdtau.shape[0]
                    dWdtau = th.einsum('r,bsar->bsa', w, dWdtau)
                    dWdtau, _ = th.max(dWdtau, dim=1)

                self.fp_optimizer.zero_grad()
                tau.backward(gradient=dWdtau, retain_graph=True)
                self.fp_optimizer.step() 

        if self.tau_target != 1.0 or self.num_timesteps % self.target_net_update_freq == 0:
            polyak_update(self.dpsi_net.parameters(), self.target_dpsi_net.parameters(), self.tau_target)
            if self.tau_type == 'fqf':
                polyak_update(self.fp.parameters(), self.target_fp.parameters(), self.tau_target)
        
        if self.epsilon_decay_steps is not None:
            self.epsilon = linearly_decaying_epsilon(self.initial_epsilon, self.epsilon_decay_steps, self.num_timesteps, self.learning_starts, self.final_epsilon)
        
        if self.num_timesteps % 100 == 0:
            writer.add_scalar("losses/dpsi_loss", dpsi_loss.item(), self.num_timesteps)
            writer.add_scalar("metrics/epsilon", self.epsilon, self.num_timesteps)

        if not self.police_indices:
            return
        this_policy_ind = self.police_indices[-1]
        if self.dgpi is not None and this_policy_ind != len(self.dgpi.policies) - 1:
            this_task = th.tensor(self.dgpi.tasks[this_policy_ind]).float().to(self.device)
            this_policy = self.dgpi.policies[this_policy_ind]
            this_policy.num_timesteps += 1
            
            with th.no_grad():
                q_values = th.stack([policy.q_values(s_next_obs, this_task, next_tau_hat, next_presum_tau) for policy in self.dgpi.policies])
                bq, bp = th.max(q_values, dim=0)
                max_bq, new_next_actions = th.max(bq, dim=1)
                target_dpsi = this_policy.target_dpsi_net(s_next_obs, next_tau_hat)
                target_dpsi = target_dpsi.gather(1, new_next_actions.long().reshape(-1,1,1,1).expand(target_dpsi.size(0), 1, target_dpsi.size(2), target_dpsi.size(3)))
                dpsi_targets = target_dpsi.reshape(-1, self.num_quantiles*self.phi_dim)

                rewards_repeat = th.unsqueeze(s_rewards, dim=1).repeat(1, self.num_quantiles, 1)
                rewards_target = th.unsqueeze(rewards_repeat, dim=1)
                rewards_target = rewards_target.reshape(-1, self.num_quantiles*self.phi_dim)

                target_dpsi = (rewards_target + (1 - s_dones) * self.gamma * dpsi_targets).detach()
                target_dpsi = target_dpsi.reshape(-1, 1, self.phi_dim, self.num_quantiles)

            dpsi_values = this_policy.dpsi_net(s_obs, tau_hat)
            dpsi_value = dpsi_values.gather(1, s_actions.long().reshape(-1,1,1,1).expand(dpsi_values.size(0), 1, dpsi_values.size(2), dpsi_values.size(3)))
            dpsi_value = dpsi_value.reshape(-1, 1, self.phi_dim, self.num_quantiles)
            
            td_error = (dpsi_value - target_dpsi)
            dpsi_loss = huber(td_error.abs(), min_priority=self.min_priority)

            this_policy.dpsi_optim.zero_grad()
            dpsi_loss.backward()
            this_policy.dpsi_optim.step()
            
            if this_policy.tau_type == 'fqf':
                with th.no_grad():
                    dWdtau = 2 * this_policy.dpsi_net(s_obs, tau) - z_pred - z_pred
                    dWdtau /= dWdtau.shape[0]
                    dWdtau = th.einsum('r,bsar->bsa', w, dWdtau)
                    dWdtau, _ = th.max(dWdtau, dim=1)

                this_policy.fp_optimizer.zero_grad()
                tau.backward(gradient=dWdtau, retain_graph=True)
                this_policy.fp_optimizer.step()
            
            if this_policy.tau_target != 1.0 or this_policy.num_timesteps % this_policy.target_net_update_freq == 0:
                polyak_update(this_policy.dpsi_net.parameters(), this_policy.target_dpsi_net.parameters(), this_policy.tau_target)
                if this_policy.tau_type == 'fqf':
                    polyak_update(this_policy.fp.parameters(), this_policy.target_fp.parameters(), this_policy.tau_target)

    def q_values(self, obs: th.tensor, w: th.tensor, tau_hat, presum_tau) -> th.tensor:
        with th.no_grad():
            if self.risk_type == 'VaR':
                tau_ = ptu.ones_like(tau_hat) * self.risk_param
                dpsi_actions = self.dpsi_net(obs, tau_)
                q_actions = th.einsum('r,bsar->bsa', w, dpsi_actions)
            else:
                dpsi_actions = self.dpsi_net(obs, tau_hat)
                z_actions = th.einsum('r,bsar->bsa', w, dpsi_actions)
            
                if self.risk_type in ['neutral', 'std']:
                    q_actions = th.sum(th.unsqueeze(presum_tau, dim=1) * z_actions, dim=-1, keepdims=True)
                    if self.risk_type == 'std':
                        q_std = presum_tau * (z_actions - q_actions).pow(2)
                        q_actions -= self.risk_param * q_std.sum(dim=-1, keepdims=True).sqrt()
                else:
                    with th.no_grad():
                        risk_weights = distortion_de(tau_hat, self.risk_type, self.risk_param)
                    q_actions = th.sum(risk_weights * presum_tau * z_actions, dim=-1, keepdims=True)
            q_actions = q_actions[:, :, -1]
            return q_actions

    def eval(self, obs, w: np.array, tau_hat, presum_tau) -> int:
        obs = th.tensor(obs).float().to(self.device)
        w = th.tensor(w).float().to(self.device)
        if self.dgpi is not None:
            return self.dgpi.eval(obs, w, tau_hat, presum_tau)
        else:  
            return th.argmax(self.q_values(obs, w, tau_hat, presum_tau), dim=1).item()

    def act(self, obs: th.tensor, w: th.tensor, tau_hat, presum_tau) -> np.array:
        if np.random.random() < self.epsilon:
            return self.env.action_space.sample()
        else:
            if self.dgpi is not None:
                action, policy_index = self.dgpi.eval(obs, w, tau_hat, presum_tau, return_policy_index=True)
                self.police_indices.append(policy_index)
                return action
            else:
                return th.argmax(self.q_values(obs, w, tau_hat, presum_tau), dim=1).item()

    def learn(self, total_timesteps, writer, total_episodes=None, reset_num_timesteps=True, eval_env=None, eval_freq=1000, w=np.array([1.0,0.0]), M=[np.array([1.0,0.0]), np.array([0.0,1.0]), np.array([0.5,0.5])]):
        episode_reward = 0.0
        episode_vec_reward = np.zeros_like(w)
        num_episodes = 0
        self.police_indices = []
        obs, done = self.env.reset()[0], False
        obs = obs.reshape(-1)
        
        self.env.unwrapped.w = w
        tensor_w = th.tensor(w).float().to(self.device)

        self.num_timesteps = 0 if reset_num_timesteps else self.num_timesteps
        self.num_episodes = 0 if reset_num_timesteps else self.num_episodes
        for _ in range(1, total_timesteps+1):
            global action, tau_hat, presum_tau

            if total_episodes is not None and num_episodes == total_episodes:
                break

            self.num_timesteps += 1
            
            if self.num_timesteps < self.learning_starts:
                action = self.env.action_space.sample()
            else:
                action = self.act(th.tensor(obs).float().to(self.device), tensor_w, tau_hat, presum_tau)

            next_obs, reward, done, info = self.env.step(action)
            next_obs = next_obs.reshape(-1)
            tau, tau_hat, presum_tau = get_tau(obs, action, tau_type=self.tau_type, num_quantiles=self.num_quantiles, fp=self.target_fp)
            
            terminal = done if 'TimeLimit.truncated' not in info else not info['TimeLimit.truncated']
            self.replay_buffer.add(obs, action, info['phi'], next_obs, terminal)

            if self.num_timesteps >= self.learning_starts:
                self.train(tensor_w, writer)
            
            if eval_env is not None and self.num_timesteps % eval_freq == 0:
                total_reward, discounted_return, total_vec_r, total_vec_return = dsf_reward_eval_mo(self, eval_env, w, tau_hat, presum_tau)
                writer.add_scalar("eval/total_reward", total_reward, self.num_timesteps)
                writer.add_scalar("eval/discounted_return", discounted_return, self.num_timesteps)
                for i in range(episode_vec_reward.shape[0]):
                    writer.add_scalar(f"eval/total_reward_obj{i}", total_vec_r[i], self.num_timesteps)
                    writer.add_scalar(f"eval/return_obj{i}", total_vec_return[i], self.num_timesteps)

            episode_reward += reward
            episode_vec_reward += info['phi']
            if done:
                obs, done = self.env.reset()[0], False
                obs = obs.reshape(-1)
                num_episodes += 1
                self.num_episodes += 1
                
                if num_episodes % 1 == 0:
                    print(f"Episode: {self.num_episodes} Step: {self.num_timesteps}, Ep. Total Reward: {episode_reward}")

                self.police_indices = []
                writer.add_scalar("metrics/episode", self.num_episodes, self.num_timesteps)
                writer.add_scalar("metrics/episode_reward", episode_reward, self.num_timesteps)
                for i in range(episode_vec_reward.shape[0]):
                    writer.add_scalar(f"metrics/episode_reward_obj{i}", episode_vec_reward[i], self.num_timesteps)

                episode_reward = 0.0
                episode_vec_reward = np.zeros_like(w)
            else:
                obs = next_obs
        return self.target_fp
