import random
from copy import deepcopy
import numpy as np
import torch
from torch.optim import Adam
from numpy import linalg as LA
import gym
import d4rl
import argparse
import json
from utils import redirect_stdout, TrainLogger
import torch.nn as nn
import itertools
import torch.nn.functional as F
from torch.distributions.normal import Normal
import os
from tqdm import tqdm

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers).to(torch.device('cuda'))

class reward_nn(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation = nn.ReLU):
        super().__init__()
        self.r = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)
        self.device = torch.device('cuda')

    def forward(self, obs, act):
        r = self.r(torch.cat([obs, act], dim=-1))
        r = torch.tanh(r)
        return torch.squeeze(r, -1)

class reward_model(object):
    def __init__(self, env_name, ac_kwargs=dict(), lr=1e-3):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.obs_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.shape[0]
        self.reward_nn = reward_nn(self.obs_dim, self.act_dim, **ac_kwargs)
        self.optimizer = Adam(self.reward_nn.r.parameters(), lr=lr, weight_decay=1e-5)

    def compute_loss(self, data):
        traj_obs1, traj_act1, traj_obs2, traj_act2, pref = data
        traj_obs1, traj_act1, traj_obs2, traj_act2 = torch.as_tensor(traj_obs1, dtype=torch.float32).clone().detach().to(self.reward_nn.device),\
            torch.as_tensor(traj_act1, dtype=torch.float32).clone().detach().to(self.reward_nn.device),\
            torch.as_tensor(traj_obs2, dtype=torch.float32).clone().detach().to(self.reward_nn.device),\
            torch.as_tensor(traj_act2, dtype=torch.float32).clone().detach().to(self.reward_nn.device)
        traj_rew1 = self.reward_nn(traj_obs1[:-1], traj_act1)
        traj_rew2 = self.reward_nn(traj_obs2[:-1], traj_act2)
        utility_1 = torch.sum(traj_rew1)
        utility_2 = torch.sum(traj_rew2)
        if pref == 1:
            loss = - torch.log(torch.exp(utility_1)/(torch.exp(utility_1)+torch.exp(utility_2)))

        if pref == 2:
            loss = - torch.log(torch.exp(utility_2)/(torch.exp(utility_1)+torch.exp(utility_2)))

        # regularization = (torch.linalg.norm(traj_rew1) + torch.linalg.norm(traj_rew2))
        # loss += 0.05 * regularization

        return loss

    def compute_true_loss(self, data):
        criterion = nn.MSELoss()
        traj_obs1, traj_act1, traj_true_rew1 = data
        traj_obs1 = torch.as_tensor(traj_obs1, dtype=torch.float32).clone().detach().to(self.reward_nn.device)
        traj_act1 = torch.as_tensor(traj_act1, dtype=torch.float32).clone().detach().to(self.reward_nn.device)
        traj_true_rew1 = torch.as_tensor(traj_true_rew1, dtype=torch.float32).clone().detach().to(self.reward_nn.device)
        
        traj_rew1 = self.reward_nn(traj_obs1[:-1], traj_act1)
        return criterion(traj_rew1, traj_true_rew1.squeeze(1))

    def update(self, loss):
        self.optimizer.zero_grad()
        loss.to(self.reward_nn.device)
        loss.backward()
        self.optimizer.step()

    def test(self, dataset):
        traj_obs, traj_act, traj_rew, traj_idx_1, traj_idx_2, prefs = dataset
        with torch.no_grad():
            total_num = len(dataset[3])
            batch_idx = random.sample(range(total_num), 200)
            for idx in batch_idx:
                traj_obs1 = traj_obs[traj_idx_1[idx]]
                traj_act1 = traj_act[traj_idx_1[idx]]
                traj_obs2 = traj_obs[traj_idx_2[idx]]
                traj_act2 = traj_act[traj_idx_2[idx]]
                traj_obs1, traj_act1, traj_obs2, traj_act2 = torch.as_tensor(traj_obs1, dtype=torch.float32).clone().detach().to(
                    self.reward_nn.device), \
                    torch.as_tensor(traj_act1, dtype=torch.float32).clone().detach().to(self.reward_nn.device), \
                    torch.as_tensor(traj_obs2, dtype=torch.float32).clone().detach().to(self.reward_nn.device), \
                    torch.as_tensor(traj_act2, dtype=torch.float32).clone().detach().to(self.reward_nn.device)
                print(self.reward_nn(traj_obs1[:-1], traj_act1))
                print(self.reward_nn(traj_obs2[:-1], traj_act2))

def training(reward_model, dataset, epoch=30, epoch_steps=100, batch_size=64, reward_path=None, logger = TrainLogger()):
    total_num = len(dataset[3])
    training_num = int(total_num*0.8)
    split = [i for i in range(total_num)]
    random.shuffle(split)
    training_set = split[:training_num]
    testing_set = split[training_num:]
    traj_obs, traj_act, traj_rew, traj_idx_1, traj_idx_2, prefs = dataset
    best_test_loss = float('inf')
    for i in tqdm(range(epoch), desc = 'Reward Num Epochs'):
        training_loss = 0
        for j in tqdm(range(epoch_steps), desc = 'Num Steps', colour = 'blue'):
            batch_idx = random.sample(range(training_num), min(batch_size, training_num))
            loss = 0
            for idx in batch_idx:
                traj_obs1 = traj_obs[traj_idx_1[training_set[idx]]]
                traj_act1 = traj_act[traj_idx_1[training_set[idx]]]
                traj_obs2 = traj_obs[traj_idx_2[training_set[idx]]]
                traj_act2 = traj_act[traj_idx_2[training_set[idx]]]
                pref = prefs[training_set[idx]]
                loss += reward_model.compute_loss([traj_obs1, traj_act1, traj_obs2, traj_act2, pref])
            reward_model.update(loss)
            training_loss += loss.item()

        with torch.no_grad():
            batch_idx = random.sample(range(total_num-training_num), min(total_num-training_num, 20 * batch_size))
            test_loss = 0
            for idx in batch_idx:
                traj_obs1 = traj_obs[traj_idx_1[testing_set[idx]]]
                traj_act1 = traj_act[traj_idx_1[testing_set[idx]]]
                traj_obs2 = traj_obs[traj_idx_2[testing_set[idx]]]
                traj_act2 = traj_act[traj_idx_2[testing_set[idx]]]
                pref = prefs[testing_set[idx]]
                temp_loss = reward_model.compute_loss([traj_obs1, traj_act1, traj_obs2, traj_act2, pref])
                test_loss += temp_loss
            test_loss = test_loss.item()
            logger.log({'epoch': i, 'reward_training_loss': training_loss/(epoch_steps * batch_size), 'reward_test_loss': test_loss/(20 * batch_size)})
            if test_loss < best_test_loss:
                best_test_loss = test_loss
                if i >= 10 and reward_path is not None:
                    torch.save(reward_model.reward_nn.state_dict(), reward_path)



def training_true_reward(reward_model, dataset, epoch=30, epoch_steps=100, batch_size=64, reward_path=None, logger = TrainLogger()):
    total_num = len(dataset[3])
    training_num = int(total_num*0.8)
    split = [i for i in range(total_num)]
    random.shuffle(split)
    training_set = split[:training_num]
    testing_set = split[training_num:]
    traj_obs, traj_act, traj_rew, traj_idx_1, traj_idx_2, prefs = dataset
    best_test_loss = float('inf')
    
    for i in tqdm(range(epoch), desc = 'Reward Num Epochs'):
        training_loss = 0
        for j in tqdm(range(epoch_steps), desc = 'Num Steps', colour = 'blue'):
            batch_idx = random.sample(range(training_num), min(batch_size, training_num))
            loss = 0
            for idx in batch_idx:
                if random.random() < 0.5:
                    traj_obs1 = traj_obs[traj_idx_1[training_set[idx]]]
                    traj_act1 = traj_act[traj_idx_1[training_set[idx]]]
                    traj_rew1 = traj_rew[traj_idx_1[training_set[idx]]]
                    loss += reward_model.compute_true_loss([traj_obs1, traj_act1, traj_rew1])
                else:
                    traj_obs2 = traj_obs[traj_idx_2[training_set[idx]]]
                    traj_act2 = traj_act[traj_idx_2[training_set[idx]]]
                    traj_rew2 = traj_rew[traj_idx_2[training_set[idx]]]
                    loss += reward_model.compute_true_loss([traj_obs2, traj_act2, traj_rew2])
                    

            reward_model.update(loss)
            training_loss += loss.item()

        with torch.no_grad():
            batch_idx = random.sample(range(total_num- training_num), min(total_num- training_num, 20 * batch_size))
            test_loss = 0
            for idx in batch_idx:
                if random.random() < 0.5:
                    traj_obs1 = traj_obs[traj_idx_1[testing_set[idx]]]
                    traj_act1 = traj_act[traj_idx_1[testing_set[idx]]]
                    traj_rew1 = traj_rew[traj_idx_1[testing_set[idx]]]
                else:
                    traj_obs2 = traj_obs[traj_idx_2[testing_set[idx]]]
                    traj_act2 = traj_act[traj_idx_2[testing_set[idx]]]
                    traj_rew2 = traj_rew[traj_idx_2[testing_set[idx]]]
                temp_loss = reward_model.compute_true_loss([traj_obs1, traj_act1, traj_rew1]) 
                test_loss += temp_loss
            test_loss = test_loss.item()
            logger.log({'epoch': i, 'reward_training_loss': training_loss/(epoch_steps * batch_size), 'reward_test_loss': test_loss/(20 * batch_size)})
            if test_loss < best_test_loss:
                best_test_loss = test_loss
                if i >= 10 and reward_path is not None:
                    torch.save(reward_model.reward_nn.state_dict(), reward_path)


