import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import pickle as pkl


def gen_net(in_size=1, out_size=1, hidden_size=128, n_layers=3, activation='tanh'):
    net = []
    for i in range(n_layers):
        net.append(nn.Linear(in_size, hidden_size))
        net.append(nn.LeakyReLU())
        in_size = hidden_size
    net.append(nn.Linear(in_size, out_size))
    if activation == 'tanh':
        net.append(nn.Tanh())
    elif activation =='sig':
        net.append(nn.Sigmoid())
    else:
        net.append(nn.ReLU())
    
    return net

class RewardModel:
    def __init__(self, obs_dim, action_dim,
                 ensemble_size=3, lr=3e-4, mb_size=128, size_segment=1, max_size=100,
                 activation='tanh', capacity=1e3, large_batch=1, label_margin=0.0, device='cuda'):
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.ensemble_size = ensemble_size
        self.lr = lr
        self.ensemble = []
        self.paramlst = []
        self.optimizer = None
        self.max_size = max_size
        self.activation = activation
        self.size_segment = size_segment
        
        self.capacity = int(capacity)
        self.buffer_seg1 = np.empty((self.capacity, size_segment, self.obs_dim + self.action_dim), dtype=np.float32)
        self.buffer_seg2 = np.empty((self.capacity, size_segment, self.obs_dim + self.action_dim), dtype=np.float32)
        self.buffer_label = np.empty((self.capacity, 2), dtype=np.float32)
        self.buffer_index = 0
        self.buffer_full = False
        self.device = device
        
        self.construct_ensemble()
        self.inputs = []
        self.targets = []
        self.mb_size = mb_size
        self.origin_mb_size = mb_size
        self.train_batch_size = 128
        self.CELoss = nn.CrossEntropyLoss()
        self.large_batch = large_batch
        
        # new teacher
        
        self.label_margin = label_margin
        self.label_target = 1-2*self.label_margin
        
        
    def eval(self):
        for model in self.ensemble:
            model.eval()
    
    def construct_ensemble(self):
        for i in range(self.ensemble_size):
            model = nn.Sequential(*gen_net(in_size=self.obs_dim + self.action_dim,
                                           out_size=1, hidden_size=256, n_layers=3,
                                           activation=self.activation)).float().to(self.device)
            self.ensemble.append(model)
            self.paramlst.extend(model.parameters())
        self.optimizer = torch.optim.Adam(self.paramlst, lr=self.lr)
        
    def softXEnt_loss(self, input, target):
        """_summary_

        Args:
            input (_type_): _description_
            target (_type_): _description_
        """
        log_probs = torch.nn.functional.log_softmax(input, dim=1)
        return - (target * log_probs).sum() / input.shape[0]
    
    def change_batch(self, new_frac):
        self.mb_size = int(self.origin_mb_size*new_frac)
    
    def set_batch(self, new_batch):
        self.mb_size = int(new_batch)
    
    def add_data(self, input_path, label_path):
        for idx in range(int(self.capacity)):
            with open(f'{input_path}/inputs{idx:03d}.pkl', 'rb') as f:
                i = pkl.load(f)
            with open(f'{label_path}/labels{idx:03d}.pkl', 'rb') as f:
                l = pkl.load(f)
                self.inputs.append(i)
                self.targets.append(l)
                
        self.data_indices = np.random.permutation(int(self.capacity))
    
    def p_hat_member(self, x_1, x_2, member=-1):
        """
        x_1, x_2 : (mb_size * large_batch, segment_length, obs_dim + action_dim)
        return (mb_size * large_batch, 1)
        """
        with torch.no_grad():
            r_hat1 = self.r_hat_member(x_1, member=member) # (mb_size * large_batch, segment_length, 1)
            r_hat2 = self.r_hat_member(x_2, member=member) # (mb_size * large_batch, segment_length, 1)
            r_hat1 = r_hat1.sum(axis=1) # (mb_size * larget_batch, 1)
            r_hat2 = r_hat2.sum(axis=1) # (mb_size * larget_batch, 1)
            r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
        return F.softmax(r_hat, dim=-1)[:, 0]
    
    def r_hat_member(self, x, member=-1):
        if isinstance(x, np.ndarray):
            return self.ensemble[member](torch.from_numpy(x).float().to(self.device))
        else:
            return self.ensemble[member](x.float().to(self.device))
    
    def r_hat(self, x):
        """
        1) x : (obs_dim + action_dim, ) -> (ensemble_size, 1) -> (1, )
        """
        r_hats = []
        for member in range(self.ensemble_size):
            r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy())
        r_hats = np.array(r_hats)
        return np.mean(r_hats)
    
    def r_hat_batch(self, x):
        """
        x : (batch_size, obs_dim + action_dim) -> (ensemble_size, batch_size, 1) -> (batch_size, 1)
        """
        r_hats = []
        for member in range(self.ensemble_size):
            r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy())
        r_hats = np.array(r_hats)
        return np.mean(r_hats, axis=0)
    
    def save(self, model_dir, step):
        for member in range(self.ensemble_size):
            torch.save(self.ensemble[member].state_dict(), '%s/reward_model_%s_%s.pt'%(model_dir, step, member))
    
    def load(self, model_dir, step):
        for member in range(self.ensemble_size):
            self.ensemble[member].load_state_dict(torch.load('%s/reward_model_%s_%s.pt'%(model_dir, step, member)))
    
    def get_queries(self, mb_size=20):
        next_index = self.buffer_index + mb_size if self.buffer_index + mb_size <= self.capacity else self.capacity
        indices = self.data_indices[self.buffer_index:next_index]
        sa_t_1 = np.stack([self.inputs[i]['traj1'] for i in indices], axis=0)
        sa_t_2 = np.stack([self.inputs[i]['traj2'] for i in indices], axis=0)
        
        return sa_t_1, sa_t_2
    
    def get_label(self, mb_size=20):
        next_index = self.buffer_index + mb_size if self.buffer_index + mb_size <= self.capacity else self.capacity
        indices = self.data_indices[self.buffer_index:next_index]
        labels = np.stack([self.targets[i] for i in indices], axis=0)
        return labels
    
    def put_queries(self, sa_t_1, sa_t_2, labels):

        mb_size = sa_t_1.shape[0]
        next_index = self.buffer_index + mb_size if self.buffer_index + mb_size <= self.capacity else self.capacity
        if next_index >= self.capacity:
            self.buffer_full = True
            maximum_index = self.capacity - self.buffer_index
            np.copyto(self.buffer_seg1[self.buffer_index:self.capacity], sa_t_1[:maximum_index])
            np.copyto(self.buffer_seg2[self.buffer_index:self.capacity], sa_t_2[:maximum_index])
            np.copyto(self.buffer_label[self.buffer_index:self.capacity], labels[:maximum_index])
        else:
            np.copyto(self.buffer_seg1[self.buffer_index:next_index], sa_t_1)
            np.copyto(self.buffer_seg2[self.buffer_index:next_index], sa_t_2)
            np.copyto(self.buffer_label[self.buffer_index:next_index], labels)
            self.buffer_index = next_index
        
    def uniform_sampling(self):
        """
        sa_t : (batch_size, size_segment, obs_dim + action_dim)
        r_t : (batch_size, size_segment, 1)
        label : (batch_size, 1), if equally preferable -> -1??
        """
        sa_t_1, sa_t_2 = self.get_queries(mb_size = self.mb_size)
        labels = self.get_label(mb_size=self.mb_size)
        if len(labels) > 0:
            self.put_queries(sa_t_1, sa_t_2, labels)
        return len(labels)
    
    def train_soft_reward(self):
        ensemble_losses = [[] for _ in range(self.ensemble_size)]
        # ensemble_acc = np.array([0 for _ in range(self.ensemble_size)])
        max_len = self.capacity if self.buffer_full else self.buffer_index
        total_batch_index = []
        total_loss = 0.
        for _ in range(self.ensemble_size):
            total_batch_index.append(np.random.permutation(max_len))
        
        num_iters = int(np.ceil(max_len/self.train_batch_size))
        total = 0
        for it in range(num_iters):
            self.optimizer.zero_grad()
            loss = 0.0
            
            last_index = (it + 1) * self.train_batch_size
            if last_index > max_len:
                last_index = max_len
            
            for member in range(self.ensemble_size):
                
                idxs = total_batch_index[member][it * self.train_batch_size:last_index]
                sa_t_1 = torch.FloatTensor(self.buffer_seg1[idxs]) # (batch_size, size_segment, obs_dim + action_dim)
                sa_t_2 = torch.FloatTensor(self.buffer_seg2[idxs]) # (batch_size, size_segment, obs_dim + action_dim)
                labels = torch.FloatTensor(self.buffer_label[idxs]).to(self.device) # (batch_size, 2)
                
                if member == 0:
                    total += labels.size(0)
                    
                r_hat1 = self.r_hat_member(sa_t_1, member=member) # (batch_size, size_segment, 1)
                r_hat2 = self.r_hat_member(sa_t_2, member=member) # (batch_size, size_segment, 1)
                r_hat1 = r_hat1.sum(axis=1) # (batch_size, 1)
                r_hat2 = r_hat2.sum(axis=1) # (batch_size, 1)
                r_hat = torch.cat([r_hat1, r_hat2], axis=-1) # (batch_size, 2)
                
                # uniform_index = labels == -1
                # labels[uniform_index] = 0
                # target_onehot = torch.zeros_like(r_hat).scatter(1, labels.unsqueeze(1), self.label_target)
                # target_onehot += self.label_margin
                # if sum(uniform_index) > 0:
                #     target_onehot[uniform_index] = 0.5
                curr_loss = self.softXEnt_loss(r_hat, labels)
                loss += curr_loss
                ensemble_losses[member].append(curr_loss.item())
                
                # _, predicted = torch.max(r_hat.data, 1) # (batch_size, )
                # correct = (predicted == labels).sum().item()
                # ensemble_acc[member] += correct
            
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            
        # ensemble_acc = ensemble_acc / total
        return total_loss / (total * self.ensemble_size)