import torch
import copy
import torch.nn.functional as F
import torch.nn as nn
from utils.loss import  manifold_alignment_loss, individual_reconstruction_loss, manifold_prediction_loss, manifold_contrastive_loss_sameseries, ManifoldTopologyLoss


def get_param_groups_with_weight_decay(named_params, weight_decay: float):
    decay_params = []
    no_decay_params = []

    for name, param in named_params:
        if not param.requires_grad:
            continue
        if name.endswith(".bias") or "norm" in name.lower() or "bn" in name.lower():
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    param_groups = [
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": no_decay_params, "weight_decay": 0.0},
    ]
    return param_groups

def compute_batch_smoothness(
    embeddings: torch.Tensor, 
    num_series: int, 
    batch_size_time: int
) -> float:
    try:
        reshaped_emb = embeddings.view(batch_size_time, num_series, -1)
    except RuntimeError:
        current_total = embeddings.shape[0]
        if current_total % num_series != 0:
            return 0.0 
        real_B = current_total // num_series
        reshaped_emb = embeddings.view(real_B, num_series, -1)

    if reshaped_emb.size(0) < 2:
        return 0.0 
    diff = reshaped_emb[1:] - reshaped_emb[:-1]

    diff_norm = torch.norm(diff, p=2, dim=-1)

    smoothness = diff_norm.mean().item()

    return smoothness


class TwoStage_Trainer:
    def __init__(self, model, train_loader, test_loader, args, alignment_coeff=1.0, contrastive_coeff=1.0, reconstruction_coeff=1.0, topology_coeff=1.0, train_global_lr=1e-4, train_individual_lr=1e-4, patience = 10, device=None):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.alignment_coeff = alignment_coeff
        self.contrastive_coeff = contrastive_coeff
        self.individual_reconstruction_coeff = reconstruction_coeff
        self.topology_coeff = topology_coeff
        self.train_global_lr = train_global_lr
        self.train_individual_lr = train_individual_lr
        self.args = args
        self.num_series = train_loader.dataset.num_series
        self.patience = patience

        if device is None:
            self.device = torch.device(args.device)
        else:
            self.device = device
        
    def train_global_phase(self, num_epochs: int):
        print("=== stage-1 training global embedder ===")

        named_params = list(self.model.global_embedder.named_parameters()) + \
                       list(self.model.global_predictor.named_parameters())

        weight_decay_global = getattr(self.args, "weight_decay_global", 1e-4)
        param_groups = get_param_groups_with_weight_decay(
            named_params,
            weight_decay=weight_decay_global
        )

        optimizer = torch.optim.Adam(
            param_groups,
            lr=self.train_global_lr
        )

        best_metric = float('inf')
        best_model_state = None
        counter = 0
        
        for epoch in range(num_epochs):
            self.model.train()
            total_loss = 0
            epoch_smoothness = 0
            
            for batch_idx, batch in enumerate(self.train_loader):
                global_windows = batch.global_windows.to(self.device, non_blocking=True)
                global_future_targets = batch.global_future_targets.to(self.device, non_blocking=True) 
                
                noise = torch.randn_like(global_windows) * 0.05
                global_embeddings = self.model.global_embedder(global_windows + noise)

                prediction_global = self.model.global_predictor(global_embeddings)

                global_prediction_loss = manifold_prediction_loss(prediction_global,global_future_targets)

                optimizer.zero_grad()
                global_prediction_loss.backward()
                optimizer.step()
                
                total_loss += global_prediction_loss.item()
                
                with torch.no_grad():
                    if global_embeddings.size(0) > 1:
                        diff = global_embeddings[1:] - global_embeddings[:-1] 
                        smoothness = torch.norm(diff, p=2, dim=1).mean().item()
                        epoch_smoothness += smoothness
                    else:
                        epoch_smoothness += 0.0

            avg_loss = total_loss/len(self.train_loader)
            avg_smoothness = epoch_smoothness / len(self.train_loader)
                
            print(f'Global Phase - Epoch {epoch} completed. Average Loss: {avg_loss:.4f}, Smoothness={avg_smoothness:.4f}')

            current_metric = avg_smoothness
            
            if current_metric < best_metric:
                best_metric = current_metric
                best_model_state = copy.deepcopy(self.model.state_dict())
                counter = 0
            else:
                counter += 1
                
            if counter >= self.patience:
                print(f"Global Phase - Early stopping triggered at epoch {epoch}.")
                break

        if best_model_state is not None:
            self.model.load_state_dict(best_model_state)
            print("Global Phase - Loaded best model based on Smoothness-Regularized Metric.")
    
    
    
    
    def train_individual_phase(self, num_epochs: int):
        print("=== stage-2 training individual embedder ===")

        for param in self.model.global_embedder.parameters():
            param.requires_grad = False
        for param in self.model.reconstructor.parameters():
            param.requires_grad = False

        indiv_named_params = []
        indiv_named_params += list(self.model.shared_embedder.named_parameters())
        indiv_named_params += list(self.model.individual_reconstructor.named_parameters())
        indiv_named_params += list(self.model.shared_mapper.named_parameters())
        indiv_named_params += [("sequence_adapters", self.model.sequence_adapters)]
        indiv_named_params += list(self.model.adapter_projector.named_parameters())

        weight_decay_indiv = getattr(self.args, "weight_decay_individual", 1e-4)
        param_groups = get_param_groups_with_weight_decay(
            indiv_named_params,
            weight_decay=weight_decay_indiv
        )

        optimizer = torch.optim.Adam(
            param_groups,
            lr=self.train_individual_lr
        )

        topology_loss = ManifoldTopologyLoss(temperature=0.1).to(self.device)

        best_metric = float('inf') 
        counter = 0
        
        for epoch in range(num_epochs):
            self.model.train()
            epoch_total_loss = 0
            epoch_smoothness = 0
            
            for batch_idx, batch in enumerate(self.train_loader):
                
                global_windows = batch.global_windows.to(self.device, non_blocking=True)
                individual_windows = batch.individual_windows.to(self.device, non_blocking=True)
                series_ids = batch.series_ids.to(self.device, non_blocking=True)
                time_indices = batch.time_indices.to(self.device, non_blocking=True)

                batch_data = {
                    'global_windows': global_windows,
                    'individual_windows': individual_windows,
                    'series_ids': series_ids,
                    'time_indices': time_indices
                }

                outputs = self.model(batch_data)

                global_emb = outputs['global_embeddings']
                indiv_emb = outputs['individual_embeddings'] 
                topology_loss_value = topology_loss(global_emb, indiv_emb, num_series=self.num_series)

                manifold_alignment_loss_value = manifold_alignment_loss(
                    outputs['observed_embeddings'],
                    outputs['individual_embeddings']
                )
                manifold_contrastive_loss_value = manifold_contrastive_loss_sameseries(
                    outputs['individual_embeddings'],
                    series_ids,
                    time_indices,
                    pos_time_threshold=self.args.pos_time_threshold,
                    neg_time_threshold=self.args.neg_time_threshold,
                    temp=self.args.contrastive_temp
                )

                total_loss_value = self.contrastive_coeff * manifold_contrastive_loss_value +\
                    self.topology_coeff * topology_loss_value
                    
                optimizer.zero_grad()
                total_loss_value.backward()
                optimizer.step()
                
                epoch_total_loss += total_loss_value.item()

                current_batch_time_len = global_emb.shape[0]
                with torch.no_grad():
                    batch_smoothness = compute_batch_smoothness(
                        indiv_emb, 
                        num_series=self.num_series, 
                        batch_size_time=current_batch_time_len
                    )
                    epoch_smoothness += batch_smoothness

            avg_loss = epoch_total_loss/len(self.train_loader)
            avg_smoothness = epoch_smoothness /len(self.train_loader)
            print(f'Individual Phase - Epoch {epoch} completed. Average Loss: {avg_loss:.4f}, Smoothness={avg_smoothness:.4f}')

            current_metric = avg_smoothness 
            
            if current_metric < best_metric:
                best_metric = current_metric
                best_model_state = copy.deepcopy(self.model.state_dict())
                counter = 0
            else:
                counter += 1
                
            if counter >= self.patience:
                print(f"Early stopping triggered at epoch {epoch}.")
                break

        if best_model_state is not None:
            self.model.load_state_dict(best_model_state)
            print("Loaded best model based on Smoothness-Regularized Metric.")