import torch
from torch.optim import Adam
#import os

from distributionalrl.model import Ensemble, MOGN
from distributionalrl.utils import disable_gradients, update_params, calculate_JT_gaussian_loss, calculate_JTxpow2_gaussian_loss, sampleGMM, normal_density, calculate_MMD_gaussian_loss, calculate_SSBU_loss, calculate_MW2_loss, mse_gauss

from .base_agent import BaseAgent


class MOGNAgent(BaseAgent):

    def __init__(self, env, test_env, log_dir, frame_stack=4, num_steps=5*(10**7), algorithm="TD",
                 batch_size=32, K=32, prop_lr=1e-8, moment_lr=5e-5, memory_size=10**6, epi_memory_size=10**5, gamma=0.99,
                 multi_step=1, update_interval=4, target_update_interval=10000,
                 start_steps=50000, epsilon_train=0.01, epsilon_eval=0.001, epsilon_decay_steps=250000, 
                 double_q_learning=False, decision_function="greedy", dueling_net=False, noisy_net=False, alea_comparison_measure="wang", 
                 alea_comparison_hyps=(0.0), epi_comparison_measure="wang", epi_comparison_hyps=(0.0), epistemic_method="ensemble", num_ensembles=1,
                 use_per=False, log_interval=100, eval_interval=250000, num_eval_steps=125000, max_episode_steps=27000, grad_cliping=None, 
                 cuda=True, seed=0, loss="JT2", beta=0.8, theta=(0,0,1), gammas=(1,1), noise_hyps=(0.0, 0.0)):
        
        log_dir = log_dir+loss+f'-theta={theta[0]}_{theta[1]}_{theta[2]}-beta={beta}-gammas={gammas[0]}_{gammas[1]}'+'method='+epistemic_method +'-num_ensembles='+str(num_ensembles)+'-'+'epi='+epi_comparison_measure+'_'+'_'.join(map(str, epi_comparison_hyps))+'_'+'-alea='+alea_comparison_measure+'_'+'_'.join(map(str, alea_comparison_hyps))

        super(MOGNAgent, self).__init__(
            env, test_env, log_dir, frame_stack, num_steps, algorithm, batch_size, memory_size,
            gamma, multi_step, update_interval, target_update_interval,
            start_steps, epsilon_train, epsilon_eval, epsilon_decay_steps,
            double_q_learning, decision_function, dueling_net, noisy_net, epistemic_method, use_per, 
            log_interval, eval_interval, num_eval_steps, 
            max_episode_steps, grad_cliping, cuda, seed, epi_memory_size, noise_hyps)

        ensemble_comparison_measure = epi_comparison_measure if (epistemic_method == "ensemble") else "wang"
        ensemble_comparison_hyps = epi_comparison_hyps if (epistemic_method == "ensemble") else (0.0)
        self.num_ensembles = num_ensembles if (epistemic_method == "ensemble") else 1
        # Online network.
        self.online_net = Ensemble(MOGN, self.device, epi_comparison_measure=ensemble_comparison_measure, epi_comparison_hyps=ensemble_comparison_hyps, num_ensembles=self.num_ensembles,
            state_shape=env.observation_space.shape,
            num_actions=self.num_actions, K=K, dueling_net=dueling_net, noisy_net=noisy_net, comparison_measure=alea_comparison_measure, comparison_hyps=alea_comparison_hyps)
        # Target network.
        if(algorithm == "TD" or algorithm == "SAT"):
            self.target_net = Ensemble(MOGN, self.device, epi_comparison_measure=ensemble_comparison_measure, epi_comparison_hyps=ensemble_comparison_hyps, num_ensembles=self.num_ensembles,
                state_shape=env.observation_space.shape,
                num_actions=self.num_actions, K=K, dueling_net=dueling_net, noisy_net=noisy_net, comparison_measure=alea_comparison_measure, comparison_hyps=alea_comparison_hyps)
            
            # Copy parameters of the learning network to the target network.
            self.update_target()

            # Disable calculations of gradients of the target network.
            disable_gradients(self.target_net.ensemble)

        # Epistemic network
        self.epistemic_net = None
        if(epistemic_method == "epinet"):
            self.epistemic_net = MOGN(
                state_shape=env.observation_space.shape,
                num_actions=self.num_actions, K=K, is_embedding=False, dueling_net=dueling_net, noisy_net=noisy_net, comparison_measure=epi_comparison_measure, comparison_hyps=epi_comparison_hyps).to(self.device)

        # Load model
        #self.online_net.load_state_dict(torch.load(os.path.join('logs', 'QbertNoFrameskip-v4', 'mogn2-rbf-seed0-20240705-1729', 'model', 'final', 'online_net.pth'), map_location=torch.device('cpu')))
        #self.target_net.load_state_dict(torch.load(os.path.join('logs', 'MontezumaRevengeNoFrameskip-v4', 'mogn2-seed0-20240704-1608', 'model', 'final', 'target_net.pth'), map_location=torch.device('cpu')))

        #print(self.online_net.state_dict())
        self.optim = Adam(
            self.online_net.params(prop_lr),
            lr=moment_lr, eps=1e-2/batch_size)

        self.K = K
        self.beta = beta
        self.theta = theta
        self.gammas = gammas
        
        splitting = loss.split(':')
        self.loss, self.approx = splitting[0], ""
        if(len(splitting) >= 2):
            self.approx = splitting[1]

    def learn(self):
        self.learning_steps += 1
        self.online_net.sample_noise()

        if(self.algorithm == "TD" or self.algorithm == "SAT"):
            self.target_net.sample_noise()

        if self.use_per:
            (states, actions, rewards, next_states, dones), weights =\
                self.memory.sample(self.batch_size*self.num_ensembles)
        else:
            states, actions, rewards, next_states, dones =\
                self.memory.sample(self.batch_size*self.num_ensembles)
            weights = None

        #with torch.autograd.anomaly_mode.detect_anomaly():
        online_loss, mean_q = self.calculate_loss(
            states, actions, rewards, next_states, dones, weights)

        update_params(
            self.optim, online_loss,
            networks=self.online_net.ensemble,
            retain_graph=False, grad_cliping=self.grad_cliping)
        
        if(self.epistemic_method == "epinet"):
            if self.use_per:
                (states, actions, rewards, next_states, dones), weights =\
                    self.epi_memory.sample(self.batch_size*self.num_ensembles)
            else:
                states, actions, rewards, next_states, dones =\
                    self.epi_memory.sample(self.batch_size*self.num_ensembles)
                weights = None
            epistemic_loss = self.calculate_epistemic_loss(states, actions, rewards, next_states, dones, weights)
            update_params(
                self.optim, epistemic_loss,
                networks=[self.epistemic_net],
                retain_graph=False, grad_cliping=self.grad_cliping)

        if self.episodes % self.log_interval == 0:
            self.writer.add_scalar(
                'loss/'+self.loss, online_loss.detach().item(),
                4*self.steps)
            self.writer.add_scalar('stats/mean_Q', mean_q, 4*self.steps)

    def calculate_loss(self, states, actions, rewards, next_states,
                       dones, weights):

        # Calculate gaussian mixture density
        current_density = self.online_net.calculate_density_parameters(states, actions, batch_size=self.batch_size)

        target_density = []

        if(self.algorithm == "TD"):
            with torch.no_grad():
                # Calculate Q values of next states.
                if self.double_q_learning:
                    # Sample the noise of online network to decorrelate between
                    # the action selection and the quantile calculation.
                    self.online_net.sample_noise()
                    next_q = self.online_net.calculate_q(states=next_states)
                else:
                    next_q = self.target_net.calculate_q(
                        states=next_states)

                # Calculate greedy actions.
                next_actions = torch.argmax(next_q, dim=1, keepdim=True)
                
                new_target_density = self.target_net.calculate_density_parameters(next_states, next_actions, batch_size=self.batch_size)

                for i in range(self.num_ensembles):
                    target_density.append([None, None, None])
                    target_density[i][0] = new_target_density[i][0]
                    target_density[i][1] = rewards[i*self.batch_size:(i+1)*self.batch_size] + self.gamma_n*new_target_density[i][1]
                    target_density[i][2] = self.gamma_n*new_target_density[i][2]

        elif(self.algorithm == "SA"):
            with torch.no_grad():
                next_q = self.online_net.calculate_q(states=next_states)
                next_actions = torch.argmax(next_q, dim=1, keepdim=True)
                next_density = self.online_net.calculate_density_parameters(next_states, next_actions, batch_size=self.batch_size)

                for i in range(self.num_ensembles):
                    target_density.append([None, None, None])
                    target_density[i][0] = torch.cat(((1-self.beta)*current_density[i][0], self.beta*next_density[i][0]), dim=1)
                    target_density[i][1] = torch.cat((current_density[i][1], rewards + self.gamma_n*next_density[i][1]), dim=1)
                    target_density[i][2] = torch.cat((current_density[i][2], self.gamma_n*next_density[i][2]), dim=1)
        elif(self.algorithm == "SAT"):
            with torch.no_grad():
                next_q = self.target_net.calculate_q(states=next_states)
                next_actions = torch.argmax(next_q, dim=1, keepdim=True)
                next_density = self.target_net.calculate_density_parameters(next_states, next_actions, batch_size=self.batch_size)
                new_target_density = self.target_net.calculate_density_parameters(states, actions, batch_size=self.batch_size)

                for i in range(self.num_ensembles):
                    target_density.append([None, None, None])
                    target_density[i][0] = torch.cat(((1-self.beta)*new_target_density[i][0], self.beta*next_density[i][0]), dim=1)
                    target_density[i][1] = torch.cat((new_target_density[i][1], rewards + self.gamma_n*next_density[i][1]), dim=1)
                    target_density[i][2] = torch.cat((new_target_density[i][2], self.gamma_n*next_density[i][2]), dim=1)
        
        online_loss = 0.0

        if(self.loss == "JT2"):
            for i in range(self.num_ensembles):
                online_loss += calculate_JT_gaussian_loss(current_density[i], target_density[i])
        elif(self.loss == "JTxpow2"):
            for i in range(self.num_ensembles):
                online_loss += calculate_JTxpow2_gaussian_loss(current_density[i], target_density[i])
        elif(self.loss == "MMD"):
            for i in range(self.num_ensembles):
                online_loss += calculate_MMD_gaussian_loss(current_density[i], target_density[i], self.theta, self.gammas)
        elif(self.loss == "MW2"):
            for i in range(self.num_ensembles):
                online_loss += calculate_MW2_loss(current_density[i], target_density[i])

        return online_loss, next_q.detach().mean().item()
    
    def calculate_epistemic_loss(self, states, actions, rewards, next_states,
                       dones, weights):
        
        with torch.no_grad():
            current_density = self.online_net.calculate_density_parameters(states, actions, batch_size=self.batch_size)

            target_density = []

            # Calculate Q values of next states.
            if self.double_q_learning:
                # Sample the noise of online network to decorrelate between
                # the action selection and the quantile calculation.
                self.online_net.sample_noise()
                next_q = self.online_net.calculate_q(states=next_states)
            else:
                next_q = self.target_net.calculate_q(
                    states=next_states)

            # Calculate greedy actions.
            next_actions = torch.argmax(next_q, dim=1, keepdim=True)
            
            new_target_density = self.target_net.calculate_density_parameters(next_states, next_actions, batch_size=self.batch_size)

            error_samples = []
            for i in range(self.num_ensembles):
                target_density = [None, None, None]
                target_density[0] = new_target_density[i][0]
                target_density[1] = rewards[i*self.batch_size:(i+1)*self.batch_size] + self.gamma_n*new_target_density[i][1]
                target_density[2] = self.gamma_n*new_target_density[i][2]

                #print(target_density[0].shape)
                #print(target_density[1].shape)
                #print(target_density[2].shape)

                gt_q = self.target_net.ensemble[i].calculate_q(states=None, density=target_density)
                #online_q = self.online_net.ensemble[i].calculate_q(states=None, density=current_density[i])
                #print(gt_q.shape, online_q.shape)
                #error_samples.append(gt_q-online_q)
                error_samples.append(gt_q)

            states_embeddings = [self.online_net.ensemble[i].calculate_state_embeddings(states) for i in range(self.num_ensembles)]



        epistemic_loss = 0.0
        for i in range(self.num_ensembles):
            epistemic_loss -= self.epistemic_net.log_likelihood(states_embeddings[i], actions[i*self.batch_size:(i+1)*self.batch_size], error_samples[i])

        return epistemic_loss
    