import random
import pandas as pd
import gym
from copy import deepcopy
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import imageio
import os

device = "cuda"
torch.autograd.set_detect_anomaly(True)

Normalization = True
std = None
mean = None
batch_size = 256
########################################################################################
# Eval
########################################################################################


def evaluate_real_return(policy, env, n_episodes, horizon, deterministic):
    returns = []
    for _ in range(n_episodes):
        obs = env.reset()

        ret = 0
        # frames=[]

        for t in range(horizon):
            if Normalization:
                obs = (obs-mean)/std
            state = torch.FloatTensor(obs).to(policy.device)

            action = policy(state, deterministic, with_logprob=False)
            action = (action[0]).cpu().detach().numpy()
            # print(action.shape)
            if len(action.shape) == 3:
                action = action.squeeze(0)
            if len(action.shape) == 2:
                action = action.squeeze(0)
            # NOTE: assume rew=0 after done=True for evaluation
            obs, rew, done, _ = env.step(action)
            ret += rew
            if done:
                break
        returns.append(ret)

    return np.mean(returns)

########################################################################################
# Model definition
########################################################################################


def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)


class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation=nn.ReLU, act_limit=1):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit
        self.LOG_STD_MAX = 2
        self.LOG_STD_MIN = -20
        self.device = device

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action -
                        F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi

    def log_prob(self, obs, act):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)

        act = act / self.act_limit
        # act = torch.atanh(act) # arctanh to project [-1,1] to real
        # act=torch.clamp(act,max=1-1e-3,min=-1+1e-3)
        act = torch.atanh(torch.Tensor(act).to(obs.device))

        logp_pi = pi_distribution.log_prob(act).sum(axis=-1)
        logp_pi -= (2*(np.log(2) - act - F.softplus(-2*act))).sum(axis=1)

        return logp_pi


class Agent(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes=(256, 256), activation=nn.ReLU, act_limit=1, weight_decay=0, lr=3e-4):
        super().__init__()
        self.ac = Actor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.pi_optimizer = torch.optim.Adam(
            self.ac.parameters(), lr=lr, weight_decay=weight_decay)
        self.device = device

########################################################################################
# IL Methods
########################################################################################


def POIL(agent, expert_states, expert_actions, greedy=False, steps=100, beta=0.1, lambda_value=0):
    assert expert_states.shape[0] == expert_actions.shape[0]

    total_loss = 0
    total_margin = 0
    total_positive_reward = 0
    total_negative_reward = 0
    epsilon = 1e-3

    for i in range(steps):
        # random sample a batch of expert states and actions
        idx = np.random.randint(0, expert_states.shape[0], batch_size)
        state = torch.FloatTensor(expert_states[idx]).to(agent.device)
        chosen_act = torch.FloatTensor(expert_actions[idx]).to(agent.device)

        with torch.no_grad():
            reject_act, reference_rejected_logps = agent.ac(
                state, deterministic=False, with_logprob=True)

        # Clamp the reject action to the action space
        chosen_act = torch.clamp(chosen_act, -1+epsilon, 1-epsilon)
        reject_act = torch.clamp(reject_act, -1+epsilon, 1-epsilon)

        # Calculate log probabilities for chosen actions
        policy_chosen_logps = agent.ac.log_prob(state, chosen_act)
        policy_reject_logps = agent.ac.log_prob(state, reject_act)

        positive_reward = policy_chosen_logps.detach().mean().item()
        negative_reward = policy_reject_logps.detach().mean().item()
        margin = positive_reward - negative_reward

        total_positive_reward += positive_reward
        total_negative_reward += negative_reward
        total_margin += margin
        logits = policy_chosen_logps - policy_reject_logps

        losses = - \
            torch.nn.functional.logsigmoid(
                beta * logits)-lambda_value*policy_chosen_logps

        loss = losses.mean()
        total_loss += loss.sum()

        agent.pi_optimizer.zero_grad()
        loss.backward()

        agent.pi_optimizer.step()
    total_loss = total_loss.mean().item()
    return total_loss, total_margin, total_positive_reward, total_negative_reward


def ORPO(agent, expert_states, expert_actions, greedy=False, steps=100, beta=0.1):
    assert expert_states.shape[0] == expert_actions.shape[0]

    total_loss = 0
    total_margin = 0
    total_positive_reward = 0
    total_negative_reward = 0
    epsilon = 1e-3

    for i in range(steps):
        # random sample a batch of expert states and actions
        idx = np.random.randint(0, expert_states.shape[0], batch_size)
        state = torch.FloatTensor(expert_states[idx]).to(agent.device)
        chosen_act = torch.FloatTensor(expert_actions[idx]).to(agent.device)

        with torch.no_grad():
            reject_act, reference_rejected_logps = agent.ac(
                state, deterministic=False, with_logprob=True)

        # Clamp the reject action to the action space
        chosen_act = torch.clamp(chosen_act, -1+epsilon, 1-epsilon)
        reject_act = torch.clamp(reject_act, -1+epsilon, 1-epsilon)

        # Calculate log probabilities for chosen actions
        policy_chosen_logps = agent.ac.log_prob(state, chosen_act)
        policy_rejected_logps = agent.ac.log_prob(state, reject_act)
        reference_chosen_logps = (
            1-torch.clamp(torch.exp(policy_chosen_logps), max=-epsilon)).log()
        reference_rejected_logps = (
            1-torch.clamp(torch.exp(policy_rejected_logps), max=-epsilon)).log()

        positive_reward = (policy_chosen_logps -
                           reference_chosen_logps).detach().mean().item()
        negative_reward = (policy_rejected_logps -
                           reference_chosen_logps).detach().mean().item()
        margin = positive_reward - negative_reward

        total_positive_reward += positive_reward
        total_negative_reward += negative_reward
        total_margin += margin

        log_odds = (policy_chosen_logps-policy_rejected_logps)-(torch.log1p(-torch.exp(
            policy_chosen_logps))-torch.log1p(-torch.exp(policy_rejected_logps)))
        sig_ratio = torch.sigmoid(log_odds)
        ratio = torch.log(sig_ratio+epsilon)
        losses = beta*ratio

        loss = (policy_chosen_logps-losses).mean()
        total_loss += loss.sum()

        agent.pi_optimizer.zero_grad()
        loss.backward()

        agent.pi_optimizer.step()
    total_loss = total_loss.mean().item()
    return total_loss, total_margin, total_positive_reward, total_negative_reward


def RRHF(agent, expert_states, expert_actions, greedy=False, steps=100):
    assert expert_states.shape[0] == expert_actions.shape[0]

    total_loss = 0
    total_margin = 0
    total_positive_reward = 0
    total_negative_reward = 0
    epsilon = 1e-3

    for i in range(steps):
        # random sample a batch of expert states and actions
        idx = np.random.randint(0, expert_states.shape[0], batch_size)
        state = torch.FloatTensor(expert_states[idx]).to(agent.device)
        chosen_act = torch.FloatTensor(expert_actions[idx]).to(agent.device)

        with torch.no_grad():
            reject_act, reference_rejected_logps = agent.ac(
                state, deterministic=False, with_logprob=True)

        # Clamp the reject action to the action space
        chosen_act = torch.clamp(chosen_act, -1+epsilon, 1-epsilon)
        reject_act = torch.clamp(reject_act, -1+epsilon, 1-epsilon)

        # Calculate log probabilities for chosen actions
        chosen_logratios = agent.ac.log_prob(state, chosen_act)
        reject_logratios = agent.ac.log_prob(state, reject_act)

        positive_reward = chosen_logratios.detach().mean().item()
        negative_reward = reject_logratios.detach().mean().item()
        margin = positive_reward - negative_reward

        total_positive_reward += positive_reward
        total_negative_reward += negative_reward
        total_margin += margin

        losses = torch.clamp(-chosen_logratios +
                             reject_logratios, min=0)-chosen_logratios

        loss = losses.mean()
        total_loss += loss.sum()

        agent.pi_optimizer.zero_grad()
        loss.backward()

        agent.pi_optimizer.step()
    total_loss = total_loss.mean().item()
    return total_loss, total_margin, total_positive_reward, total_negative_reward


def SLiC_HF(agent, expert_states, expert_actions, greedy=False, steps=100):
    assert expert_states.shape[0] == expert_actions.shape[0]

    total_loss = 0
    total_margin = 0
    total_positive_reward = 0
    total_negative_reward = 0
    epsilon = 1e-3

    for i in range(steps):
        # random sample a batch of expert states and actions
        idx = np.random.randint(0, expert_states.shape[0], batch_size)
        state = torch.FloatTensor(expert_states[idx]).to(agent.device)
        chosen_act = torch.FloatTensor(expert_actions[idx]).to(agent.device)

        with torch.no_grad():

            reject_act, reference_rejected_logps = agent.ac(
                state, deterministic=False, with_logprob=True)

        # Clamp the reject action to the action space
        chosen_act = torch.clamp(chosen_act, -1+epsilon, 1-epsilon)
        reject_act = torch.clamp(reject_act, -1+epsilon, 1-epsilon)

        # Calculate log probabilities for chosen actions
        chosen_logratios = agent.ac.log_prob(state, chosen_act)
        reject_logratios = agent.ac.log_prob(state, reject_act)

        positive_reward = chosen_logratios.detach().mean().item()
        negative_reward = reject_logratios.detach().mean().item()
        margin = positive_reward - negative_reward

        total_positive_reward += positive_reward
        total_negative_reward += negative_reward
        total_margin += margin

        losses = torch.clamp(1-chosen_logratios +
                             reject_logratios, min=0)-chosen_logratios

        loss = losses.mean()
        total_loss += loss.sum()

        agent.pi_optimizer.zero_grad()
        loss.backward()

        agent.pi_optimizer.step()
    total_loss = total_loss.mean().item()
    return total_loss, total_margin, total_positive_reward, total_negative_reward


def SimPO(agent, expert_states, expert_actions, greedy=False, steps=100, beta=2.0, gamma=1):
    assert expert_states.shape[0] == expert_actions.shape[0]

    total_loss = 0
    total_margin = 0
    total_positive_reward = 0
    total_negative_reward = 0
    epsilon = 1e-3

    for i in range(steps):
        # random sample a batch of expert states and actions
        idx = np.random.randint(0, expert_states.shape[0], batch_size)
        state = torch.FloatTensor(expert_states[idx]).to(agent.device)
        chosen_act = torch.FloatTensor(expert_actions[idx]).to(agent.device)

        with torch.no_grad():

            reject_act, reference_rejected_logps = agent.ac(
                state, deterministic=False, with_logprob=True)

        # Clamp the reject action to the action space
        chosen_act = torch.clamp(chosen_act, -1+epsilon, 1-epsilon)
        reject_act = torch.clamp(reject_act, -1+epsilon, 1-epsilon)

        # Calculate log probabilities for chosen actions
        chosen_logratios = agent.ac.log_prob(state, chosen_act)
        reject_logratios = agent.ac.log_prob(state, reject_act)

        positive_reward = chosen_logratios.detach().mean().item()
        negative_reward = reject_logratios.detach().mean().item()
        margin = positive_reward - negative_reward

        total_positive_reward += positive_reward
        total_negative_reward += negative_reward
        total_margin += margin

        losses = - \
            torch.nn.functional.logsigmoid(
                beta * chosen_logratios-beta*reject_logratios-gamma)

        loss = losses.mean()
        total_loss += loss.sum()

        agent.pi_optimizer.zero_grad()
        loss.backward()

        agent.pi_optimizer.step()
    total_loss = total_loss.mean().item()
    return total_loss, total_margin, total_positive_reward, total_negative_reward


########################################################################################
# Entrypoint
########################################################################################


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def parse_args():
    parser = argparse.ArgumentParser(
        description="Configuration for expert dataset and model parameters")

    # Required arguments
    parser.add_argument("--expert_path", type=str,
                        required=True, help="Path to the expert dataset")
    parser.add_argument("--load_freq", type=int, default=0,
                        help="Frequency for loading previous model")
    parser.add_argument("--method", type=str, required=True,
                        choices=["POIL", "SimPO", "ORPO", "SLiC_HF", "RRHF"], help="Method to use")
    parser.add_argument("--weight_decay", type=float, default=1e-3,
                        help="Whether to use weight decay for the optimizer")
    parser.add_argument("--env_name", type=str, required=True,
                        help="Name of the environment")
    parser.add_argument("--total_steps", type=int,
                        default=100000, help="Total training steps")
    parser.add_argument("--eval_freq", type=int,
                        default=500, help="Evaluation frequency")

    # Optional arguments
    parser.add_argument("--beta", type=float, default=0.1,
                        help="Beta parameter (optional)")
    parser.add_argument("--lr", type=float, default=3e-4,
                        help="Learning rate (optional)")
    parser.add_argument("--gamma", type=float, default=1.0,
                        help="Gamma parameter (optional)")
    parser.add_argument("--Lambda", type=float, default=0,
                        help="Lambda parameter (optional)")
    # seed
    parser.add_argument("--seed", type=int, default=0,
                        help="Seed for reproducibility")

    return parser.parse_args()


def get_log_path(args):
    # Create base directory structure
    base_dir = os.path.join("logs", args.env_name)
    # Different dataset diffrerent directory
    base_dir = os.path.join(base_dir, os.path.basename(args.expert_path))

    # Create method-specific filename
    filename_parts = [
        f"{args.method}",
        f"weight-decay_{args.weight_decay:.0e}",
        f"lr_{args.lr:.0e}",
        f"seed_{args.seed}",

    ]

    # Add method-specific parameters
    if args.method == "POIL":
        filename_parts.append(f"beta_{args.beta:.0e}")
        filename_parts.append(f"lambda_{args.Lambda:.0e}")
    elif args.method == "ORPO":
        filename_parts.append(f"beta_{args.beta:.0e}")
    elif args.method == "SimPO":
        filename_parts.append(f"beta_{args.beta:.0e}")
        filename_parts.append(f"gamma_{args.gamma:.0e}")

    # Join all parts to create the filename
    filename = "_".join(filename_parts) + ".csv"

    # Combine base directory and filename
    log_path = os.path.join(base_dir, filename)

    # Ensure the directory exists
    os.makedirs(os.path.dirname(log_path), exist_ok=True, mode=0o777)

    return log_path


if __name__ == "__main__":
    args = parse_args()

    # Print configuration
    print("Configuration:")
    for key, value in vars(args).items():
        print(f"  {key}: {value}")

    # Set device
    log_path = get_log_path(args)

    if os.path.exists(log_path):
        print(f"Log path {log_path} already exists. Exiting...")
        exit()
    set_seed(args.seed)

    # Load expert dataset
    expert_obs = np.load(os.path.join(args.expert_path, "expert_obs.npy"))
    expert_act = np.load(os.path.join(args.expert_path, "expert_act.npy"))

    # Initialize environment
    assert args.env_name in ['Hopper-v2', 'HalfCheetah-v2', 'Walker2d-v2']
    if Normalization:
        mean = expert_obs.mean(axis=0)
        std = expert_obs.std(axis=0)
        expert_obs = (expert_obs-mean)/std
    env = gym.make(args.env_name)

    env.seed(args.seed)  # Assuming Seed is defined as 0
    # Initialize agent and previous model
    agent = Agent(
        env.observation_space.shape[0], env.action_space.shape[0], lr=args.lr, weight_decay=args.weight_decay)
    actor = args.actor_type
    all_expert_actions = torch.FloatTensor(expert_act).to(device=device)
    all_expert_states = torch.FloatTensor(expert_obs).to(device=device)

    # Initialize tracking variables
    loss_list = []
    margin_list = []
    positive_reward_list = []
    negative_reward_list = []
    return_det_list = []

    # Initial evaluation
    print("Step 0")
    print("Evaluating real return")
    horizon = 1000
    n_episodes = 1
    real_return_det = evaluate_real_return(
        agent.ac, env, n_episodes, horizon, deterministic=True)
    print(f"Deterministic real return: {real_return_det}")
    return_det_list.append(real_return_det)
    loss_list.append(torch.tensor(0))
    margin_list.append(0)
    positive_reward_list.append(0)
    negative_reward_list.append(0)

    # Main training loop
    for step in range(1, 1+int(args.total_steps / args.eval_freq)):

        if args.method == "POIL":
            loss, margin, positive_reward, negative_reward = POIL(
                agent, all_expert_states, all_expert_actions, steps=args.eval_freq, beta=args.beta, lambda_value=args.Lambda)
        elif args.method == "ORPO":
            loss, margin, positive_reward, negative_reward = ORPO(
                agent, all_expert_states, all_expert_actions, steps=args.eval_freq, beta=args.beta)
        elif args.method == "RRHF":
            loss, margin, positive_reward, negative_reward = RRHF(
                agent, all_expert_states, all_expert_actions, steps=args.eval_freq)
        elif args.method == "SLiC_HF":
            loss, margin, positive_reward, negative_reward = SLiC_HF(
                agent, all_expert_states, all_expert_actions, steps=args.eval_freq)
        elif args.method == "SimPO":
            loss, margin, positive_reward, negative_reward = SimPO(
                agent, all_expert_states, all_expert_actions, steps=args.eval_freq, beta=args.beta, gamma=args.gamma)

        else:
            raise ValueError("Invalid method")

        # Update tracking variables
        loss_list.append(loss)
        margin_list.append(margin)
        positive_reward_list.append(positive_reward)
        negative_reward_list.append(negative_reward)

        # Evaluate if needed
        print(
            f"---------------------------------\nStep {step * args.eval_freq}")
        print(f"Loss: {loss}")
        print(f"Margin: {margin}")
        print(f"Positive reward: {positive_reward}")
        print(f"Negative reward: {negative_reward}")
        print("---------------------------------")
        print("Evaluating real return")
        real_return_det = evaluate_real_return(
            agent.ac, env, n_episodes, horizon, deterministic=True)
        print(f"Deterministic real return: {real_return_det}")
        return_det_list.append(real_return_det)

    # all to numpy  [to_cpu_numpy(item) for item in tensor_or_list]
    loss_list = np.array(loss_list)

    margin_list = np.array(margin_list)
    positive_reward_list = np.array(positive_reward_list)
    negative_reward_list = np.array(negative_reward_list)
    return_det_list = np.array(return_det_list)
    # Save results
    df = pd.DataFrame({
        'loss': loss_list,
        'margin': margin_list,
        'positive_reward': positive_reward_list,
        'negative_reward': negative_reward_list,
        'deterministic_return': return_det_list,
    })
    log_path = get_log_path(args)
    df.to_csv(log_path, index=False)
    print(f"Results saved to {log_path}")
    best_return = max(return_det_list)
    print(f"Best deterministic return: {best_return}")
