import os
import numpy as np
import json
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import wandb

from tqdm import tqdm

from layer_registry import REPLACERS
from rep_sims import CKA, Procrustes
from modules import MemEffAttention
from dataloader import get_imagenet_dataloaders, cifar_dataset, get_text_dataloaders
from plotting import plot_cka
from resnet_utils import *
from modules import BasicBlockCompat
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from torch.cuda.amp import GradScaler, autocast

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def list_target_layers(model, target_type):
    layers = []
    for idx, (name, module) in enumerate(model.named_modules(), start = 1):
        if isinstance(module, target_type):
            layers.append((idx, name))
        elif module.__class__.__name__ == target_type.__name__:
            layers.append((idx, name))
    return layers

def extract_input_shapes(model, dummy_input, target_type):
    '''
    Extracts input shapes for each convolutional layer across a network. We extract this input 
    shape to specify how large how linear replacement should be.
    '''
    shapes = []
    hooks = []
    def hook(module, inp, out):
        x = inp[0]
        shapes.append(tuple(x.shape[1:])) # Everything other than batch size --> ENSURE BATCH FIRST!
    for m in model.modules():
        if isinstance(m, target_type):
            hooks.append(m.register_forward_hook(hook))
        elif m.__class__.__name__ == target_type.__name__:
            hooks.append(m.register_forward_hook(hook))
    device = next(model.parameters()).device
    dummy_input = dummy_input.to(device)
    model.eval()
    with torch.no_grad():
        model(dummy_input)
    for h in hooks:
        h.remove()
    return shapes

def replace_nth_module(model, target_idx, input_shape, device, target_types, use_low_rank = True, rank = 1024, from_attn = False):
    '''
    Replace the nth layer (given by target_idx)
    Does this recursively because why not?
    '''
    counter = {'i': 0}
    # replacement = REPLACERS[target_type]
    if not isinstance(target_types, tuple):
        target_types = (target_types,)
    replacers = tuple([REPLACERS[t] for t in target_types])
    def recurse(mod):
        for name, child in list(mod.named_children()):
            if isinstance(child, (target_types, replacers)):
                counter['i'] += 1
                if counter['i'] == target_idx and isinstance(child, target_types):
                    replacement = REPLACERS[type(child)]
                    if not from_attn:
                        setattr(mod, name, replacement(child, input_shape, device, low_rank = use_low_rank, rank = rank))
                    else:
                        setattr(mod, name, replacement.from_attn(child, target_idx))
                    print(f'Replaced {type(child)} #{target_idx} at {mod}.{name}')
                    return True
            elif child.__class__.__name__ in tuple(x.__name__ for x in target_types):
                counter['i'] += 1
                if counter['i'] == target_idx and child.__class__.__name__ in tuple(x.__name__ for x in target_types):
                    replacement = REPLACERS[[x for x in target_types if x.__name__ ==  child.__class__.__name__][0]]
                    if not from_attn:
                        setattr(mod, name, replacement(child, input_shape, device, low_rank = use_low_rank, rank = rank))
                    else:
                        r = r.from_attn(child, target_idx)
                        r = r.to(dtype = child.dtype)
                        setattr(mod, name, r)
                    print(f'Replaced {type(child)} #{target_idx} at {mod}.{name}')
                    return True
            else:
                if recurse(child):
                    return True
        return False
    if not recurse(model):
        prt_counter = counter['i']
        raise IndexError(f'Only {prt_counter} {target_types} layers; cannot replace #{target_idx} (counting {target_types}+{replacers})')
    return model

def get_nth_module(model, n, target_type):
    if not isinstance(target_type, tuple):
        target_type = (target_type,)
    replacers = tuple([REPLACERS[t] for t in target_type])
    count = 0
    for module in model.modules():
        if isinstance(module, (target_type, replacers)) or module.__class__.__name__ in tuple(x.__name__ for x in target_type):
            count += 1
            if count == n:
                return module
    raise IndexError

def get_norm_after_module(model, n, target_type):
    if not isinstance(target_type, tuple):
        target_type = (target_type,)
    replacers = tuple([REPLACERS[t] for t in target_type])
    count = 0
    modules = list(model.modules())
    for i, m in enumerate(modules):
        if isinstance(m, (target_type, replacers)) or m.__class__.__name__ in tuple(x.__name__ for x in target_type):
            count += 1
            if count == n:
                for j in range(i + 1, len(modules)):
                    if isinstance(modules[j], (nn.BatchNorm2d, nn.LayerNorm)):
                        return modules[j]
                return m
    raise IndexError(f'Only {count} {target_type}; cannot find {target_type} #{n}.')

def match_single_layer(args, index, orig_model, input_shapes, train_loader, device, target_type = nn.Conv2d, repdist = 'CKA', low_rank=True, rank=1024, epochs=5, lr=5e-3, 
                       distillation = False, temperature = 4, alpha = 0.5, from_attn = False):
    replaced_model = copy.deepcopy(orig_model).to(device)
    input_shape = input_shapes[index - 1]
    replace_nth_module(replaced_model, index, input_shape, device,
                       use_low_rank=low_rank, rank=rank, from_attn = from_attn)
    replaced_model = replaced_model.to(device)

    module_o = get_nth_module(orig_model, index, target_type)
    module_l = get_nth_module(replaced_model, index, target_type)
    activations_o = {}
    activations_l = {}
    h_o = module_o.register_forward_hook(lambda m, i, o: activations_o.setdefault('feat', o))
    h_l = module_l.register_forward_hook(lambda m, i, o: activations_l.setdefault('feat', o))

    for p in replaced_model.parameters():
        p.requires_grad = False
    for p in module_l.parameters():
        p.requires_grad = True

    optimizer = optim.AdamW(
        [p for p in replaced_model.parameters() if p.requires_grad],
        lr=lr
    )
    cka = CKA(device)
    proc = Procrustes(device)

    replaced_model.train()
    orig_model.eval()
    scaler = GradScaler(enabled=args.use_amp)
    cka_losses = []
    for epoch in range(epochs):
        epoch_cka_losses = []
        for i, batch in enumerate(tqdm(train_loader, desc = f'Layer {index} CKA Epoch {epoch}', dynamic_ncols = True)):
            if args.setting == 'imagenet':
                inputs = batch[0]['images']
                inputs = inputs.permute(0, 3, 1, 2)
            elif args.setting == 'cifar':
                inputs, _ = batch
                inputs = inputs.to(device)
            elif args.setting == 'text':
               inputs = {'input_ids': batch['input_ids'].to(device)}
            activations_o.clear(); activations_l.clear()
            with torch.autocast(device_type = 'cuda', dtype = torch.float16):
                if isinstance(inputs, dict):
                    with torch.no_grad():
                        orig(**inputs, use_cache = False)
                    output = replaced_model(**inputs, use_cache = False)
                    inputs = inputs['input_ids']
                else:
                    with torch.no_grad():
                        orig(inputs)
                    output = replaced_model(inputs)
                if repdist != 'MSE':
                    B = inputs.size(0)
                    fo_flat = activations_o['feat'].reshape(B, -1)
                    fo_flat = fo_flat / (fo_flat.norm(dim = 1, keepdim = True) + 1e-8)
                    fl_flat = activations_l['feat'].reshape(B, -1)
                    fl_flat = fl_flat / (fl_flat.norm(dim = 1, keepdim = True) + 1e-8)
                    if repdist == 'CKA':
                        loss = 1.0 - cka.linear_CKA(fo_flat, fl_flat)
                    elif args.repdist == 'KNN':
                        rep_sim_loss = 1.0 - knn.soft_knn_alignment_topk(fo, fl)
                    elif repdist == 'Procrustes':
                        loss = 1.0 - proc.orthogonal_procrustes_similarity(fo_flat, fl_flat)
                    else:
                        raise NotImplementedError()
                elif repdist == 'MSE':
                    fo_flat = activations_o['feat']
                    fl_flat = activations_l['feat']
                    loss_fn = nn.MSELoss()
                    loss = loss_fn(fo_flat, fl_flat)
                else:
                    raise NotImplementedError()
                if distillation:
                    kl_loss_fn = nn.KLDivLoss(reduction = 'batchmean', log_target = True)
                    with torch.no_grad():
                        orig_output = F.log_softmax(orig_output / temperature, dim = 1)
                    new_output = F.log_softmax(new_output / temperature, dim = 1)
                    kl = kl_loss_fn(new_output, orig_output) * (temperature ** 2)
                    loss = loss * alpha + (1 - alpha) * kl
            if i == 0 and epoch == 0:
                print('Starting Similarity...', loss.item())

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            nn.utils.clip_grad_norm_([p for p in replaced_model.parameters() if p.requires_grad], max_norm=1.0)
            caler.step(optimizer)
            scaler.update()
            cka_losses.append(loss.item())
            epoch_cka_losses.append(loss.item())
        avg_train_loss = np.mean(epoch_cka_losses)
        print(f'Epoch {epoch}, Avg Loss {avg_train_loss}')

    h_o.remove()
    h_l.remove()

    return module_l.state_dict(), cka_losses

def match_all_layers(args, indices, orig, replaced_model, input_shapes, target_type, train_loader, device, 
                     lr = 1e-4, num_epochs = 10, cka_scheduler = False, token_avg = True):
    activations_o, activations_l = {}, {}
    hooks = []
    for idx in indices:
        if args.replace_norm:
            module_o = get_norm_after_module(orig, idx, target_type)
            module_l = get_norm_after_module(replaced_model, idx, target_type)
        else:
            module_o = get_nth_module(orig, idx, target_type)
            module_l = get_nth_module(replaced_model, idx, target_type)
        # capture the feature-tensor for this idx
        hooks.append(module_o.register_forward_hook(lambda m,i,o,idx=idx: activations_o.__setitem__(idx,o)))
        hooks.append(module_l.register_forward_hook(lambda m,i,o,idx=idx: activations_l.__setitem__(idx,o)))

    for p in replaced_model.parameters():
        p.requires_grad = False
    for idx in indices:
        module_l = get_nth_module(replaced_model, idx, target_type)
        for p in module_l.parameters():
            p.requires_grad = True

    optimizer = optim.AdamW([p for p in replaced_model.parameters() if p.requires_grad], lr = lr, weight_decay = 0.0)
    scheduler = None
    if cka_scheduler:
        steps_per_epoch = len(train_loader)
        total_steps = num_epochs * steps_per_epoch
        if num_epochs > 5:
            warmup_epochs = 5
            print(f'Building scheduler with {warmup_epochs} epochs of warmup')
            warmup_steps = warmup_epochs * steps_per_epoch
            cosine_steps = total_steps - warmup_steps
        else:
            warmup_ratio = 0.03
            print(f'Building scheduler with {warmup_ratio}% steps of warmup')
            warmup_steps = int(warmup_ratio * total_steps)
            cosine_steps = total_steps - warmup_steps
        sched_warmup = optim.lr_scheduler.LinearLR(optimizer, start_factor = 0.1, end_factor = 1.0, total_iters = warmup_steps)
        sched_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = cosine_steps)  
        scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers = [sched_warmup, sched_cosine], milestones = [warmup_steps])
    cka = CKA(device)
    proc = Procrustes(device)
    joint_losses = []
    layer_losses = {}
    scaler = GradScaler(enabled=args.use_amp)

    replaced_model.train()
    orig.eval()
    for epoch in range(num_epochs):
        epoch_losses = []
        for batch in tqdm(train_loader, desc = f'Joint CKA Epoch {epoch}', dynamic_ncols = True):
            if args.setting == 'imagenet':
                inputs = batch[0]['images']
                inputs = inputs.permute(0, 3, 1, 2)
            elif args.setting == 'cifar':
                inputs, _ = batch
                inputs = inputs.to(device)
            elif args.setting == 'text':
               inputs = {'input_ids': batch['input_ids'].to(device)}
            activations_o.clear(); activations_l.clear()
            with torch.autocast(device_type = 'cuda', dtype = torch.float16):
                if isinstance(inputs, dict):
                    with torch.no_grad():
                        orig(**inputs, use_cache = False)
                    output = replaced_model(**inputs, use_cache = False)
                    inputs = inputs['input_ids']
                else:
                    with torch.no_grad():
                        orig(inputs)
                    output = replaced_model(inputs)

                loss = torch.tensor(0.0).to(device)
                total_weight = 0.0
                per_layer_loss = []
                for i, idx in enumerate(indices):
                    if isinstance(activations_o[idx], tuple):
                        activations_o[idx] = activations_o[idx][0]
                    if isinstance(activations_l[idx], tuple):
                        activations_l[idx] = activations_l[idx][0]
                        
                    if not token_avg:
                        fo = activations_o[idx].reshape(inputs.size(0), -1)
                        fl = activations_l[idx].reshape(inputs.size(0), -1)
                    else:
                        fo = torch.mean(activations_o[idx], dim = 1).reshape(inputs.size(0), -1)
                        fl = torch.mean(activations_l[idx], dim = 1).reshape(inputs.size(0), -1)
                    fo = fo / (fo.norm(dim=1, keepdim=True) + 1e-8)
                    fl = fl / (fl.norm(dim=1, keepdim=True) + 1e-8)
                    if args.repdist == 'CKA':
                        rep_sim_loss = 1.0 - cka.linear_CKA(fo, fl)
                    elif args.repdist == 'KNN':
                        rep_sim_loss = 1.0 - knn.soft_knn_alignment_topk(fo, fl)
                    else:
                        rep_sim_loss = 1.0 - proc.orthogonal_procrustes_similarity(fo, fl)
                    per_layer_loss.append(rep_sim_loss)
                    if f'Layer_{idx}' in layer_losses:
                        layer_losses[f'Layer_{idx}'].append(rep_sim_loss.item())
                    else:
                        layer_losses[f'Layer_{idx}'] = [rep_sim_loss.item()]
                per_layer_loss = torch.stack(per_layer_loss, dim = 0)
                
                if args.layer_weight:
                    if args.weighting_strategy == 'complicated':
                        with torch.no_grad():
                            Lvec = per_layer_loss.detach()
                            S = Lvec.sum() + 1e-12
                            rho = 0.20
                            w_prop = Lvec / S
                            w = (1.0 - rho) * w_prop + rho * (1.0 / Lvec.numel())
                    
                        loss = torch.sum(w * per_layer_loss)
                        w_to_log = w
                    else:
                        total_unweighted_loss = torch.sum(per_layer_loss)
                        if len(indices) > 1:
                            w = per_layer_loss / (total_unweighted_loss + 1e-8)
                        else:
                            w = per_layer_loss / (total_unweighted_loss)
                        loss = torch.sum(w * per_layer_loss)
                else:
                    loss = torch.mean(per_layer_loss)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            nn.utils.clip_grad_norm_([p for p in replaced_model.parameters() if p.requires_grad], max_norm=0.5)
            caler.step(optimizer)
            scaler.update()

            if args.task_loss and labels is not None:
                loss_fn = nn.CrossEntropyLoss()
                label = batch[0]['labels'].squeeze().long()
                task_loss = loss_fn(output, label)
                loss += task_loss
                
            if args.local_loss:
                loss_fn = nn.MSELoss()
                mse_loss = loss_fn(fo, fl)
                loss += mse_loss
                
            if scheduler:
                scheduler.step()
                
            wandb_indices = '_'.join([str(x) for x in indices])
            wandb.log({f'{wandb_indices}_total_loss': loss})
            for i, wan_idx in enumerate(indices):
                wandb.log({f'{wandb_indices}_per_layer_loss/{wan_idx}': layer_losses[f'Layer_{wan_idx}'][-1]})
                if args.layer_weight:
                    wandb.log({f'{wandb_indices}_per_layer_weight/{wan_idx}': w[i].item()})
            unweighted_avg_loss = np.mean([layer_losses[f'Layer_{wan_idx}'][-1] for wan_idx in indices])
            wandb.log({f'{wandb_indices}_unweighted_avg_loss': unweighted_avg_loss})
            wandb_indices = '_'.join([str(x) for x in indices])
            wandb.log({f'{wandb_indices}_total_loss': loss})
            epoch_losses.append(loss.item())
            joint_losses.append(loss.item())
        epoch_layer_losses = {layer:np.mean(values[-len(train_loader):]) for layer,values in layer_losses.items()}
        print(f'Epoch {epoch}, layer losses: {epoch_layer_losses}')
        print(f'Epoch {epoch} avg loss: {np.mean(epoch_losses):.4f}')

    for h in hooks:
        h.remove()

    return replaced_model, layer_losses

def match_all_groups(args, group_ids, orig, replaced_model, orig_hook_modules, repl_hook_modules,
                     train_loader, device, lr=1e-4, num_epochs=10, cka_scheduler=False,
                     extra_trainable_modules = False):
    activations_o, activations_l = {}, {}
    hooks = []
    for gid in group_ids:
        mo = orig_hook_modules[gid]
        ml = repl_hook_modules[gid]
        hooks.append(mo.register_forward_hook(lambda m, i, o, gid=gid: activations_o.__setitem__(gid, o)))
        hooks.append(ml.register_forward_hook(lambda m, i, o, gid=gid: activations_l.__setitem__(gid, o)))

    for p in replaced_model.parameters():
        p.requires_grad = False
    for gid in group_ids:
        for p in repl_hook_modules[gid].parameters():
            p.requires_grad = True

    if extra_trainable_modules:
        for mod in extra_trainable_modules: 
            for p in mod.parameters(): 
                p.requires_grad = True

    optimizer = optim.AdamW([p for p in replaced_model.parameters() if p.requires_grad],
                            lr=lr, weight_decay=0.0)

    scheduler = None
    if cka_scheduler:
        warmup_epochs = 5
        steps_per_epoch = len(train_loader)
        total_steps = num_epochs * steps_per_epoch
        warmup_steps = min(warmup_epochs * steps_per_epoch, max(1, total_steps // 3))
        cosine_steps = max(1, total_steps - warmup_steps)
        sched_warmup = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps)
        sched_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_steps)
        scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[sched_warmup, sched_cosine],
                                                    milestones=[warmup_steps])

    cka = CKA(device)
    proc = Procrustes(device)
    group_losses = {} 
    replaced_model.train()
    orig.eval()

    for epoch in range(num_epochs):
        epoch_losses = []
        for batch in tqdm(train_loader, desc=f'Joint CKA (groups) Epoch {epoch}', dynamic_ncols=True):
            if args.setting == 'imagenet':
                inputs = batch[0]['images'].permute(0, 3, 1, 2).to(device)
                labels = batch[0].get('labels', None)
                if labels is not None:
                    labels = labels.squeeze().long().to(device)
            elif args.setting == 'cifar':
                inputs, labels = batch
                inputs = inputs.to(device)
                if labels is not None:
                    labels = labels.to(device)
            else:
                raise NotImplementedError()

            with torch.no_grad():
                orig(inputs)
            output = replaced_model(inputs)

            loss = torch.tensor(0.0, device=device)
            total_weight = 0.0
            per_group_loss = []
            for i, gid in enumerate(group_ids):
                fo = activations_o[gid].reshape(inputs.size(0), -1)
                fl = activations_l[gid].reshape(inputs.size(0), -1)
                if args.post_activation:
                    fo = F.relu(fo)
                    fl = F.relu(fl)
                if args.repdist == 'CKA':
                    rep_sim_loss = 1.0 - cka.linear_CKA(fo, fl)
                else:
                    rep_sim_loss = 1.0 - proc.orthogonal_procrustes_similarity(fo, fl)
                group_losses.setdefault(f'Group_{gid}', []).append(rep_sim_loss.item())
                per_group_loss.append(rep_sim_loss)
                if args.local_loss:
                    loss = loss + F.mse_loss(fo, fl)
            per_group_loss = torch.stack(per_group_loss, dim=0)
            if args.layer_weight:
                if args.weighting_strategy == 'complicated':
                    with torch.no_grad():
                        Lvec = per_group_loss.detach()
                        S = Lvec.sum() + 1e-12
                        rho = 0.20
                        w_prop = Lvec / S
                        w = (1.0 - rho) * w_prop + rho * (1.0 / Lvec.numel())
                
                    loss = torch.sum(w * per_group_loss)
                    w_to_log = w
                else:
                    total_unweighted_loss = torch.sum(per_group_loss)
                    if len(indices) > 1:
                        w = per_group_loss / (total_unweighted_loss + 1e-8)
                    else:
                        w = per_group_loss / (total_unweighted_loss)
                    loss = torch.sum(w * per_group_loss)
            else:
                loss = torch.mean(per_group_loss)

            if args.task_loss and labels is not None:
                task_loss = F.cross_entropy(output, labels)
                loss = loss + task_loss

            optimizer.zero_grad()
            loss.backward()
            for gid in group_ids:
                nn.utils.clip_grad_norm_(repl_hook_modules[gid].parameters(), max_norm=1.0)
            optimizer.step()
            if scheduler:
                scheduler.step()
            wandb_indices = '_'.join([str(x) for x in group_ids])
            try:
                wandb.log({f'{wandb_indices}_total_loss': loss})
                for i, wan_idx in enumerate(group_ids):
                    wandb.log({f'{wandb_indices}_per_group_loss/{wan_idx}': group_losses[f'Group_{wan_idx}'][-1]})
                    if args.layer_weight:
                        wandb.log({f'{wandb_indices}_per_layer_weight/{wan_idx}': w[i].item()})
                unweighted_avg_loss = np.mean([layer_losses[f'Layer_{wan_idx}'][-1] for wan_idx in group_ids])
                wandb.log({f'{wandb_indices}_unweighted_avg_loss': unweighted_avg_loss})
                wandb_indices = '_'.join([str(x) for x in indices])
                wandb.log({f'{wandb_indices}_total_loss': loss})
            except Exception:
                pass

            epoch_losses.append(loss.item())
        print(f'Epoch {epoch} avg loss: {np.mean(epoch_losses):.4f}')

    for h in hooks:
        h.remove()
    return replaced_model, group_losses

def progressive_rn50_to_rn18(args, orig, train_loader, device):
    schedule, orig_hook_modules = build_progressive_schedule_and_orig_hooks(orig)
    print('[RN50→RN18] Progressive schedule (layer, group_size):', schedule)

    replace_model_current = copy.deepcopy(orig).to(device)
    replaced_hook_modules = []   # list[BasicBlockCompat], aligned with schedule
    log_json = {}

    lrs = args.progressive_lrs
    epochs_per_stage = args.progressive_epochs
    num_steps = len(schedule)

    if len(lrs) < num_steps:
        lrs.extend([lrs[-1]] * (num_steps - len(lrs)))
    if len(epochs_per_stage) < num_steps:
        epochs_per_stage.extend([epochs_per_stage[-1]] * (num_steps - len(epochs_per_stage)))

    start_step = 0
    if args.reload_progressive:
        ckpt_name = args.prog_ckpt
        if ckpt_name:
            ckpt_path = f'saved_models/{args.exp_name}/{ckpt_name}.pt'
            assert os.path.exists(ckpt_path), f'Checkpoint not found: {ckpt_path}'
            replace_model_current.load_state_dict(torch.load(ckpt_path, map_location=device))
            replace_model_current = replace_model_current.to(device)
            import re
            m = re.search(r'stage(\d+)', ckpt_name)
            if m:
                start_step = int(m.group(1))
                print(f'[RN50→RN18] Resuming from stage {start_step}')
            # Pre-fill replaced_hook_modules up to start_step by scanning the model (left to right)
            # We assume that previously replaced groups appear as BasicBlockCompat at the left
            replaced_hook_modules = []
            for step_idx in range(start_step):
                lname, gsz = schedule[step_idx]
                # No actual replacement here—just fetch the currently leftmost BasicBlockCompat for bookkeeping
                seq = getattr(replace_model_current, lname)
                bb = None
                for m in seq.children():
                    if isinstance(m, BasicBlockCompat):
                        bb = m
                        break
                assert bb is not None, f'Could not find replaced block for resumed step {step_idx} in {lname}'
                replaced_hook_modules.append(bb)

    os.makedirs(f'saved_models/{args.exp_name}', exist_ok=True)

    for step_idx in range(start_step, num_steps):
        lname, group_k = schedule[step_idx]
        current_lr = lrs[step_idx]
        current_epochs = epochs_per_stage[step_idx]
        print(f'\n================= RN50→RN18 Progressive Step {step_idx+1}/{num_steps} =================')
        print(f'Layer: {lname} | Replacing next group of {group_k} Bottlenecks | LR={current_lr} | Epochs={current_epochs}')

        new_block = replace_leftmost_k_bottlenecks_with_basic(replace_model_current, lname, k=group_k)
        replace_model_current = replace_model_current.to(device)
        replaced_hook_modules.append(new_block)

        active_group_ids = list(range(0, len(replaced_hook_modules)))
        replace_model_current, stage_losses = match_all_groups(
            args, active_group_ids, orig, replace_model_current,
            orig_hook_modules, replaced_hook_modules,
            train_loader, device, num_epochs=current_epochs, lr=current_lr,
            cka_scheduler=getattr(args, 'cka_scheduler', False)
        )

        log_json[f'Stage_{step_idx+1}_{lname}_k{group_k}'] = stage_losses

        ckpt_file = f'saved_models/{args.exp_name}/replace_progressive_rn50to18_stage{step_idx+1}.pt'
        torch.save(replace_model_current.state_dict(), ckpt_file)
        print(f'[RN50→RN18] Saved {ckpt_file}')

    replace_model_final = copy.deepcopy(replace_model_current)
    return replace_model_final, log_json

def progressive_align_rn18_to_rn50(args, teacher_rn50, student_rn18, train_loader, device):
    '''
    Progressive alignment: introduce RN18 blocks one-by-one (unfreeze), and at each
    step jointly CKA-match all introduced RN18 blocks to their paired RN50 groups.
    RN50 stays fixed; RN18 is the model being trained.
    '''
    schedule, teacher_hook_modules, student_hook_modules = build_alignment_rn50_rn18(teacher_rn50, student_rn18)
    print('[RN18←RN50] Progressive schedule (8 steps):', schedule)

    teacher_rn50 = teacher_rn50.to(device).eval()
    student_rn18 = student_rn18.to(device)

    # allow the stem to adapt (conv1/bn1) from the start
    extra_trainables = []
    if hasattr(student_rn18, 'conv1'): extra_trainables.append(student_rn18.conv1)
    if hasattr(student_rn18, 'bn1'):   extra_trainables.append(student_rn18.bn1)

    num_steps = len(schedule)  # 8
    lrs = args.progressive_lrs
    epochs = args.progressive_epochs
    if len(lrs) < num_steps:     lrs += [lrs[-1]] * (num_steps - len(lrs))
    if len(epochs) < num_steps:  epochs += [epochs[-1]] * (num_steps - len(epochs))

    os.makedirs(f'saved_models/{args.exp_name}', exist_ok=True)
    log_json = {}

    for step_idx in range(num_steps):
        active_gids = schedule[:step_idx + 1]
        print(f'\n===== RN18←RN50 Alignment Step {step_idx+1}/{num_steps} | active groups: {active_gids} =====')
        student_rn18, stage_losses = match_all_groups(
            args, active_gids, teacher_rn50, student_rn18,
            teacher_hook_modules, student_hook_modules,
            train_loader, device,
            lr=lrs[step_idx], num_epochs=epochs[step_idx],
            cka_scheduler=args.cka_scheduler,
            extra_trainable_modules=extra_trainables
        )
        log_json[f'Stage_{step_idx+1}'] = stage_losses

        ckpt_file = f'saved_models/{args.exp_name}/rn18_align_stage{step_idx+1}.pt'
        torch.save(student_rn18.state_dict(), ckpt_file)
        print(f'[RN18←RN50] Saved {ckpt_file}')

    return student_rn18, log_json

def layer_replace(args, orig, indices, replacement_type = 'independent', low_rank = True, rank = 1024, reload_ft = False, reload_cka = False, reload_cka_model = None):
    log_json = {}
    torch.cuda.empty_cache()
    if args.setting == 'imagenet':
        train_loader, _ = get_imagenet_dataloaders(batch_size = args.cka_batch)
    elif args.setting == 'cifar':
        train_loader, _ = cifar_dataset(batch_size = args.cka_batch)
    elif args.setting == 'text':
        train_loader, _, _, _ = get_text_dataloaders(args.model, batch_size = args.cka_batch, seq_len = args.seq_len, num_workers = args.num_workers)
    else:
        raise NotImplementedError()
    
    from_attn = False
    collect_input = True
    if args.target_type == 'conv':
        target_type = nn.Conv2d
    elif args.target_type == 'mha':
        target_type = nn.MultiheadAttention
    elif args.target_type == 'vit':
        target_type = nn.Linear
    elif args.target_type == 'mematt':
        target_type = MemEffAttention
    elif args.target_type == 'llm_attention':
        target_type = Qwen2Attention
        from_attn = True
        collect_input = False
    elif args.target_type == 'attention':
        target_type = GPT2Attention
        collect_input = False
    else:
        raise NotImplementedError
    print(f'{args.target_type} layers:')
    potential_layers = list_target_layers(orig, target_type)
    print(f'{len(potential_layers)} layers for replacement')
    for idx, name in potential_layers:
        print(f'  {idx}: {name}')
    print('Replacing indices', indices)
    print('Using low_rank:', low_rank, 'Rank:', rank)

    replace_model_final = None
    if collect_input:
        batch = next(iter(train_loader))
        if args.setting == 'imagenet':
            dummy_input = torch.randn(size = batch[0]['images'].permute(0, 3, 1, 2).shape)
        else:
            dummy_input = torch.randn(size = batch[0].shape)
        input_shapes = extract_input_shapes(orig, dummy_input = dummy_input, target_type = target_type)
        print('Extracted following input shapes:', input_shapes, len(input_shapes))
    else:
        print('No input shape extraction')
        input_shapes = [None for i in range(len(indices))]

    if reload_ft or args.baseline:
        if replacement_type != 'progressive_align_rn18_to_rn50':
            replace_model_final = copy.deepcopy(orig).to(device)
            for idx in indices:
                input_shape = input_shapes[idx - 1]
                replace_nth_module(replace_model_final, idx, input_shape, device, target_type, use_low_rank=low_rank, rank=rank, from_attn = from_attn)
            return replace_model_final
        else:
            replace_model_final = torchvision.models.resnet18(pretrained = False).to(device)
            return replace_model_final
    elif reload_cka:
        print(f'Reloading {reload_cka_model}')
        replace_model_final = copy.deepcopy(orig).to(device)
        with open(f'logs/{reload_cka_model}/args.json', 'r') as f:
            reload_args = json.load(f)
        reload_indices = list(reload_args['indices'])
        print(f'Reloading indices... {reload_indices}')
        for idx in reload_indices:
            input_shape = input_shapes[idx - 1]
            replace_nth_module(replace_model_final, idx, input_shape, device, target_type,
                            use_low_rank=low_rank, rank=rank, from_attn = from_attn)
        chkpt = torch.load(f'saved_models/orig_{reload_cka_model}.pt')
        replace_model_final.load_state_dict(chkpt)
        indices = [idx for idx in indices if idx not in reload_indices]
        print(f'Remaining indices...{indices}')

    if replacement_type == 'independent':
        saved_states = {}
        for idx in indices:
            print(f'\n=== Independently Matching layer {idx} ===')
            state, cka_losses = match_single_layer(idx, orig, input_shapes,
                                                train_loader, device, repdist = args.repdist,
                                                low_rank = low_rank, rank = rank, lr = args.cka_lr, epochs = args.cka_epochs)
            log_json[f'Layer_{idx}'] = cka_losses
            saved_states[idx] = state
        if replace_model_final == None:
            assert not reload_cka, 'This behavior is unexpected when reload_cka is False!'
            replace_model_final = copy.deepcopy(orig).to(device)
        for idx in indices:
            input_shape = input_shapes[idx - 1]
            replace_nth_module(replace_model_final, idx, input_shape, device, target_type,
                            use_low_rank=low_rank, rank=rank, from_attn = from_attn)
            replacement_layer = get_nth_module(replace_model_final, idx, target_type)
            replacement_layer.load_state_dict(saved_states[idx])

    elif replacement_type == 'sequential':
        replace_model_final = copy.deepcopy(orig).to(device)
        all_losses = {}
        for idx in indices:
            print(f'\n=== Sequentially matching layer {idx} ===')
            state, losses = match_single_layer(idx, replace_model_final, input_shapes, train_loader,
                                               device, repdist = args.repdist, low_rank = low_rank, rank = rank,
                                               lr = args.cka_lr, epochs = args.cka_epochs)
            input_shape = input_shapes[idx - 1]
            replace_nth_module(replace_model_final, idx, input_shape, device, target_type, use_low_rank = low_rank, rank = rank, from_attn = from_attn)
            replacement_layer = get_nth_module(replace_model_final, idx, target_type)
            replacement_layer.load_state_dict(state)
            replace_model_final = replace_model_final.to(device)
            all_losses[f'Layer_{idx}'] = losses

    elif replacement_type == 'joint':
        prt_indices = ' '.join([str(x) for x in indices])
        print(f'\n=== Jointly matching layers {prt_indices} ===')
        replace_model_current = copy.deepcopy(orig).to(device)
        for idx in indices:
            input_shape = input_shapes[idx - 1]
            replace_nth_module(replace_model_current, idx, input_shape, device, target_type, use_low_rank = low_rank, rank = rank, from_attn = from_attn)
        print(replace_model_current)
        replace_model_current = replace_model_current.to(device)
        replace_model_final, log_json = match_all_layers(args, indices, orig, replace_model_current, input_shapes, target_type, train_loader, device)
    
    elif replacement_type == 'progressive':
        prt_indices = ' '.join([str(x) for x in indices])
        print(f'\n=== Progressively matching layers {prt_indices} ===')
        indices = sorted(indices)
        num_stages = len(indices)
        replace_model_current = copy.deepcopy(orig).to(device)
        lrs = args.progressive_lrs
        epochs_per_stage = args.progressive_epochs

        cumulative_indices = []

        if args.reload_progressive:
            print(f'Reloading model: {args.prog_ckpt}')
            assert os.path.exists(f'saved_models/{args.exp_name}/{args.prog_ckpt}' + '.pt')
            ckpt_start_idx = int(args.prog_ckpt.split('_')[-2])
            length = int(args.prog_ckpt.split('_')[-1])
            assert ckpt_start_idx == indices[0]
            end_idx = ckpt_start_idx + length
            cumulative_indices = [x for x in range(ckpt_start_idx, end_idx + 1)]
            for j in cumulative_indices:
                input_shape = input_shapes[j - 1]
                replace_nth_module(replace_model_current, j, input_shape, device, target_type,  use_low_rank = low_rank, rank = rank, from_attn = from_attn)
            replace_model_current.load_state_dict(torch.load(f'saved_models/{args.exp_name}/{args.prog_ckpt}' + '.pt'))
            replace_model_current = replace_model_current.to(device)
            indices = [idx for idx in indices if idx not in cumulative_indices]
            indices = sorted(indices)
            num_stages = len(indices)

        if len(lrs) < num_stages:
            print(f'Warning: Number of LRs ({len(lrs)}) is less than number of stages ({num_stages}). Extending with last LR value.')
            lrs.extend([lrs[-1]] * (num_stages - len(lrs)))

        if len(epochs_per_stage) < num_stages:
            print(f'Warning: Number of Epochs ({len(epochs_per_stage)}) is less than number of stages ({num_stages}). Extending with last Epoch value.')
            epochs_per_stage.extend([epochs_per_stage[-1]] * (num_stages - len(epochs_per_stage)))

        if not os.path.exists(f'saved_models/{args.exp_name}'):
            os.makedirs(f'saved_models/{args.exp_name}')
        cka_scheduler = args.cka_scheduler
        for stage_idx, new_idx_to_add in enumerate(indices):
            current_lr = lrs[stage_idx]
            current_epochs = epochs_per_stage[stage_idx]
            print(f'\n=================Progressive Stage {stage_idx + 1}/{num_stages} =================')
            print(f'Adding layer: {new_idx_to_add} | LR for this stage: {current_lr} | Numbers of epochs: {current_epochs}')

            input_shape = input_shapes[new_idx_to_add - 1]
            replace_nth_module(replace_model_current, new_idx_to_add, input_shape, device, target_type,
                               use_low_rank = low_rank, rank = rank, from_attn = from_attn)
            replace_model_current = replace_model_current.to(device)
            # replace_model_current = replace_model_current.to(dtype = orig.dtype)
            if from_attn or args.model == 'gpt2':
                replace_model_current.config.use_cache = False
                replace_model_current.gradient_checkpointing_enable(
                    gradient_checkpointing_kwargs = {'use_reentrant': False}
                )
                replace_model_current.enable_input_require_grads()
                replace_model_current.train()
            cumulative_indices.append(new_idx_to_add)
            print(f'--- Jointly training all modified layers: {cumulative_indices} ---')
            
            replace_model_current, stage_losses = match_all_layers(args, cumulative_indices, orig, replace_model_current, input_shapes, target_type, train_loader, device,
                num_epochs = current_epochs, lr = current_lr, cka_scheduler = cka_scheduler, token_avg = False)
            log_json[f'Stage_{stage_idx+1}_Layer_{new_idx_to_add}'] = stage_losses
            replace_model_final = copy.deepcopy(replace_model_current)
            save_idx = new_idx_to_add - cumulative_indices[0]
            model_save_str = f'replace_progressive_{cumulative_indices[0]}_{save_idx}.pt'
            torch.save(replace_model_current.state_dict(), f'saved_models/{args.exp_name}/{model_save_str}')

    elif replacement_type == 'progressive_rn50_to_rn18':
        print('[RN50→RN18] Starting progressive, block-level replacement w/ CKA.')
        replace_model_final, log_json = progressive_rn50_to_rn18(args, orig, train_loader, device)

    elif replacement_type == 'progressive_align_rn18_to_rn50':
        print('[RN18←RN50] Starting progressive student alignment (teacher=ResNet-50 fixed).')
        teacher = orig.to(device)
        student = torchvision.models.resnet18(pretrained = False)
        student = student.to(device)
        replace_model_final, log_json = progressive_align_rn18_to_rn50(args, teacher, student, train_loader, device)
    
    else:
        raise NotImplementedError('Select an optimization strategy among independent, sequential, joint, progressive, or progressive rn50')

    os.makedirs(f'logs/{args.exp_name}', exist_ok = True)
    with open(f'logs/{args.exp_name}/rep_sim_matches.json', 'w') as f:
        json.dump(log_json, f)
    # plot_cka(args, log_json, rank if low_rank else None)
    return replace_model_final