import os
import sys 
sys.path.append("..") 
import math
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from utils.loss import SupConLoss, IRD_distill_loss


class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, bn=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.bn1 = nn.BatchNorm1d(dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.bn2 = nn.BatchNorm1d(dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        self.bn3 = nn.BatchNorm1d(dim_V)
        self.dropout = nn.Dropout(0.1)
        if bn:
            self.ln0 = nn.BatchNorm1d(dim_V)
            self.ln1 = nn.BatchNorm1d(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)
        dim_split = self.dim_V // self.num_heads
        Q_ = T.cat(Q.split(dim_split, 2), 0)
        K_ = T.cat(K.split(dim_split, 2), 0)
        V_ = T.cat(V.split(dim_split, 2), 0)

        A = T.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = T.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = self.dropout(O)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        
        O2 = F.relu(self.fc_o(O))
        O2 = self.dropout(O2)
        
        O = O + O2
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O


class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, bn=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, bn=bn)

    def forward(self, X):
        return self.mab(X, X)


class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, bn=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(T.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, bn=bn)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, bn=bn)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return self.mab1(X, H)


class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, bn=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(T.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, bn=bn)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)


class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=50):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = T.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = T.arange(0, max_len).float().unsqueeze(1)
        div_term = (T.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = T.sin(position * div_term)
        pe[:, 1::2] = T.cos(position * div_term)[:, :d_model//2]

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.shape[1]
        x = x + self.pe[:, :seq_len]
        
        return self.dropout(x)
        

class TrajectoryEncoder(nn.Module):
    def __init__(self, input_dim, out_dim, device):
        super().__init__()
        self.enc = nn.Sequential(
            SAB(dim_in=input_dim, dim_out=128, num_heads=4),
            SAB(dim_in=128, dim_out=128, num_heads=4),
        )
        self.pooling = nn.Sequential(
                    PMA(dim=128, num_heads=4, num_seeds=1),
                    nn.Linear(in_features=128, out_features=out_dim)
                )
        self.pos_embedding = PositionalEmbedding(input_dim)
        self.to(device)

    def forward(self, trajectory):       
        x = self.enc(self.pos_embedding(trajectory))
        rep = self.pooling(x)
        return rep


class Discriminator(nn.Module):
    def __init__(self, feature_dim, z_dim, action_dim=None):
        super(Discriminator, self).__init__()
        self.use_action = action_dim is not None  # decide whether to use action
        input_dim = feature_dim + action_dim if self.use_action else feature_dim
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, z_dim)

    def forward(self, state, action=None):
        if state.dim() > 2:
            state = state.squeeze(1)
        if self.use_action:
            assert action is not None, "Action is required when use_action=True"
            x = T.cat((state, action), dim=-1)
        else:
            x = state
        
        x = T.relu(self.fc1(x))
        z = self.fc2(x)
        return z
  

class StylePredictor(nn.Module):
    def __init__(self, state_dim, ecd_output_dim, style_num, learning_rate, device):
        super(StylePredictor, self).__init__()
        # Trajectory Encoder
        self.traj_encoder = TrajectoryEncoder(state_dim, ecd_output_dim, device)
        
        # Discriminator
        self.discriminator = Discriminator(ecd_output_dim, style_num)
        
        params = list(self.traj_encoder.parameters()) + list(self.discriminator.parameters())
        self.optimizer = optim.Adam(params, lr=learning_rate)
        
        self.criterion = nn.CrossEntropyLoss()
        self.to(device)

    def forward(self, trajectory):
        encoded_traj = self.traj_encoder(trajectory)
        z_pred = self.discriminator(encoded_traj)
        return z_pred, encoded_traj
    
    def train_step(self, traj, label_curr):
        z_pred, rep = self.forward(traj)

        ce_loss = self.criterion(z_pred, label_curr.long().squeeze())

        total_loss = ce_loss 
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        return total_loss.item()

    def save_model(self, path):
        checkpoint = {
            "traj_encoder_state_dict": self.traj_encoder.state_dict(),
            "discriminator_state_dict": self.discriminator.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
        }
        T.save(checkpoint, path)

    def load_model(self, load_path):
        """Load the model parameters and optimizer state."""
        if not os.path.exists(load_path):
            print(f"No saved model found at {load_path}")
            return
        
        checkpoint = T.load(load_path)

        self.traj_encoder.load_state_dict(checkpoint["traj_encoder_state_dict"])
        self.discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        print(f"Model loaded from {load_path}")
        

class ProjectionHead(nn.Module):

    def __init__(self, input_dim, output_dim, device):
        super(ProjectionHead, self).__init__()
        self.projector = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
        )
        self.to(device)

    def forward(self, feature):
        return self.projector(feature)
    
     
class SelfAttentionAggregator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SelfAttentionAggregator, self).__init__()
        self.query = nn.Linear(input_dim, output_dim)
        self.key = nn.Linear(input_dim, output_dim)
        self.value = nn.Linear(input_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, local_policies):
        # Self-attention calculation
        Q = self.query(local_policies)    # [B, n, d]
        K = self.key(local_policies)      # [B, n, d]
        V = self.value(local_policies)    # [B, n, d]
        
        attention_scores = T.bmm(Q, K.transpose(1, 2)) / (K.size(-1) ** 0.5)  # [B, n, n]
        
        attention_weights = self.softmax(attention_scores)  # [B, n, n]
        
        aggregated = T.bmm(attention_weights, V)  # [B, n, d]
        
        global_representation = T.mean(aggregated, dim=1)  # [B, d]
        
        return global_representation
     

class Policy_Encoder(nn.Module):

    def __init__(self, input_dim, ecd_output_dim, output_dim, learning_rate, device, feature_space=False, bank_size=8192):
        super(Policy_Encoder, self).__init__()
        self.device = device
        self.feature_space = feature_space
        # Trajectory Encoder, input: [B, t, obs_dim], output: [B, feature_dim]
        self.local_encoder = TrajectoryEncoder(input_dim, ecd_output_dim, device)

        # attention-based pooling, input: [B, n, feature_dim], output: [B, d]
        self.att_pool = SelfAttentionAggregator(ecd_output_dim, output_dim)
        
        # projection head, input: [B, d], output: [B, d]
        self.projector = ProjectionHead(output_dim, output_dim, device)
        
        # optimizer
        params = list(self.local_encoder.parameters()) + list(self.att_pool.parameters()) + list(self.projector.parameters())
        self.optimizer = optim.Adam(params, lr=learning_rate)
        
        # loss function
        self.contrast_loss = SupConLoss(device=device)
        self.to(device)

    def forward(self, trajs, n_agent, feature_space=False):  # trajs: List of tensors of shape [B, T, obs_dim*n_agent]
        trajs = T.split(trajs, trajs.shape[-1]//n_agent, dim=-1)  # Split the batch into individual trajectories, shape: ([B, T, obs_dim]*n_agent)
        trajs = T.cat(trajs, dim=0)  # Concatenate along the batch dimension, shape: [B*n_agent, T, obs_dim]
        local_features = self.local_encoder(trajs).squeeze(1)  # [B*n_agent, feature_dim]
        local_features = T.split(local_features, local_features.shape[0]//n_agent, dim=0)  # Split back into individual agents, shape: ([B, feature_dim]*n_agent)
        local_features = T.stack(local_features, dim=1)  # Stack along the agent dimension, shape: [B, n_agent, feature_dim]
        global_representation = self.att_pool(local_features) # [B, feature_dim]
        contrastive_rep = F.normalize(self.projector(global_representation), dim=-1)  # [B, feature_dim]
        return contrastive_rep, global_representation, local_features
    
    def train_step(self, trajs, label_contrast, n_agent, k=0.2, past_encoder=None):
        # input: trajs: [B*view, T, obs_dim*n_agent], label_contrast: [B*n_agent]
        self.train()

        # current trajectory representation
        contrast_rep, _, _ = self.forward(trajs, n_agent)
        contrast_rep = T.split(contrast_rep, contrast_rep.shape[0]//2, dim=0)
        contrast_rep = T.stack(contrast_rep, dim=1)  # contrast_rep shape: [B, view, d]

        # contrast_rep shape: [B, view, d]
        loss_contrast = self.contrast_loss(contrast_rep, label_contrast)
        if past_encoder is not None:
            past_rep, _, _ = past_encoder(trajs)
            past_rep = T.split(past_rep, past_rep.shape[0]//2, dim=0)
            past_rep = T.stack(past_rep, dim=1)
            loss_ird = IRD_distill_loss(contrast_rep, past_rep)
        
        loss = loss_contrast + k * loss_ird if past_encoder is not None else loss_contrast
        # ---- Optimize ----
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()
    
    def save_model(self, path):
        checkpoint = {
            "traj_encoder_state_dict": self.local_encoder.state_dict(),
            "attention_pooling_state_dict": self.att_pool.state_dict(),
            "projector_state_dict": self.projector.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
        }
        T.save(checkpoint, path)

    def load_model(self, load_path):
        """Load the model parameters and optimizer state."""
        if not os.path.exists(load_path):
            print(f"No saved model found at {load_path}")
            return
        
        checkpoint = T.load(load_path, map_location=self.device)

        self.local_encoder.load_state_dict(checkpoint["traj_encoder_state_dict"])
        self.att_pool.load_state_dict(checkpoint["attention_pooling_state_dict"])
        self.projector.load_state_dict(checkpoint["projector_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    def eval(self):
        super().eval()
        self.local_encoder.eval()
        self.att_pool.eval()
        self.projector.eval()
    
