"""
FlexLoRA Training Script
Dynamic rank allocation with orthogonal regularization
"""
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn import functional as F
from tqdm import tqdm
from timm.models import create_model
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.loss import LabelSmoothingCrossEntropy
from argparse import ArgumentParser
from utils import *
import numpy as np
import psutil
import os

from models import vision_transformer_flexlora
from models.flexlora_layers import compute_orth_regu
from models.rank_allocator import RankAllocator

from avalanche.evaluation.metrics.accuracy import Accuracy


def train_flexlora(config, model, criterion, train_dl, test_dl, opt, scheduler, 
                   logger, epochs, task, rank_allocator=None, orth_weight=0.1):
    """
    FlexLoRA training loop with dynamic rank allocation
    
    Args:
        config: Configuration dict
        model: Model with LoRA layers
        criterion: Loss function
        train_dl: Training dataloader
        test_dl: Test dataloader
        opt: Optimizer
        scheduler: Learning rate scheduler
        logger: Logger
        epochs: Number of epochs
        task: Task name ('vtab' or 'fgvc')
        rank_allocator: RankAllocator instance
        orth_weight: Orthogonal regularization weight
    """
    model.train()
    model = model.cuda()
    
    global_step = 0
    total_steps = epochs * len(train_dl)
    
    if rank_allocator is not None:
        rank_allocator.set_total_step(total_steps)
        logger.info(f"RankAllocator initialized with {total_steps} total steps")
    
    for ep in tqdm(range(epochs), desc="Epochs"):
        model.train()
        model = model.cuda()
        
        epoch_loss = 0.0
        epoch_orth_loss = 0.0
        
        for i, batch in enumerate(tqdm(train_dl, desc=f"Epoch {ep+1}", leave=False)):
            # Prepare data
            if task == 'vtab':
                x, y = batch[0].cuda(), batch[1].cuda()
            elif task == 'fgvc':
                if not isinstance(batch["image"], torch.Tensor):
                    for k, v in batch.items():
                        batch[k] = torch.from_numpy(v)
                x = batch["image"].float().cuda()
                y = batch["label"].cuda()
            else:
                print("Error Task Name")
                break
            
            # Dynamic rank allocation BEFORE forward pass
            rank_adjusted = False
            if rank_allocator is not None:
                _, mask_result = rank_allocator.update_and_mask(model, global_step)
                if mask_result is not None:
                    rank_adjusted = True
                    # Note: We do NOT update ranknum here, matching the reference implementation.
                    # The scaling factor (alpha/ranknum) remains based on the initial rank.
                    
                    # 1. Collect new trainable parameters (Crucial: Filter only trainable ones!)
                    trainable_params = []
                    for n, p in model.named_parameters():
                        if 'lora_A' in n or 'lora_B' in n or 'lora_E' in n or 'head' in n:
                            trainable_params.append(p)
                    
                    # 2. Save old optimizer state
                    old_opt_state = opt.state_dict()
                    
                    # 3. Get current hyperparameters
                    current_lr = opt.param_groups[0]['lr']
                    current_wd = opt.param_groups[0]['weight_decay']
                    
                    # 4. Reinitialize optimizer with new parameters
                    # Filter out 'differentiable' from defaults as it causes TypeError in some PyTorch versions
                    optimizer_kwargs = {k: v for k, v in opt.defaults.items() if k != 'differentiable'}
                    opt = AdamW(trainable_params, **optimizer_kwargs)
                    opt.param_groups[0]['lr'] = current_lr
                    opt.param_groups[0]['weight_decay'] = current_wd

                    
                    # 5. Restore state for unchanged parameters
                    # We rely on object identity (id(p)). 
                    # If a parameter was not modified by rank_allocator, its ID is preserved.
                    # If it was modified/replaced, its ID changed, and we correctly reset its state.
                    new_opt_state = opt.state_dict()
                    preserved_count = 0
                    
                    for p in trainable_params:
                        p_id = id(p)
                        if p_id in old_opt_state['state']:
                            # Found state for this parameter ID (means it's the same object)
                            old_state = old_opt_state['state'][p_id]
                            
                            # Verify shapes match (sanity check)
                            shapes_match = True
                            for k in ['exp_avg', 'exp_avg_sq']:
                                if k in old_state and old_state[k].shape != p.shape:
                                    shapes_match = False
                                    break
                            
                            if shapes_match:
                                new_opt_state['state'][p_id] = old_state
                                preserved_count += 1
                    
                    # Load the preserved states
                    opt.load_state_dict(new_opt_state)
                    
                    # Update scheduler to track new optimizer
                    if scheduler is not None:
                        scheduler.optimizer = opt
                    
                    logger.info(f"Step {global_step}: Rank adjusted, optimizer reinitialized "
                              f"(lr={current_lr:.6f}, {preserved_count}/{len(trainable_params)} states preserved)")
            
            # Forward pass
            out = model(x)
            
            # Compute main loss
            main_loss = criterion(out, y)
            
            # Compute orthogonal regularization
            if orth_weight > 0:
                orth_loss = compute_orth_regu(model, regu_weight=orth_weight)
                loss = main_loss + orth_loss
                epoch_orth_loss += orth_loss.item()
            else:
                loss = main_loss
                orth_loss = 0.0
            
            # Backward pass
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            epoch_loss += main_loss.item()
            global_step += 1
        
        # Step scheduler
        if scheduler is not None:
            scheduler.step(ep)
        
        # Logging
        avg_loss = epoch_loss / len(train_dl)
        avg_orth_loss = epoch_orth_loss / len(train_dl) if orth_weight > 0 else 0.0
        
        ram_used = psutil.virtual_memory().used / (1024.0 * 1024.0)
        memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
        
        # Evaluate every 10 epochs
        
        if ep % 10 == 9:
            # memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            acc = test(model, test_dl, task)
            if acc > config['best_acc']:
                config['best_acc'] = acc
                print(acc)
                # save('vit_mlae', config['task'], config['name'], model, acc, ep)
            logger.info(str(ep)+' '+str(acc)+' memory: '+str(memory_used)+'MB')
    
    # Final rank pattern
    if rank_allocator is not None:
        rank_pattern = rank_allocator.get_rank_pattern()
        logger.info("Final rank pattern:")
        for name, rank in rank_pattern.items():
            logger.info(f"  {name}: {rank}")
    
    model = model.cpu()
    return model


@torch.no_grad()
def test(model, dl, task):
    """Evaluate model"""
    model.eval()
    model = model.cuda()
    
    # Use avalanche Accuracy (original method)
    acc = Accuracy()
    for batch in dl:
        torch.cuda.empty_cache()
        if task == 'vtab':
            x, y = batch[0].cuda(), batch[1].cuda()
        elif task == 'fgvc':
            if not isinstance(batch["image"], torch.Tensor):
                for k, v in batch.items():
                    batch[k] = torch.from_numpy(v)
            x = batch["image"].float().cuda()
            y = batch["label"].cuda()
            
        out = model(x).data
        acc.update(out.argmax(dim=1).view(-1), y)
        
    # Handle different avalanche API versions
    result = acc.result()
    if isinstance(result, (list, tuple)):
        return result[0]  # Old API returns list
    else:
        return result  # New API returns float directly
    


if __name__ == '__main__':
    parser = ArgumentParser(description='FlexLoRA Training')
    
    # Basic arguments
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--wd', type=float, default=1e-4)
    parser.add_argument('--eval', type=str, default='True')
    parser.add_argument('--dpr', type=float, default=0.1)
    parser.add_argument('--topN', type=int, default=None)
    parser.add_argument('--model', type=str, default='vit_base_patch16_224_in21k_flexlora')
    parser.add_argument('--model_checkpoint', type=str, default='./released_models/ViT-B_16.npz')
    parser.add_argument('--model_type', type=str, default='vit_flexlora')
    parser.add_argument('--task', type=str, default='vtab')
    parser.add_argument('--dataset', type=str, default='cifar100')
    parser.add_argument('--tuning_mode', type=str, default='flexlora')
    
    # FlexLoRA specific arguments
    parser.add_argument('--enable_flexlora', action='store_true', 
                       help='Enable FlexLoRA dynamic rank allocation')
    parser.add_argument('--orth_weight', type=float, default=0.1,
                       help='Orthogonal regularization weight')
    parser.add_argument('--target_rank', type=int, default=8,
                       help='Average target rank per layer')
    parser.add_argument('--init_warmup', type=int, default=1000,
                       help='Initial warmup steps before rank adjustment')
    parser.add_argument('--final_warmup', type=int, default=1000,
                       help='Final warmup steps (no rank adjustment)')
    parser.add_argument('--mask_interval', type=int, default=100,
                       help='Interval between rank adjustments')
    parser.add_argument('--beta1', type=float, default=0.85,
                       help='EMA coefficient for importance')
    parser.add_argument('--beta2', type=float, default=0.85,
                       help='EMA coefficient for uncertainty')
    parser.add_argument('--b', type=int, default=4,
                       help='Number of layers to adjust each time')
    parser.add_argument('--enable_scheduler', action='store_true',
                       help='Enable b parameter scheduler')
    parser.add_argument('--importance_mode', type=str, default='entropy',
                       choices=['entropy', 'nuclear', 'frobenius'],
                       help='Importance computation mode')
    
    args = parser.parse_args()
    print(args)
    
    set_seed(args.seed)
    config = get_config('model_lora', args.task, args.dataset)
    
    if args.topN is not None:
        topN = args.topN
    else:
        topN = config['topN']
    
    exp_base_path = './output/%s/%s/%s' % (
        args.model_type, args.task, 
        config['name'] + '_dim_%d' % topN
    )
    mkdirss(exp_base_path)
    logger = create_logger(log_path=exp_base_path, log_name='training')
    
    logger.info(args)
    logger.info(config)
    
    # Prepare training data
    if args.eval == 'True':
        evalflag = True
    else:
        evalflag = False
    
    if 'train_aug' in config.keys():
        train_aug = config['train_aug']
    else:
        train_aug = False
    
    if args.task == 'vtab':
        from dataloader.vtab import get_data
        basedir = './vtab-1k'
        train_dl, test_dl = get_data(
            basedir, args.dataset, logger, 
            evaluate=evalflag, train_aug=train_aug, 
            batch_size=config['batch_size']
        )
    elif args.task == 'fgvc':
        from dataloader.loader import construct_train_loader, construct_test_loader
        train_dl = construct_train_loader(args.dataset, batch_size=config['batch_size'])
        test_dl = construct_test_loader(args.dataset, batch_size=config['batch_size'])
        print(len(train_dl), len(test_dl))
    
    # Create model
    if 'swin' in args.model:
        model = create_model(
            args.model, pretrained=False, drop_path_rate=args.dpr, 
            tuning_mode=args.tuning_mode, topN=topN
        )
        model.load_state_dict(torch.load(args.model_checkpoint)['model'], False)
    else:
        model = create_model(
            args.model, checkpoint_path=args.model_checkpoint, 
            drop_path_rate=args.dpr, tuning_mode=args.tuning_mode, topN=topN
        )
    
    model.reset_classifier(config['class_num'])
    logger.info(str(model))
    
    config['best_acc'] = 0
    config['task'] = args.task
    
    # Prepare trainable parameters
    trainable = []
    for n, p in model.named_parameters():
        if 'lora_A' in n or 'lora_B' in n or 'lora_E' in n or 'head' in n:
            trainable.append(p)
            logger.info(str(n))
        else:
            p.requires_grad = False
    
    opt = AdamW(trainable, lr=args.lr, weight_decay=args.wd)
    
    if 'cycle_decay' in config.keys():
        cycle_decay = config['cycle_decay']
    else:
        cycle_decay = 0.1
    
    scheduler = CosineLRScheduler(
        opt, t_initial=config['epochs'],
        warmup_t=config.get('warmup_epochs', 10), 
        lr_min=1e-5, warmup_lr_init=1e-6, 
        cycle_decay=cycle_decay
    )
    
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("number of extra params: {}M".format(n_parameters / 1000000))
    logger.info(f"number of extra params: {n_parameters}")
    
    # Loss function
    if config.get('labelsmoothing', 0) > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=config['labelsmoothing'])
        logger.info('label smoothing')
    else:
        criterion = torch.nn.CrossEntropyLoss()
        logger.info('CrossEntropyLoss')
    
    # Initialize RankAllocator if FlexLoRA is enabled
    rank_allocator = None
    if args.enable_flexlora:
        rank_allocator = RankAllocator(
            model=model,
            lora_r=topN,
            target_rank=args.target_rank,
            init_warmup=args.init_warmup,
            final_warmup=args.final_warmup,
            mask_interval=args.mask_interval,
            beta1=args.beta1,
            beta2=args.beta2,
            k=2,
            b=args.b,
            output_dir=exp_base_path,
            enable_scheduler=args.enable_scheduler,
            mode=args.importance_mode
        )
    else:
        logger.info("Standard LoRA training (FlexLoRA disabled)")
        args.orth_weight = 0.0
    
    # Train
    model = train_flexlora(
        config, model, criterion, train_dl, test_dl, opt, scheduler, 
        logger, config['epochs'], args.task, 
        rank_allocator=rank_allocator, 
        orth_weight=args.orth_weight
    )
    
    # Final parameter count after training
    final_n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Final number of trainable params: {final_n_parameters / 1000000:.4f}M")
    logger.info(f"Initial trainable params: {n_parameters} ({n_parameters / 1000000:.4f}M)")
    logger.info(f"Final trainable params: {final_n_parameters} ({final_n_parameters / 1000000:.4f}M)")
    logger.info(f"Parameter change: {(final_n_parameters - n_parameters) / 1000000:+.4f}M")
    
    print(f"Best accuracy: {config['best_acc']:.4f}")
    logger.info(f"Training completed. Best accuracy: {config['best_acc']:.4f}")
    logger.info('end')
