import copy
import math

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from models.vae import VAE, DynamicalModel
from torch.utils.data import DataLoader, Dataset


# ===== Dataset and Dataloader =====
class StateDataset(Dataset):
    def __init__(self, data):
        self.state_1, self.state_4, self.actions, self.actions_2 = [], [], [], []
        for item in data:
            self.state_1.append(item[0][0])
            self.state_4.append(item[1][0])
            self.actions.append(item[0][1])
            self.actions_2.append(item[0][2])

        # Normalize data
        self.state_1 = np.array(self.state_1)
        self.state_4 = np.array(self.state_4)
        self.actions = np.array(self.actions)
        self.actions_2 = np.array(self.actions_2)
        return
        # self.state_mean = self.state_1.mean(axis=0)
        # self.state_std = self.state_1.std(axis=0) + 1e-8
        # self.action_mean = self.actions.mean(axis=0)
        # self.action_std = self.actions.std(axis=0) + 1e-8

        # self.state_1 = (self.state_1 - self.state_mean) / self.state_std
        # self.state_4 = (self.state_4 - self.state_mean) / self.state_std
        # self.actions = (self.actions - self.action_mean) / self.action_std

    def __len__(self):
        return len(self.state_1) - 1

    def __getitem__(self, idx):
        state_0 = self.state_1[idx]
        state_1 = self.state_4[idx]
        action = self.actions[idx]
        action_2 = self.actions_2[idx]
        return (
            torch.tensor(state_0, dtype=torch.float32),
            torch.tensor(state_1, dtype=torch.float32),
            torch.tensor(action, dtype=torch.float32).reshape(-1),
            torch.tensor(action_2, dtype=torch.float32).reshape(-1),
        )


# ===== Prediction Class =====
class FutureStatePredictor:
    def __init__(self, vae1, vae2, dynamical_model):
        self.vae1 = vae1
        self.vae2 = vae2
        self.dynamical_model = dynamical_model

        # self.vae1.load_state_dict(torch.load(vae1_path))
        # self.vae2.load_state_dict(torch.load(vae2_path))
        # self.dynamical_model.load_state_dict(torch.load(dynamical_model_path))

        self.vae1.eval().to("cuda:0")
        self.vae2.eval().to("cuda:0")
        self.dynamical_model.eval().to("cuda:0")
        self.cache_file = None
        # self.state_mean = torch.tensor(state_mean, dtype=torch.float32)
        # self.state_std = torch.tensor(state_std, dtype=torch.float32)
        # self.action_mean = torch.tensor(action_mean, dtype=torch.float32)
        # self.action_std = torch.tensor(action_std, dtype=torch.float32)

    def normalize_state(self, state):
        return (state - self.state_mean) / self.state_std

    def normalize_action(self, action):
        return (action - self.action_mean) / self.action_std

    def denormalize_state(self, state):
        return state * self.state_std + self.state_mean

    def eval_score_of_future_state(self, future_state, env, state_encoder):
        future_state = list(future_state[0].cpu().numpy())
        robot_1_position = future_state[0:2]
        robot_1_rotation = future_state[2:4]
        robot_2_position = future_state[8:10]
        robot_2_rotation = future_state[10:12]

        pos_ang = []
        for i in range(7):
            pos_ang.append(future_state[16 + 8 * i - 1 : 16 + 8 * i + 4 - 1])

        def convert_cos_sin_to_angle(cos, sin):
            return math.atan2(sin, cos)

        all_states = env.get_all_states()
        all_states["robot_1"]["pos"] = robot_1_position
        all_states["robot_1"]["angle"] = convert_cos_sin_to_angle(
            robot_1_rotation[0], robot_1_rotation[1]
        )
        all_states["robot_2"]["pos"] = robot_2_position
        all_states["robot_2"]["angle"] = convert_cos_sin_to_angle(
            robot_2_rotation[0], robot_2_rotation[1]
        )

        for i in range(len(all_states["debris"])):
            all_states["debris"][i]["pos"] = pos_ang[i][0:2]
            all_states["debris"][i]["angle"] = convert_cos_sin_to_angle(
                pos_ang[i][2], pos_ang[i][3]
            )
        if self.cache_file is None and env.if_cache_founded:
            self.cache_file = copy.deepcopy(env.distance_cache)
        cloned_env = env.clone(self.cache_file)
        cloned_env.update_env_by_given_state({"states": all_states})
        reward = cloned_env.global_dense_reward()
        return reward

    def predict(self, current_state, action, action_2, env, state_encoder):
        current_state = torch.tensor(current_state, dtype=torch.float32).unsqueeze(0)
        action = torch.tensor(action, dtype=torch.float32)
        action_2 = torch.tensor(action_2, dtype=torch.float32)
        # Normalize inputs
        # current_state = self.normalize_state(current_state)
        # action = self.normalize_action(action)
        # Encode current state
        with torch.no_grad():
            mu1, logvar1 = self.vae1.encode(current_state)
            z1 = self.vae1.reparameterize(mu1, logvar1)
            action = action.reshape(-1).unsqueeze(0)
            action_2 = action_2.reshape(-1).unsqueeze(0)
            combined_action = torch.cat((action, action_2), dim=1)
            # Predict next latent state
            z1_next = self.dynamical_model(z1, combined_action)
            # Decode future state
            future_state = self.vae2.decode(z1_next)
            reward = self.eval_score_of_future_state(future_state, env, state_encoder)
        # return the max reward action
        return reward


# ===== Define VAE =====
# ===== Loss Function =====
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = nn.MSELoss()(recon_x, x)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kl_loss * 0

    # def denormalize_state(self, state):
    #     return state * self.state_std + self.state_mean

    # def denormalize_action(self, action):
    #     return action * self.action_std + self.action_mean


# ===== Unified Training Function =====
def train_combined_model(
    data_loader, vae1, vae2, dynamical_model, optimizer, epochs, model_save_path
):
    for epoch in range(epochs):
        total_loss = 0
        for batch in data_loader:
            state_0, state_1, action, action_2 = batch
            combined_action = torch.cat((action, action_2), dim=1)
            # VAE1 Forward
            recon_x1, mu1, logvar1 = vae1(state_0)
            loss_vae1 = vae_loss(recon_x1, state_0, mu1, logvar1)

            # Dynamical Model Forward
            z1 = vae1.reparameterize(mu1, logvar1)
            z1_next_pred = dynamical_model(z1, combined_action)

            # VAE2 Forward
            recon_x2, mu2, logvar2 = vae2(state_1)
            loss_vae2 = vae_loss(recon_x2, state_1, mu2, logvar2)

            # Dynamical Model Loss
            loss_dyn = nn.MSELoss()(z1_next_pred, mu2)

            # Total Loss
            total_batch_loss = loss_vae1 + loss_vae2 + loss_dyn

            # Backward and Optimize
            optimizer.zero_grad()
            total_batch_loss.backward()
            optimizer.step()

            total_loss += total_batch_loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(data_loader):.4f}")
    torch.save(
        {
            "vae1": vae1,
            "vae2": vae2,
            "dynamical_model": dynamical_model,
        },
        model_save_path,
    )


# ===== Main =====
if __name__ == "__main__":
    # Hyperparameters
    state_dim = 142
    latent_dim = 32
    action_dim = 12 * 2  # 4x3
    hidden_dim = 128
    epochs = 200
    batch_size = 128
    learning_rate = 0.001
    import argparse

    parser = argparse.ArgumentParser(description="Process JSON file and ID.")
    parser.add_argument(
        "--data_path", type=str, required=True, help="The ID number to extract"
    )
    parser.add_argument("--model_save_path", required=True, type=str, help="model path")
    try:
        args = parser.parse_args()
    except:
        args = parser.parse_args([])
    data_path = args.data_path
    model_save_path = args.model_save_path
    # Load Dataset
    data = np.load(data_path, allow_pickle=True)
    dataset = StateDataset(data)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize Models
    vae1 = VAE(state_dim, latent_dim, hidden_dim)
    vae2 = VAE(state_dim, latent_dim, hidden_dim)
    dynamical_model = DynamicalModel(latent_dim, action_dim, hidden_dim)

    # Unified Optimizer
    optimizer = optim.Adam(
        list(vae1.parameters())
        + list(vae2.parameters())
        + list(dynamical_model.parameters()),
        lr=learning_rate,
    )

    # Train Combined Model
    train_combined_model(
        data_loader, vae1, vae2, dynamical_model, optimizer, epochs, model_save_path
    )
