import torch
import numpy as np

#from distributionalrl.model import MOGN, FQF, IQN, QRDQN

class Ensemble():
    def __init__(self, model, device, epi_comparison_measure="oneversusall", epi_comparison_hyps=[0.0], num_ensembles=1, **args):
        self.comparison_measure = epi_comparison_measure
        self.comparison_hyps = epi_comparison_hyps

        self.ensemble = []
        for i in range(num_ensembles):
            self.ensemble.append(model(**args).to(device))
    
    def calculate_density_parameters(self, states, actions=None, batch_size=None):
        #print(len(states), len(self.ensemble))
        return [self.ensemble[i].calculate_density_parameters(states[i*batch_size:(i+1)*batch_size], actions=actions[i*batch_size:(i+1)*batch_size]) for i in range(len(self.ensemble))]
    
    def calculate_q(self, states=None, step=None):
        q_values = torch.stack([self.ensemble[i].calculate_q(states=states) for i in range(len(self.ensemble))])

        #print(q_values)

        #TODO: add other comparison measures
        if(self.comparison_measure == "oneversusall"):
            Fx = torch.mean((q_values[:,None,:,None,:] <= q_values[None,:,:,:,None])*1.0, dim=0)
            Fx = torch.prod(Fx, dim=3)/(torch.diagonal(Fx, dim1=2, dim2=3)+1e-8)
            return (torch.mean(Fx, dim=0))**self.comparison_hyps[1]
        
        elif(self.comparison_measure == "trueoneversusone"):
            probas = torch.mean(torch.mean((q_values[:,None,:,None,:] <= q_values[None,:,:,:,None])*1.0, dim=0), dim=0)
            return torch.mean(probas, dim=2)**self.comparison_hyps[1]
        
        elif(self.comparison_measure == "oneversusone"):
            epsilon = self.comparison_hyps[0]
            probas = torch.mean(torch.mean((q_values[:,None,:,None,:] <= q_values[None,:,:,:,None])*1.0, dim=0), dim=0)
            return (torch.mean((probas >= 0.5 + epsilon) + 0.5*(torch.abs(probas - 0.5) < epsilon), dim=2))**self.comparison_hyps[1]
        
        elif(self.comparison_measure == "ucb" and ((step is None) == False)):
            mean_q = torch.mean(q_values, dim=0)
            return mean_q + 0.1*torch.sqrt((np.log(step)/step)*(torch.mean(q_values**2, dim=0) - mean_q**2))
        else:
            return torch.mean(q_values, dim=0)
        
    def params(self, *args):
        params = []
        for i in range(len(self.ensemble)):
            params += self.ensemble[i].params(*args)
        return params
    
    def sample_noise(self):
        for network in self.ensemble:
            network.sample_noise()

    def state_dict(self):
        new_dict = {}
        for i in range(len(self.ensemble)):
            i_dict = self.ensemble[i].state_dict()
            for old_key in i_dict:
                new_dict[str(i)+"____"+old_key] = i_dict[old_key]

        return new_dict

    def load_state_dict(self, dict, undesired_params=[]):
        ensemble_dicts = [{} for i in range(len(self.ensemble))]
        for key in dict.keys():
            i, true_key = key.split('____')
            i = int(i)
            
            flag = False
            for param in undesired_params:
                if(len(param) <= len(true_key)):
                    if(true_key[:len(param)] == param):
                        flag = True
                        break
            if(flag):
                continue

            ensemble_dicts[i][true_key] = dict[key]

        for i in range(len(self.ensemble)):
            #print(self.ensemble[i].state_dict())
            self.ensemble[i].load_state_dict(ensemble_dicts[i])

    def train(self):
        for net in self.ensemble:
            net.train()

    def eval(self):
        for net in self.ensemble:
            net.eval()