import torch
import time
import torch.optim as optim
from collections import OrderedDict
from utils.utils import print_current_loss
from os.path import join as pjoin

from diffusers import  DDPMScheduler
from torch.utils.tensorboard import SummaryWriter
import time
import pdb
import sys
import os
from torch.optim.lr_scheduler import ExponentialLR
import torch_dct
import torch.nn as nn
import torch.nn.functional as F

# Import MoCLIP MotionEncoder for representation alignment
try:
    from MoCLIP import MotionEncoder, _init_clip_motion_model, GLOBAL_CACHE
    MOCLIP_AVAILABLE = True
except ImportError:
    print("Warning: MoCLIP not available. Representation alignment loss will be disabled.")
    MOCLIP_AVAILABLE = False 

class DDPMTrainer(object):

    def __init__(self, args, model,accelerator, model_ema=None):
        self.opt = args
        self.accelerator = accelerator
        self.device = self.accelerator.device
        self.model = model
        self.diffusion_steps = args.diffusion_steps
        self.noise_scheduler = DDPMScheduler(num_train_timesteps= self.diffusion_steps,
            beta_schedule=args.beta_schedule,
            variance_type="fixed_small",
            prediction_type= args.prediction_type,
            clip_sample=False)
        self.model_ema = model_ema
        if args.is_train:
            self.mse_criterion = torch.nn.MSELoss(reduction='none')

        accelerator.print('Diffusion_config:\n',self.noise_scheduler.config)

        # Initialize representation alignment loss components
        self._init_repr_align_loss(args)

        # Track current training step for dynamic loss weighting
        self.current_step = 0

        if self.accelerator.is_main_process:
            starttime = time.strftime("%Y-%m-%d_%H:%M:%S")
            print("Start experiment:", starttime)
            self.writer = SummaryWriter(log_dir=pjoin(args.save_root,'logs_')+starttime[:16],comment=starttime[:16],flush_secs=60)  # Named by experiment time, [:13] can be customized
        self.accelerator.wait_for_everyone()

        self.optimizer = optim.AdamW(self.model.parameters(), lr=self.opt.lr, weight_decay=self.opt.weight_decay)
        self.scheduler = ExponentialLR(self.optimizer, gamma=args.decay_rate) if args.decay_rate>0 else None

    def _init_repr_align_loss(self, args):
        """Initialize LUMA dual anchor loss components"""
        # Get loss weights for frequency and temporal anchors
        self.lambda_fre = getattr(args, 'lambda_fre', 0.2)  # λ_fre for frequency anchor
        self.lambda_tem = getattr(args, 'lambda_tem', 0.3)  # λ_tem for temporal anchor
        self.decay_threshold = getattr(args, 'decay_threshold', 50000)  # N in cosine annealing
        
        # Initialize motion encoder for temporal semantic anchor alignment
        self.motion_encoder = None
        self.dual_anchor_enabled = False
        
        # Initialize MoCLIP MotionEncoder if available
        if MOCLIP_AVAILABLE and args.is_train:
            self._init_motion_encoder(args)
            
        self.accelerator.print(f"LUMA Dual Anchor Loss - Enabled: {self.dual_anchor_enabled}")
        self.accelerator.print(f"Lambda_fre: {self.lambda_fre}, Lambda_tem: {self.lambda_tem}")
        self.accelerator.print(f"Decay threshold N: {self.decay_threshold}")

    def _init_motion_encoder(self, args):
        """Initialize MotionEncoder for temporal semantic anchor f_tem(x_0)"""
        try:
            # Get parameters from args with automatic dataset type detection
            # motion_encoder_input_dim is automatically set in train_options.py based on dataset
            input_dim = getattr(args, 'motion_encoder_input_dim', args.dim_pose if hasattr(args, 'dim_pose') else 263)
            embed_dim = getattr(args, 'motion_encoder_embed_dim', 768)   # MoCLIP output dimension D_a
            max_seq_len = getattr(args, 'motion_encoder_max_seq_len', 196)
            
            # Initialize MotionEncoder for f_tem(x_0)
            self.motion_encoder = MotionEncoder(
                input_dim=input_dim,
                embed_dim=embed_dim,
                max_seq_length=max_seq_len
            ).to(self.device)
            
            # Try to load pre-trained weights
            self._load_motion_encoder_weights(args)
            
            # Freeze encoder weights for stable features
            freeze_encoder = getattr(args, 'freeze_motion_encoder', True)
            if freeze_encoder:
                for param in self.motion_encoder.parameters():
                    param.requires_grad = False
                self.motion_encoder.eval()
                self.accelerator.print("MotionEncoder weights frozen for stable features")
            else:
                self.accelerator.print("MotionEncoder weights will be trained")
            
            self.dual_anchor_enabled = True
            dataset_name = getattr(args, 'dataset_name', 'unknown')
            self.accelerator.print(f"Successfully initialized MotionEncoder (dataset={dataset_name}, input_dim={input_dim}, embed_dim={embed_dim})")
            
        except Exception as e:
            self.accelerator.print(f"Failed to initialize MotionEncoder: {e}")
            self.motion_encoder = None

    def _load_motion_encoder_weights(self, args):
        """Load pre-trained MoCLIP MotionEncoder weights for f_tem(x_0)"""
        # moclip_model_path is automatically set in train_options.py based on dataset
        default_path = './checkpoints/moclip_training/checkpoint_epoch_20.pt' if not hasattr(args, 'dataset_name') or args.dataset_name == 't2m' else './checkpoints/moclip_kit_training/best_model.pt'
        moclip_path = getattr(args, 'moclip_model_path', default_path)
        load_weights = getattr(args, 'load_motion_encoder_weights', True)
        
        if not load_weights:
            self.accelerator.print("Skipping MotionEncoder pre-trained weight loading")
            return
        
        self.accelerator.print(f"Loading MoCLIP model from: {moclip_path}")
            
        try:
            # Temporarily set device
            original_device = GLOBAL_CACHE.get("device", None)
            GLOBAL_CACHE["device"] = self.device
            
            # Initialize MoCLIP model
            _init_clip_motion_model(moclip_path)
            
            # Extract MotionEncoder weights
            if GLOBAL_CACHE["motion_encoder"] is not None:
                state_dict = GLOBAL_CACHE["motion_encoder"].state_dict()
                self.motion_encoder.load_state_dict(state_dict)
                self.accelerator.print("Successfully loaded MoCLIP MotionEncoder pre-trained weights")
            else:
                self.accelerator.print("Warning: Failed to extract MotionEncoder from MoCLIP")
            
            # Clean up GLOBAL_CACHE
            GLOBAL_CACHE["clip_model"] = None
            GLOBAL_CACHE["clip_tokenizer"] = None
            GLOBAL_CACHE["motion_encoder"] = None
            GLOBAL_CACHE["clip_motion_align_model"] = None
            if original_device is not None:
                GLOBAL_CACHE["device"] = original_device
            else:
                GLOBAL_CACHE.pop("device", None)
                
        except Exception as e:
            self.accelerator.print(f"Failed to load MotionEncoder weights: {e}")
            self.accelerator.print("Will use randomly initialized weights")

    @staticmethod
    def zero_grad(opt_list):
        for opt in opt_list:
            opt.zero_grad()

    def clip_norm(self,network_list):
        for network in network_list:
            self.accelerator.clip_grad_norm_(network.parameters(), self.opt.clip_grad_norm) # 0.5 -> 1

    @staticmethod
    def step(opt_list):
        for opt in opt_list:
            opt.step()

    def forward(self, batch_data):
        caption, motions, m_lens = batch_data
        motions = motions.detach().float()

        x_start = motions
        B, T = x_start.shape[:2]
        cur_len = torch.LongTensor([min(T, m_len) for m_len in  m_lens]).to(self.device)
        self.src_mask = self.generate_src_mask(T, cur_len).to(x_start.device)

        # Save motion length information for temporal semantic anchor computation
        self.current_m_lens = cur_len
        # Save x_start for temporal anchor loss computation
        self.x_start = x_start

        # 1. Sample noise that we'll add to the motion
        real_noise = torch.randn_like(x_start)

        # 2. Sample a random timestep for each motion
        t = torch.randint(0, self.diffusion_steps, (B,), device=self.device)
        self.timesteps = t

        # 3. Add noise to the motion according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        x_t = self.noise_scheduler.add_noise(x_start, real_noise, t)

        # 4. Network prediction with LUMA dual anchors
        self.prediction, self.z_fre, self.z_tem = self.model(x_t, t, text=caption)
        # Compute DCT of prediction for frequency semantic anchor alignment
        # torch_dct.dct only works on the last dimension, so we need to permute T to the last dim
        self.prediction_dct = torch_dct.dct(self.prediction.permute(0, 2, 1), norm='ortho').permute(0, 2, 1)
        
        # Extract low frequency components (first k DCT coefficients)
        # Adjust lowf_cutoff based on dataset to match z_fre dimension F
        if hasattr(self.opt, 'dataset_name') and self.opt.dataset_name == 'kit':
            F = 64  # F: frequency semantic anchor dimension for KIT dataset
        else:
            F = 64  # F: frequency semantic anchor dimension for T2M dataset
        
        # z_fre should match DCT_k(x_0) dimension: [B, F, d_m]
        
        if self.opt.prediction_type =='sample':
            self.target = x_start
        elif self.opt.prediction_type == 'epsilon':
            self.target = real_noise
        elif self.opt.prediction_type == 'v_prediction':
            self.target = self.noise_scheduler.get_velocity(x_start, real_noise, t)
        # Compute DCT of clean motion x_0 for frequency anchor ground truth
        self.x_start_dct = torch_dct.dct(self.x_start.permute(0, 2, 1), norm='ortho').permute(0, 2, 1)
        
        # Extract first k DCT coefficients as DCT_k(x_0) for frequency semantic anchor
        self.dct_k_x0 = self.x_start_dct[:, :F, :]  # [B, F, d_m] - DCT_k(x_0)

    def masked_l2(self, a, b, mask, weights):
        
        loss = self.mse_criterion(a, b).mean(dim=-1) # (bath_size, motion_length)
        
        loss = (loss * mask).sum(-1) / mask.sum(-1) # (batch_size, )

        loss = (loss * weights).mean()

        return loss

    def backward_G(self):
        loss_logs = OrderedDict({})
        mse_loss_weights = torch.ones_like(self.timesteps)
        
        # Standard DDPM reconstruction loss
        loss_logs['loss_mot_rec'] = self.masked_l2(self.prediction, self.target, self.src_mask, mse_loss_weights)
        
        # LUMA Dual Anchor Loss (DAL)
        loss_logs['loss_fre'] = self._compute_frequency_anchor_loss()
        loss_logs['loss_tem'] = self._compute_temporal_anchor_loss()
        
        # Dynamic anchor weighting with cosine annealing schedule ζ(n)
        zeta_n = self._compute_cosine_annealing_factor()
        
        # Combined loss: L = L_DDPM + ζ(n) * (λ_fre * L_fre + λ_tem * L_tem)
        dal_loss = self.lambda_fre * loss_logs['loss_fre'] + self.lambda_tem * loss_logs['loss_tem']
        self.loss = loss_logs['loss_mot_rec'] + zeta_n * dal_loss
        
        # Log the annealing factor for monitoring
        loss_logs['zeta_n'] = zeta_n

        return loss_logs

    def _compute_frequency_anchor_loss(self):
        """
        Compute frequency semantic anchor loss: L_fre = ||z_fre - DCT_k(x_0)||_2^2
        """
        if not hasattr(self, 'z_fre') or not hasattr(self, 'dct_k_x0'):
            return torch.tensor(0.0, device=self.device)
        
        # z_fre: [B, F, d_m], dct_k_x0: [B, F, d_m]
        return F.mse_loss(self.z_fre, self.dct_k_x0)

    def _compute_temporal_anchor_loss(self):
        """
        Compute temporal semantic anchor loss: L_tem = 1 - cos(z_tem, f_tem(x_0))
        """
        if not self.dual_anchor_enabled or self.motion_encoder is None:
            return torch.tensor(0.0, device=self.device)
        
        try:
            # Ensure motion encoder is on correct device
            self.motion_encoder.to(self.device)
            
            # Extract temporal semantic features from ground truth x_0
            # Always use x_start (clean motion) for temporal semantic anchor
            target_motion = self.x_start  # [B, T, d_m] - clean motion x_0
            
            with torch.set_grad_enabled(not getattr(self.opt, 'freeze_motion_encoder', True)):
                f_tem_x0 = self.motion_encoder(target_motion, self.current_m_lens)  # [B, D_a]
            
            # Compute cosine similarity loss: L_tem = 1 - cos(z_tem, f_tem(x_0))
            cosine_sim = F.cosine_similarity(self.z_tem, f_tem_x0, dim=-1).mean()
            return 1.0 - cosine_sim
            
        except Exception as e:
            self.accelerator.print(f"Error computing temporal anchor loss: {e}")
            return torch.tensor(0.0, device=self.device)

    def _compute_cosine_annealing_factor(self):
        """
        Compute cosine annealing factor ζ(n) = 1/2 * [1 + cos(π * min(n/N, 1))]
        where n is current step and N is decay threshold
        """
        import math
        n = self.current_step
        N = self.decay_threshold
        
        # ζ(n) = 1/2 * [1 + cos(π * min(n/N, 1))]
        ratio = min(n / N, 1.0)
        zeta_n = 0.5 * (1.0 + math.cos(math.pi * ratio))
        
        return zeta_n



    def update(self):
        self.zero_grad([self.optimizer])
        loss_logs = self.backward_G()
        self.accelerator.backward(self.loss)
        
        # Gradient clipping: include main model and trainable motion_encoder
        models_to_clip = [self.model]
        if (self.dual_anchor_enabled and 
            self.motion_encoder is not None and 
            not getattr(self.opt, 'freeze_motion_encoder', True)):
            models_to_clip.append(self.motion_encoder)
        
        self.clip_norm(models_to_clip)
        self.step([self.optimizer])

        return loss_logs
    
    def generate_src_mask(self, T, length):
        B = len(length)
        src_mask = torch.ones(B, T)
        for i in range(B):
            for j in range(length[i], T):
                src_mask[i, j] = 0
        return src_mask

    def train_mode(self):
        self.model.train()
        if self.model_ema:
            self.model_ema.train()
        # Set motion encoder to train mode only if not frozen
        if (self.dual_anchor_enabled and 
            self.motion_encoder is not None and 
            not getattr(self.opt, 'freeze_motion_encoder', True)):
            self.motion_encoder.train()

    def eval_mode(self):
        self.model.eval()
        if self.model_ema:
            self.model_ema.eval()
        # Motion encoder is always set to eval mode (usually frozen)
        if self.dual_anchor_enabled and self.motion_encoder is not None:
            self.motion_encoder.eval()

    def save(self, file_name,total_it):
        state = {
            'opt_encoder': self.optimizer.state_dict(),
            'total_it': total_it,
            'encoder': self.accelerator.unwrap_model(self.model).state_dict(),
        }
        if self.model_ema:
            state["model_ema"] = self.accelerator.unwrap_model(self.model_ema).module.state_dict()
        
        # Save motion encoder state if it's trainable
        if (self.dual_anchor_enabled and 
            self.motion_encoder is not None and 
            not getattr(self.opt, 'freeze_motion_encoder', True)):
            state["motion_encoder"] = self.motion_encoder.state_dict()
            
        torch.save(state, file_name)
        return

    def load(self, model_dir):
        checkpoint = torch.load(model_dir, map_location=self.device)
        self.optimizer.load_state_dict(checkpoint['opt_encoder'])
        if self.model_ema:
            self.model_ema.load_state_dict(checkpoint["model_ema"], strict=True)
        self.model.load_state_dict(checkpoint['encoder'], strict=True)
        
        # Load motion encoder state if it exists and is trainable
        if (self.dual_anchor_enabled and 
            self.motion_encoder is not None and 
            not getattr(self.opt, 'freeze_motion_encoder', True) and
            "motion_encoder" in checkpoint):
            self.motion_encoder.load_state_dict(checkpoint["motion_encoder"], strict=True)
            self.accelerator.print("Successfully loaded motion encoder state")
       
        return checkpoint.get('total_it', 0)

    def train(self, train_loader):
        
        it = 0
        if self.opt.is_continue:
            model_path = pjoin(self.opt.model_dir, self.opt.continue_ckpt)         
            it = self.load(model_path)
            self.current_step = it  # Set current step when resuming from checkpoint
            self.accelerator.print(f'Continue training from {it} iterations in {model_path}')
        start_time = time.time()

        logs = OrderedDict()
        self.dataset = train_loader.dataset
        self.model,self.mse_criterion,self.optimizer,train_loader, self.model_ema = \
        self.accelerator.prepare(self.model,self.mse_criterion,self.optimizer,train_loader,self.model_ema)

        num_epochs = (self.opt.num_train_steps-it)//len(train_loader)  + 1 
        self.accelerator.print(f'Need to train for {num_epochs} epochs...')
        
        for epoch in range(0, num_epochs):
            self.train_mode()
            for i, batch_data in enumerate(train_loader):
                self.forward(batch_data)
                log_dict = self.update()
                it += 1
                self.current_step = it  # Update current step for dynamic loss weighting

                if self.model_ema and it % self.opt.model_ema_steps == 0:
                    self.accelerator.unwrap_model(self.model_ema).update_parameters(self.model)

                # update logger
                for k, v in log_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                
                if it % self.opt.log_every == 0 :                   
                    mean_loss = OrderedDict({})
                    for tag, value in logs.items():
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(self.accelerator,start_time, it, mean_loss, epoch, inner_iter=i)
                    if self.accelerator.is_main_process:
                        self.writer.add_scalar("loss",mean_loss['loss_mot_rec'],it)
                        if 'loss_fre' in mean_loss:
                            self.writer.add_scalar("loss_fre",mean_loss['loss_fre'],it)
                        if 'loss_tem' in mean_loss:
                            self.writer.add_scalar("loss_tem",mean_loss['loss_tem'],it)
                        if 'zeta_n' in mean_loss:
                            self.writer.add_scalar("zeta_n",mean_loss['zeta_n'],it)
                    self.accelerator.wait_for_everyone()
                
                if it % self.opt.save_interval == 0 and self.accelerator.is_main_process: # 500
                    self.save(pjoin(self.opt.model_dir, 'latest.tar').format(it), it)
                self.accelerator.wait_for_everyone()


                if (self.scheduler is not None) and (it % self.opt.update_lr_steps == 0) :
                    self.scheduler.step()

        # Save the last checkpoint if it wasn't already saved.
        if it % self.opt.save_interval != 0 and self.accelerator.is_main_process:
            self.save(pjoin(self.opt.model_dir, 'latest.tar'), it)

        self.accelerator.wait_for_everyone()
        self.accelerator.print('FINISH')

 