import os
import sys
import json
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import random_split
from tqdm import tqdm
from fvcore.nn import FlopCountAnalysis
import argparse

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from models.model import NetworkCIFAR
from models.utils import AverageMeter, accuracy, Cutout, _data_transforms_cifar10
from models.genotypes import DARTS
from models.operations import SEAttentionWeightedAggregation 


parser = argparse.ArgumentParser("NR-DARTS Final Training")
parser.add_argument('--save', type=str, default='./exp/nrdarts', help='experiment save directory')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--data', type=str, default='./data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--epochs', type=int, default=100, help='num of training epochs')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--search_model_path', type=str, default=None, help='path to the trained search model (optional)')
parser.add_argument('--num_prune_nodes', type=int, default=15, help='number of nodes to prune')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate for base model')
parser.add_argument('--se_learning_rate', type=float, default=0.02, help='learning rate for SE modules')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--warmup_epochs', type=int, default=10, help='num of warmup epochs for training base model')
parser.add_argument('--dropout_p', type=float, default=0.3, help='dropout probability for SE module')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')

def set_seed(seed):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True 

def measure_flops(model, input_size=(1, 3, 32, 32), device='cuda'):
    model.eval()
    dummy_input = torch.randn(*input_size).to(device)


    def get_tensor_numel(tensor_value):
        try:
            tensor_type = tensor_value.type()
            if hasattr(tensor_type, 'sizes'):
                sizes = tensor_type.sizes()
                numel = 1
                for size in sizes:
                    numel *= size
                return numel
            else:
                return 1
        except:
            return 1


    def aten_softmax_flop_jit(inputs, outputs):
        numel = get_tensor_numel(outputs[0])
        return numel * 5


    def aten_mul_flop_jit(inputs, outputs):
        return get_tensor_numel(outputs[0])


    def aten_add_flop_jit(inputs, outputs):
        return get_tensor_numel(outputs[0])


    def aten_softplus_flop_jit(inputs, outputs):
        numel = get_tensor_numel(outputs[0])
        return numel * 4


    def aten_max_pool2d_flop_jit(inputs, outputs):
        numel = get_tensor_numel(outputs[0])
        kernel_flops = 4
        return numel * kernel_flops


    def aten_adaptive_avg_pool2d_flop_jit(inputs, outputs):
        numel = get_tensor_numel(inputs[0])
        return numel


    def aten_linear_flop_jit(inputs, outputs):
        try:
            input_numel = get_tensor_numel(inputs[0])
            weight_type = inputs[1].type()
            if hasattr(weight_type, 'sizes'):
                weight_sizes = weight_type.sizes()
                if len(weight_sizes) >= 2:
                    return input_numel * weight_sizes[0]
            return 0
        except:
            return 0


    def aten_relu_flop_jit(inputs, outputs):
        return get_tensor_numel(outputs[0])


    def aten_sigmoid_flop_jit(inputs, outputs):
        numel = get_tensor_numel(outputs[0])
        return numel * 4 


    def aten_view_flop_jit(inputs, outputs):
        return 0

    flops = FlopCountAnalysis(model, dummy_input)
    
    flops.set_op_handle("aten::softmax", aten_softmax_flop_jit)
    flops.set_op_handle("aten::mul", aten_mul_flop_jit)
    flops.set_op_handle("aten::add", aten_add_flop_jit)
    flops.set_op_handle("aten::softplus", aten_softplus_flop_jit)
    flops.set_op_handle("aten::max_pool2d", aten_max_pool2d_flop_jit)
    flops.set_op_handle("aten::adaptive_avg_pool2d", aten_adaptive_avg_pool2d_flop_jit)
    flops.set_op_handle("aten::linear", aten_linear_flop_jit)
    flops.set_op_handle("aten::addmm", aten_linear_flop_jit)
    flops.set_op_handle("aten::relu", aten_relu_flop_jit)
    flops.set_op_handle("aten::relu_", aten_relu_flop_jit)
    flops.set_op_handle("aten::sigmoid", aten_sigmoid_flop_jit)
    flops.set_op_handle("aten::view", aten_view_flop_jit)
    flops.set_op_handle("aten::reshape", aten_view_flop_jit)
    
    return flops.total() / 1e6


def measure_inference_latency(model, device, input_size=(1, 3, 32, 32), iterations=100, warmup=10):
    model.eval(); model.to(device)
    dummy_input = torch.randn(*input_size).to(device)
    
    with torch.no_grad():
        for _ in range(warmup): _ = model(dummy_input)
    
    if device.type == 'cuda':
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
        timings = []
        with torch.no_grad():
            for _ in range(iterations):
                starter.record()
                _ = model(dummy_input)
                ender.record()
                torch.cuda.synchronize()
                timings.append(starter.elapsed_time(ender))
        return sum(timings) / len(timings)
    else:
        start = time.time()
        with torch.no_grad():
            for _ in range(iterations): _ = model(dummy_input)
        end = time.time()
        return ((end - start) / iterations) * 1000


def measure_peak_memory(model, device, input_size=(1, 3, 32, 32)):
    if device.type != 'cuda': return 0.0
    model.eval(); model.to(device)
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device=device)
    dummy_input = torch.randn(*input_size).to(device)
    with torch.no_grad(): _ = model(dummy_input)
    return torch.cuda.max_memory_allocated(device=device) / (1024 ** 2)


def train_one_epoch(loader, model, criterion, optimizer, device, epoch_desc):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(loader, desc=epoch_desc)
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits, _ = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        _, predicted = logits.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
        pbar.set_postfix(loss=f'{(total_loss/total):.3f}', acc=f'{(100.*correct/total):.2f}%')
    return total_loss / total, 100. * correct / total


def evaluate(loader, model, criterion, device, desc):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(loader, desc=desc)
    with torch.no_grad():
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item() * x.size(0)
            _, predicted = logits.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
            pbar.set_postfix(loss=f'{(total_loss/total):.3f}', acc=f'{(100.*correct/total):.2f}%')
    return total_loss / total, 100. * correct / total


def log_learnable_aggregation_params(model, pruned_nodes_dict, epoch, logger=None):
    log_messages = []
    log_messages.append(f"\n [Epoch {epoch+1}] SE-Attention Aggregation Parameters:")
    for (cell_idx, node_idx) in pruned_nodes_dict.keys():
        cell = model.cells[cell_idx]
        if hasattr(cell, 'learnable_aggregators') and str(node_idx) in cell.learnable_aggregators:
            aggregator = cell.learnable_aggregators[str(node_idx)]
            normalized_w = torch.softmax(aggregator.raw_weights.detach(), dim=-1).cpu().numpy()
            
            fc1_norm = torch.norm(aggregator.excitation[0].weight.detach()).cpu().item()
            fc2_norm = torch.norm(aggregator.excitation[3].weight.detach()).cpu().item()
            
            msg = (f"     Cell {cell_idx}, Node {node_idx}: "
                   f"Weights: [{normalized_w[0]:.4f} {normalized_w[1]:.4f}], "
                   f"FC1 Norm: {fc1_norm:.4f}, FC2 Norm: {fc2_norm:.4f}")
            log_messages.append(msg)
    
    for msg in log_messages:
        print(msg)
        if logger:
            logger.info(msg)


def get_pruned_nodes_dict(search_model_path, num_prune_nodes, C, layers, num_classes, device):
    print(f"Loading search model from {search_model_path} to determine nodes to prune...")
    search_model = NetworkCIFAR(C, num_classes, layers, False, DARTS, cell_type="normal_weighted").to(device)
    search_model.load_state_dict(torch.load(search_model_path, map_location=device))
    search_model.eval()
    
    sorted_nodes_by_weight = []
    for cell_idx, cell in enumerate(search_model.cells):
        if not cell.reduction and hasattr(cell, 'node_weights') and cell.node_weights is not None:
            for node_idx, weight in enumerate(cell.node_weights):
                sorted_nodes_by_weight.append((abs(weight.item()), cell_idx, node_idx))
    
    sorted_nodes_by_weight.sort()
    pruned_nodes_list = [ (cell_idx, node_idx) for _, cell_idx, node_idx in sorted_nodes_by_weight[:num_prune_nodes] ]
    pruned_nodes_dict = { (c_idx, n_idx): True for c_idx, n_idx in pruned_nodes_list }
    
    print(f"Identified {len(pruned_nodes_dict)} nodes to prune: {pruned_nodes_list}")
    return pruned_nodes_dict


def freeze_learnable_aggregators(model):
    frozen_params_count = 0
    for name, param in model.named_parameters():
        if 'learnable_aggregators' in name:
            param.requires_grad = False
            frozen_params_count += 1
    if frozen_params_count > 0:
        print(f"Frozen {frozen_params_count} parameters in learnable_aggregators.")


def unfreeze_learnable_aggregators(model):
    unfrozen_params_count = 0
    for name, param in model.named_parameters():
        if 'learnable_aggregators' in name:
            param.requires_grad = True
            unfrozen_params_count += 1
    if unfrozen_params_count > 0:
        print(f"Unfrozen {unfrozen_params_count} parameters in learnable_aggregators.")


def debug_print_pruned_operations(model, pruned_nodes_dict):
    print("\n Initial check of SE-Attention Aggregators:")
    for (cell_idx, node_idx) in pruned_nodes_dict.keys():
        cell = model.cells[cell_idx]
        if hasattr(cell, 'learnable_aggregators') and str(node_idx) in cell.learnable_aggregators:
            aggregator = cell.learnable_aggregators[str(node_idx)]
            print(f"  Cell {cell_idx}, Node {node_idx}:")
            print(f"    - Raw Weights: {aggregator.raw_weights.detach().cpu().numpy()}")
            print(f"    - Norm Weights: {torch.softmax(aggregator.raw_weights.detach(), dim=-1).cpu().numpy()}")
        else:
             print(f"  Cell {cell_idx}, Node {node_idx}: Not an SE-Attention node.")


def check_se_module_training_status(model, pruned_nodes_dict):
    print("\n SE Module Training Status Check:")
    total_se_params = 0
    trainable_se_params = 0
    
    for (cell_idx, node_idx) in pruned_nodes_dict.keys():
        cell = model.cells[cell_idx]
        if hasattr(cell, 'learnable_aggregators') and str(node_idx) in cell.learnable_aggregators:
            aggregator = cell.learnable_aggregators[str(node_idx)]
            
            se_params = 0
            se_trainable = 0
            
            for name, param in aggregator.named_parameters():
                se_params += param.numel()
                total_se_params += param.numel()
                if param.requires_grad:
                    se_trainable += param.numel()
                    trainable_se_params += param.numel()
            
            training_status = "Training" if se_trainable > 0 else "Frozen"
            print(f"   Cell {cell_idx}, Node {node_idx}: {training_status} "
                  f"({se_trainable}/{se_params} params trainable)")
    
    print(f"   Total SE params: {trainable_se_params}/{total_se_params} trainable")
    return trainable_se_params, total_se_params


def main(args):
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    set_seed(args.seed)

    os.makedirs(args.save, exist_ok=True)
    best_model_path = os.path.join(args.save, f'nrdarts-seed{args.seed}-best.pth')
    log_path = os.path.join(args.save, f'nrdarts-seed{args.seed}-log.json')

    search_model_path = args.search_model_path

    if search_model_path is None:
        search_save_path = f"exp/search_seed{args.seed}"
        search_model_path = os.path.join(search_save_path, f"search-seed{args.seed}-best.pth")
        print(f"INFO: --search_model_path not provided. Using default path: {search_model_path}")
    
    if not os.path.exists(search_model_path):
        print(f"ERROR: Search model not found at {search_model_path}")
        sys.exit(1)

    
    train_transform, valid_transform = _data_transforms_cifar10(args)
    full_train = torchvision.datasets.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    test_data = torchvision.datasets.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
    
    train_len = int(0.75 * len(full_train))
    val_len = len(full_train) - train_len
    train_data, val_data = random_split(full_train, [train_len, val_len])

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)

    pruned_nodes_dict = get_pruned_nodes_dict(search_model_path, args.num_prune_nodes, args.init_channels, args.layers, 10, device)

    model = NetworkCIFAR(
        args.init_channels, 10, args.layers, False, DARTS,
        args.drop_path_prob, cell_type="SEFusion_train",
        pruned_nodes_info=pruned_nodes_dict, dropout_p=args.dropout_p
    ).to(device)
    criterion = nn.CrossEntropyLoss().to(device)

    print("\n--- Starting Warmup Phase ---")
    freeze_learnable_aggregators(model)

    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 
                          lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    debug_print_pruned_operations(model, pruned_nodes_dict)
    check_se_module_training_status(model, pruned_nodes_dict)

    best_val_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_model_wts = None
    start_time = time.time()

    for epoch in range(args.epochs):
        if epoch == args.warmup_epochs:
            print(f"\n--- Warmup complete. Re-initializing optimizer for Main Training Phase ---")
            unfreeze_learnable_aggregators(model)
            check_se_module_training_status(model, pruned_nodes_dict)

            base_params = [p for n, p in model.named_parameters() if 'learnable_aggregators' not in n and p.requires_grad]
            se_params = [p for n, p in model.named_parameters() if 'learnable_aggregators' in n and p.requires_grad]
            
            optimizer = optim.SGD([
                {"params": base_params, "lr": args.learning_rate, "weight_decay": args.weight_decay},
                {"params": se_params, "lr": args.se_learning_rate, "weight_decay": 0.0},
            ], momentum=args.momentum)
            
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.warmup_epochs)
        
        phase = f"Warmup {epoch+1}/{args.warmup_epochs}" if epoch < args.warmup_epochs else f"Main {epoch+1-args.warmup_epochs}/{args.epochs-args.warmup_epochs}"
        
        train_loss, train_acc = train_one_epoch(train_loader, model, criterion, optimizer, device, f"[{phase}] Training")
        val_loss, val_acc = evaluate(val_loader, model, criterion, device, f"[{phase}] Validation")
        
        print(f"Epoch {epoch+1}: TrainAcc={train_acc:.2f}% | ValidAcc={val_acc:.2f}%")
        history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss); history['val_acc'].append(val_acc)
        
        if epoch >= args.warmup_epochs and (epoch + 1) % 10 == 0:
            log_learnable_aggregation_params(model, pruned_nodes_dict, epoch)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = model.state_dict()
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved with validation accuracy: {best_val_acc:.2f}%")
            
        scheduler.step()
    
    training_time = time.time() - start_time
    print(f"\nTraining finished! Best validation accuracy: {best_val_acc:.2f}%")

    model.load_state_dict(best_model_wts)
    test_loss, test_acc = evaluate(test_loader, model, criterion, device, "Testing Best Model")
    flops = measure_flops(model, device=device)
    latency = measure_inference_latency(model, device)
    max_memory_MB = measure_peak_memory(model, device)


    print("\n" + "="*50)
    print("           TRAINING FINISHED - RESULTS")
    print("="*50)
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"FLOPs: {flops:.2f} M")
    print(f"Latency: {latency:.2f} ms")
    print(f"Max GPU Memory: {max_memory_MB:.2f} MB")
    print(f"Total Training Time: {training_time/3600:.2f} hours")
    print("="*50)

    history['test_acc'] = test_acc; history['flops'] = flops; history['training_time_sec'] = training_time
    history['inference_latency_ms'] = latency; history['memory_usage_MB'] = max_memory_MB

    torch.save(best_model_wts, os.path.join(args.save, f'nrdarts-seed{args.seed}-best.pth'))
    with open(os.path.join(args.save, f"nrdarts-seed{args.seed}-log.json"), 'w') as f:
        json.dump(history, f, indent=4)
    print(f"Best model and logs saved in '{args.save}'")

if __name__ == '__main__':
    args = parser.parse_args()
    main(args)