import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset

import torch.nn.functional as F
from tqdm import tqdm
import pickle

from pantheonrl.common.agents import StaticPolicyAgent
from pantheonrl.common.util import distribution_from_policy


def params_from_policy(policy_batch, obs_dim=None, use_std=False):
    """
    Get a dict of params batch from a list of policies
    
    Args:
        ego_batch (list): a list of ego agents of batchsize
    """
    fc_1_weight = []
    fc_2_weight = []
    fc_3_weight = []
    bias = []
    for policy in policy_batch:
        fc_1 = policy.mlp_extractor.policy_net[0] # out: 64
        fc_2 = policy.mlp_extractor.policy_net[2] # in: 64, out: 64
        fc_3 = policy.action_net # in: 64, out: 6
        
        original_fc_1_weight = fc_1._parameters['weight']
        if obs_dim is None or obs_dim == original_fc_1_weight.shape[-1]:
            fc_1_weight.append(original_fc_1_weight.unsqueeze(0).unsqueeze(0))
        else: 
            obs_padding = np.zeros(len(original_fc_1_weight.shape), dtype=int)
            obs_padding[-1] = obs_dim - original_fc_1_weight.shape[-1]
            fc_1_weight.append(F.pad(original_fc_1_weight, pad=tuple(obs_padding), mode='constant', value=0).unsqueeze(0).unsqueeze(0))
        fc_2_weight.append(fc_2._parameters['weight'].unsqueeze(0).unsqueeze(0))
        fc_3_weight.append(fc_3._parameters['weight'].unsqueeze(0).unsqueeze(0))
        
        if use_std:
            concated_bias = torch.concat([fc_1._parameters['bias'], fc_2._parameters['bias'], fc_3._parameters['bias'], policy.log_std.T], dim=0)
        else:
            concated_bias = torch.concat([fc_1._parameters['bias'], fc_2._parameters['bias'], fc_3._parameters['bias']], dim=0)
        bias.append(concated_bias.unsqueeze(0))
        
    weight_batch_1 = torch.concat(fc_1_weight, dim=0)
    weihgt_batch_2 = torch.concat(fc_2_weight, dim=0)
    weight_batch_3 = torch.concat(fc_3_weight, dim=0)
    bias_batch = torch.concat(bias, dim=0)
    
    params = {'fc_1_weight': weight_batch_1, 'fc_2_weight': weihgt_batch_2, 'fc_3_weight': weight_batch_3, 'bias': bias_batch}
    return params

def process_recon_params(recon_params, bias_dims=[64, 64, 6]):
    """
    Process the reconstructed params from vae, organize into loadable params

    Args:
        recon_params (_type_): _description_
    """
    params_list = []
    batch_size = recon_params['fc_1_weight'].shape[0]
    weight_batch_1 = recon_params['fc_1_weight']
    weight_batch_2 = recon_params['fc_2_weight']
    weight_batch_3 = recon_params['fc_3_weight']
    bias_batch = recon_params['bias']
    for i in range(batch_size):
        fc_1_weight = weight_batch_1[i]
        fc_2_weight = weight_batch_2[i]
        fc_3_weight = weight_batch_3[i]
        bias = bias_batch[i]
        
        bias_list = torch.split(bias, bias_dims)
        params = {
            'fc_1_weight': fc_1_weight, 'fc_1_bias': bias_list[0],
            'fc_2_weight': fc_2_weight, 'fc_2_bias': bias_list[1],
            'fc_3_weight': fc_3_weight, 'fc_3_bias': bias_list[2]
            }
        params_list.append(params)
    return params_list

def params_from_modules(modules, idx):
    """
    Get loadable parameters dict from a list of modules
    
    Args:
        modules: a list of modules in a single model
        idx: a list specifies the modules to process
    """
    fc_1_weight = modules[idx[0]]._parameters['weight']
    fc_1_bias = modules[idx[0]]._parameters['bias']
    fc_2_weight = modules[idx[1]]._parameters['weight']
    fc_2_bias = modules[idx[1]]._parameters['bias']
    fc_3_weight = modules[idx[2]]._parameters['weight']
    fc_3_bias = modules[idx[2]]._parameters['bias']

    params = {
        'fc_1_weight': fc_1_weight, 'fc_1_bias': fc_1_bias,
        'fc_2_weight': fc_2_weight, 'fc_2_bias': fc_2_bias,
        'fc_3_weight': fc_3_weight, 'fc_3_bias': fc_3_bias
        }
    return params

def process_recon_params_list(recon_params_list):
    return recon_params_list
    
def set_params(policy, params, use_std=False):
    """
    Load params to policy and return it

    Args:
        policy (ActorCriticPolicy): the policy to load params and return
        params (dict): a dict of loadable params (tensor)
    """
    policy.mlp_extractor.policy_net[0]._parameters['weight'] = params['fc_1_weight'][
        : policy.mlp_extractor.policy_net[0]._parameters['weight'].shape[0],
        : policy.mlp_extractor.policy_net[0]._parameters['weight'].shape[1]
    ]
    policy.mlp_extractor.policy_net[0]._parameters['bias'] = params['fc_1_bias'][
        : policy.mlp_extractor.policy_net[0]._parameters['bias'].shape[0]
    ]
    policy.mlp_extractor.policy_net[2]._parameters['weight'] = params['fc_2_weight'][
        : policy.mlp_extractor.policy_net[2]._parameters['weight'].shape[0],
        : policy.mlp_extractor.policy_net[2]._parameters['weight'].shape[1]
    ]
    policy.mlp_extractor.policy_net[2]._parameters['bias'] = params['fc_2_bias'][
        : policy.mlp_extractor.policy_net[2]._parameters['bias'].shape[0]
    ]
    policy.action_net._parameters['weight'] = params['fc_3_weight'][
        : policy.action_net._parameters['weight'].shape[0],
        : policy.action_net._parameters['weight'].shape[1]
    ]
    policy.action_net._parameters['bias'] = params['fc_3_bias'][
        : policy.action_net._parameters['bias'].shape[0]
    ]
    if use_std:
        policy.log_std = torch.nn.Parameter(params['fc_3_bias'][policy.action_net._parameters['bias'].shape[0]: policy.action_net._parameters['bias'].shape[0] * 2])
    return policy

def params_list_to_batch(recon_params_list):
    """
    Input: a list of n policy params dict with keys 'fc_1_weight', 'fc_1_bias', 'fc_2_...'
    """
    fc_1_weights = []
    fc_2_weights = []
    fc_3_weights = []
    biases = []

    for params in recon_params_list:
        fc_1_weights.append(params['fc_1_weight'].unsqueeze(0))
        fc_2_weights.append(params['fc_2_weight'].unsqueeze(0))
        fc_3_weights.append(params['fc_3_weight'].unsqueeze(0))
        
        bias = []
        bias.append(params['fc_1_bias'])
        bias.append(params['fc_2_bias'])
        bias.append(params['fc_3_bias'])
        bias = torch.concat(bias)
        biases.append(bias.unsqueeze(0))
    
    fc_1_weights = torch.concat(fc_1_weights, dim=0)
    fc_2_weights = torch.concat(fc_2_weights, dim=0)
    fc_3_weights = torch.concat(fc_3_weights, dim=0)
    biases = torch.concat(biases, dim=0)
    
    return {
        'fc_1_weight': fc_1_weights,
        'fc_2_weight': fc_2_weights,
        'fc_3_weight': fc_3_weights,
        'bias': biases
    }

def set_random_params(policy):
    policy.mlp_extractor.policy_net[0]._parameters['weight'] = torch.randn_like(policy.mlp_extractor.policy_net[0]._parameters['weight'])
    policy.mlp_extractor.policy_net[0]._parameters['bias'] = torch.randn_like(policy.mlp_extractor.policy_net[0]._parameters['bias'])
    policy.mlp_extractor.policy_net[2]._parameters['weight'] = torch.randn_like(policy.mlp_extractor.policy_net[2]._parameters['weight'])
    policy.mlp_extractor.policy_net[2]._parameters['bias'] = torch.randn_like(policy.mlp_extractor.policy_net[2]._parameters['bias'])
    policy.action_net._parameters['weight'] = torch.randn_like(policy.action_net._parameters['weight'])
    policy.action_net._parameters['bias'] = torch.randn_like(policy.action_net._parameters['bias'])
    
    return policy


def get_recon_loss(policy_1, policy_2, transitions, recon_loss_fn, steps=400, sample=False):
    """
    Calculate the policy distance of policy 1 and 2 conditioned on transitions

    Args:
        policy_1 (ActorCriticPolicy): policy of ego agent 1, usually reconstructed policy
        policy_2 (ActorCriticPolicy): policy of ego agent 2
        transitions (SimultaneousTransitions): joint transitions of ego agent and alt agent
        recon_loss_fn: kl_divergence_loss 
    """
    
    recon_loss = 0

    if sample:
        obs_indices = np.random.choice(len(transitions), steps, replace=False)
        observations = np.array([transitions[i].obs for i in obs_indices])
    else:
        observations = np.array([observ.obs for observ in transitions])
    
    distribution_1 = distribution_from_policy(observations, policy_1).distribution # pridict
    distribution_2 = distribution_from_policy(observations, policy_2).distribution # true
    
    if isinstance(distribution_1, torch.distributions.Normal):
        recon_loss = recon_loss_fn(distribution_1.loc, distribution_2.loc)
    else:
        recon_loss = recon_loss_fn(distribution_1.logits, distribution_2.logits)

    return recon_loss

