import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import logging
from quant_layers.linear import MinMaxQuantLinear
from torch.optim.lr_scheduler import CosineAnnealingLR


class HeadWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.norm = getattr(model, 'norm', None)
        if self.norm is None and hasattr(model, 'fc_norm'): 
             self.norm = model.fc_norm
             
        self.head = model.head
        self.is_deit = hasattr(model, 'head_dist') and model.head_dist is not None
        if self.is_deit:
            self.head_dist = model.head_dist
        self.global_pool = getattr(model, 'global_pool', 'token') 

    def forward(self, x):
        if self.norm is not None:
            x = self.norm(x)
        if self.is_deit:
            y_cls = self.head(x[:, 0])
            y_dist = self.head_dist(x[:, 1])
            return (y_cls + y_dist) / 2.0
        is_plain_linear = isinstance(self.head, nn.Linear)
        
        if is_plain_linear:
            if self.global_pool == 'avg':
                if x.dim() == 3:
                     x = x.mean(dim=1) 
                elif x.dim() == 4:
                     x = x.mean(dim=[1, 2])
            elif self.global_pool == 'token':
                x = x[:, 0]
            x = self.head(x)
        else:
            x = self.head(x)
        return x


def get_tail_network(model, block_name):
    if 'head' in block_name:
        return nn.Identity()
    if hasattr(model, 'blocks') and 'blocks' in block_name:
        try:
            block_idx = int(block_name.split('.')[-1])
        except ValueError:
            raise ValueError(f"Cannot parse block index from ViT block name: {block_name}")
            
        tail_blocks = model.blocks[block_idx+1:]
        head = HeadWrapper(model)
        
        if isinstance(model.blocks, nn.ModuleList):
            tail = nn.Sequential(*tail_blocks, head)
        else:
            tail = nn.Sequential(tail_blocks, head)
        return tail
    elif hasattr(model, 'layers') and 'layers' in block_name:
        parts = block_name.split('.')
        try:
            layer_idx = int(parts[1])
        except (IndexError, ValueError):
             raise ValueError(f"Cannot parse layer index from Swin block name: {block_name}")
        
        layers_list = []
        
        if 'blocks' in parts:
            try:
                block_idx = int(parts[3])
            except (IndexError, ValueError):
                raise ValueError(f"Cannot parse block index from Swin block name: {block_name}")
                
            curr_layer_module = model.layers[layer_idx]

            remaining_blocks = curr_layer_module.blocks[block_idx+1:]
            if len(remaining_blocks) > 0:
                layers_list.append(nn.Sequential(*remaining_blocks))

            if hasattr(curr_layer_module, 'downsample') and curr_layer_module.downsample is not None:
                keys = list(curr_layer_module._modules.keys())
                try:
                    blocks_pos = keys.index('blocks')
                    downsample_pos = keys.index('downsample')
                    is_post_downsample = downsample_pos > blocks_pos
                except ValueError:
                    is_post_downsample = True
                if is_post_downsample:
                    layers_list.append(curr_layer_module.downsample)
                 
        elif 'downsample' in parts:
            pass 

        for l_idx in range(layer_idx + 1, len(model.layers)):
            layers_list.append(model.layers[l_idx])

        head = HeadWrapper(model)
        layers_list.append(head)
        tail = nn.Sequential(*layers_list)
        for m in tail.modules():
            if hasattr(m, 'mode'):
                m.mode = 'raw'  
        return tail

    elif 'patch_embed' in block_name:
        layers_list = []
        if hasattr(model, 'layers'): 
            for layer in model.layers:
                layers_list.append(layer)
        elif hasattr(model, 'blocks'): 
            layers_list.append(model.blocks)
            
        head = HeadWrapper(model)
        layers_list.append(head)
        tail = nn.Sequential(*layers_list)
        for m in tail.modules():
            if hasattr(m, 'mode'): m.mode = 'raw'
        return tail
    else:
        raise NotImplementedError(f"Unknown model structure or block name: {block_name}")


def get_act_stats(model, calib_loader, device='cuda', num_batches=20):
    act_stats = {}
    hooks = []
    def hook_fn(name):
        def _hook(module, inp, out):
            x = inp[0].detach()
            if x.dim() > 2: 
                x = x.flatten(0, -2)
            current_max = x.abs().amax(dim=0).cpu()
            if name not in act_stats:
                act_stats[name] = current_max
            else:
                act_stats[name] = torch.max(act_stats[name], current_max)
        return _hook
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, MinMaxQuantLinear)):
            hooks.append(module.register_forward_hook(hook_fn(name)))
    model.eval()
    with torch.no_grad():
        for i, (data, target) in enumerate(calib_loader):
            if i >= num_batches: break
            data = data.to(device)
            model(data)
    for h in hooks:
        h.remove()
    return act_stats


def init_rectification_params(block, prefix, act_stats, alpha=0.5, device='cuda'):
    count = 0
    for name, module in block.named_modules():
        is_labeled = getattr(module, 'is_reparam_layer', False)
        if is_labeled:
            if name == "":
                full_name = prefix
            elif prefix == "":
                full_name = name
            else:
                full_name = f"{prefix}.{name}"
            if full_name not in act_stats:
                continue

            act_max = act_stats[full_name].to(device) + 1e-6
            weight_max = module.weight.data.abs().amax(dim=0) + 1e-6

            if act_max.shape != weight_max.shape:
                logging.warning(f"[HAR Mismatch] '{full_name}': Act {act_max.shape} != Weight {weight_max.shape}. Skipping.")
                continue
            numerator = act_max.pow(alpha)
            denominator = weight_max.pow(1 - alpha)
            init_val = torch.log(numerator / denominator) * 0.5

            module.rect_p = nn.Parameter(init_val)
            module.p_init = init_val.clone().detach()
            count += 1

def absorb_rectification(block):
    count = 0
    for name, module in block.named_modules():
        p = getattr(module, 'rect_p', None)
        if p is None:
            continue
        if not isinstance(p, (torch.Tensor, torch.nn.Parameter)):
            continue
        if not hasattr(module, 'weight'):
            continue

        gamma = torch.exp(p).detach()
        module.weight.data.mul_(gamma.view(1, -1))
        current_factor = getattr(module, 'input_scaling_factor', None)
        if current_factor is None:
            module.register_buffer('input_scaling_factor', gamma.clone())
        else:
            if gamma.device != current_factor.device:
                gamma = gamma.to(current_factor.device)
            current_factor.mul_(gamma)

        module.rect_p = None 
        if hasattr(module, 'rect_p'): del module.rect_p
        if hasattr(module, 'p_init'): del module.p_init
        count += 1
    logging.info(f"[HAR] Absorbed rectification params for {count} modules.")


class RectificationDistiller:
    def __init__(self, student_block, tail_network, config):
        self.student = student_block
        self.tail = tail_network
        self.config = config
        self.device = next(self.student.parameters()).device
        self.modules_with_p = []
        self.params = []
        
        for name, m in self.student.named_modules():
            rect_p = getattr(m, 'rect_p', None)
            if rect_p is not None and isinstance(rect_p, torch.nn.Parameter):
                self.modules_with_p.append(m)
                self.params.append(rect_p)

                if not hasattr(m, 'p_init'):
                    m.p_init = rect_p.detach().clone()

        self.optimizer = torch.optim.Adam(self.params, lr=5e-3)
        self.kl_loss_func = nn.KLDivLoss(reduction='batchmean')

    def train(self, inputs, teacher_logits, batch_size=64, iters=None):
        self.student.eval()
        self.tail.eval()
        for m in self.student.modules():
            if hasattr(m, 'training_mode'):
                m.init_training()

        total_iters = 200
        calib_interval = 10
        scheduler = CosineAnnealingLR(
            self.optimizer, 
            T_max=total_iters, 
            eta_min=1e-7
        )
        
        all_inputs = inputs.to(self.device)
        all_targets = teacher_logits.to(self.device)
        total_samples = all_inputs.size(0)

        best_loss = float('inf')
        best_p_state = {} 

        val_indices = torch.randint(0, total_samples, (2 * batch_size,), device=self.device)
        val_input = all_inputs[val_indices]
        val_target = all_targets[val_indices]

        for name, m in self.student.named_modules():
            if hasattr(m, 'rect_p') and isinstance(m.rect_p, torch.nn.Parameter):
                best_p_state[name] = m.rect_p.data.clone()

        for i in tqdm(range(total_iters), desc="HAR Training", leave=False):
            indices = torch.randint(0, total_samples, (batch_size,), device=self.device)
            input_batch = all_inputs[indices]
            target_batch = all_targets[indices]
            
            if i % calib_interval == 0:
                with torch.no_grad():
                    for m in self.modules_with_p:
                        gamma = torch.exp(m.rect_p).view(1, -1)
                        w_rect = m.weight * gamma
                        if hasattr(m, 'update_weight_quant_params'):
                            m.update_weight_quant_params(w_rect)

                    for m in self.student.modules():
                        if hasattr(m, 'search_mode'): m.search_mode = True
                    self.student(input_batch) 
                    for m in self.student.modules():
                        if hasattr(m, 'search_mode'): m.search_mode = False

            self.optimizer.zero_grad()
            student_block_out = self.student(input_batch)
            student_logits = self.tail(student_block_out)
            T = 4.0 
            loss_kl = self.kl_loss_func(
                F.log_softmax(student_logits / T, dim=-1),
                F.softmax(target_batch / T, dim=-1)
            ) * (T ** 2)
            
            loss_geo = 0.0
            for m in self.modules_with_p:
                p_init = m.p_init.to(self.device)
                loss_geo += self.config.rect_geo * (m.rect_p - p_init).pow(2).sum()

            loss = loss_kl + loss_geo
            loss.backward()
            self.optimizer.step()
            scheduler.step()

            if i % 5 == 0 or i == total_iters - 1:
                with torch.no_grad():
                    val_out = self.student(val_input)
                    val_logits = self.tail(val_out)
                    
                    val_loss_kl = self.kl_loss_func(
                        F.log_softmax(val_logits / T, dim=-1),
                        F.softmax(val_target / T, dim=-1)
                    ) * (T ** 2)
                    
                    val_loss_geo = 0.0
                    for m in self.modules_with_p:
                        p_init = m.p_init.to(self.device)
                        val_loss_geo += self.config.rect_geo * (m.rect_p - p_init).pow(2).sum()
                    
                    val_total_loss = val_loss_kl + val_loss_geo

                    if val_total_loss < best_loss:
                        best_loss = val_total_loss.item()
                        for name, m in self.student.named_modules():
                            if hasattr(m, 'rect_p') and isinstance(m.rect_p, torch.nn.Parameter):
                                best_p_state[name] = m.rect_p.data.clone()

        for name, m in self.student.named_modules():
            if hasattr(m, 'rect_p') and name in best_p_state:
                m.rect_p.data.copy_(best_p_state[name])

        for m in self.student.modules():
             if hasattr(m, 'training_mode'):
                 m.end_training()

