# -*- coding: utf-8 -*-
import sys

sys.path.append('.')
sys.path.append('..')
sys.path.append('../..')

import time
import torch
import torch.nn.functional as F
from model import *

import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from util import Logger, Bar, AverageMeter, accuracy, load_dataset, warp_decay, split_params, init_config, bptt_model_setting
from spikingjelly.activation_based import functional
from model.layer import*

from model import ResNet_ANN
from datetime import datetime
import numpy as np

def tpd(outputs, temperature=3.0):
    """
    Progressive temporal self-distillation loss
    1 vs 12avg, 12avg vs 123avg, ..., until T-1avg vs Tavg
    Args:
        outputs: [T, batch_size, num_classes] student outputs at each timestep
        temperature: temperature parameter
    """
    T = outputs.shape[0]

    if T < 2:
        return torch.tensor(0.0, device=outputs.device)

    total_loss = 0.0
    num_loss_terms = 0

    # Progressive comparison: from i=1 to T-1, average of first i timesteps vs average of first (i+1) timesteps
    for i in range(1, T):
        # Average of first i timesteps
        current_avg = outputs[:i].mean(dim=0)  # [batch_size, num_classes]

        # Average of first (i+1) timesteps
        next_avg = outputs[:i+1].mean(dim=0)  # [batch_size, num_classes]

        # Calculate CE loss: let current_avg learn from next_avg
        next_probs = F.softmax(next_avg / temperature, dim=1)
        current_log_probs = F.log_softmax(current_avg / temperature, dim=1)

        loss = -torch.sum(next_probs * current_log_probs, dim=1).mean()
        total_loss += loss
        num_loss_terms += 1

    # Average over number of loss terms to ensure consistent loss magnitude across different T
    return total_loss / num_loss_terms if num_loss_terms > 0 else torch.tensor(0.0, device=outputs.device)

def train(train_ldr, optimizer, model,t_model, evaluator, args, num_classes=None):
    model.train()
    t_model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # Statistics for various loss terms
    hard_losses = AverageMeter()
    kd_losses = AverageMeter()
    tpd_losses = AverageMeter()

    # Pre-calculate various usage flags to avoid repeated checks in loop
    use_tpd = args.use_tpd
    use_timestep_mask = args.use_tmpd

    end = time.time()

    bar = Bar('Processing', max=len(train_ldr))

    for idx, (ptns, labels) in enumerate(train_ldr):
        device = next(model.parameters()).device
        ptns, labels = ptns.to(device), labels.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        optimizer.zero_grad()
        functional.reset_net(model)

        # Student network forward propagation
        if model.step_mode == 's':
            out_spikes = []
            for t in range(args.T):
                out = model(ptns)
                out_spikes.append(out)
            output = torch.stack(out_spikes, dim=0)
            avg_fr = output.mean(dim=0)
        else:
            in_data, _ = torch.broadcast_tensors(ptns, torch.zeros((args.T,) + ptns.shape))
            in_data = in_data.reshape(-1, *in_data.shape[2:])
            output = model(in_data)
            avg_fr = output.mean(dim=0)

        # Calculate teacher network outputs (if distillation needed)
        teacher_outputs = None
        if args.alpha > 0:  # Only compute teacher outputs when distillation weight > 0
            teacher_labels = make_teacher(output, labels)

            with torch.no_grad():
                if use_timestep_mask:
                    # Compute teacher outputs for each timestep
                    teacher_outputs = []
                    for i in range(args.T):
                        teacher_output = t_model.forward_with_timestep_mask(
                            ptns, time_step=i, total_steps=args.T,
                            mask_prob=args.mask_prob, mask_lambda=args.mask_lambda
                        )
                        teacher_outputs.append(teacher_output.detach())
                else:
                    # All timesteps use the same teacher output
                    shared_teacher_output = t_model.forward_with_mask(
                        ptns, mask_prob=args.mask_prob, mask_lambda=args.mask_lambda
                    ).detach()
                    teacher_outputs = [shared_teacher_output] * args.T

        # Calculate various losses
        # Use traditional method: compute CE per timestep then average
        hard_loss = cal_loss(output, labels, evaluator)

        # Distillation loss
        kd_loss_value = 0.0

        if args.alpha > 0 and teacher_outputs is not None:
            # Use traditional method: compare per timestep then average
            for i in range(args.T):
                kd_loss_value += kd_loss(output[i], teacher_outputs[i], 3.0)
            kd_loss_value = kd_loss_value / args.T

        # temporal consistency loss
        tpd_loss = 0.0
        if use_tpd:
            tpd_loss = tpd(output, args.tpd_temp)

                # Total loss
        loss = (hard_loss + 
                kd_loss_value * args.alpha + 
                tpd_loss * args.tpd_weight)

        loss.backward()
        optimizer.step()

        # Measure accuracy and record loss
        prec1, prec5 = accuracy(avg_fr.data, labels.data, topk=(1, 5))
        losses.update(loss.data.item(), ptns.size(0))
        top1.update(prec1.item(), ptns.size(0))
        top5.update(prec5.item(), ptns.size(0))

        # Update statistics for various loss terms
        batch_size = ptns.size(0)
        hard_losses.update(hard_loss.item(), batch_size)

        # Distillation loss statistics
        if args.alpha > 0:
            kd_losses.update(kd_loss_value.item(), batch_size)
        else:
            kd_losses.update(0.0, batch_size)

        tpd_losses.update(tpd_loss.item() if use_tpd else 0.0, batch_size)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
            batch=idx + 1,
            size=len(train_ldr),
            data=data_time.avg,
            bt=batch_time.avg,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            top1=top1.avg,
            top5=top5.avg,
        )
        bar.next()
    bar.finish()

    # Return statistics for various loss terms
    loss_stats = {
        'hard_loss': hard_losses.avg,
        'kd_loss': kd_losses.avg,
        'tpd_loss': tpd_losses.avg
    }
    return top1.avg, losses.avg, loss_stats

def test(val_ldr, model, t_model, evaluator, args):
    model.eval()
    t_model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # Create accuracy meters for different timestep lengths
    timestep_top1 = []
    timestep_top5 = []
    for t in range(args.T):
        timestep_top1.append(AverageMeter())
        timestep_top5.append(AverageMeter())

    end = time.time()
    bar = Bar('Processing', max=len(val_ldr))

    with torch.no_grad():
        for idx, (ptns, labels_batch) in enumerate(val_ldr):
            ptns, labels_batch = ptns.to(next(model.parameters()).device), labels_batch.to(
                next(model.parameters()).device)

            functional.reset_net(model)
            if model.step_mode == 's':
                out_spikes = []
                for t in range(args.T):
                    out = model(ptns)
                    out_spikes.append(out)

                output = torch.stack(out_spikes, dim=0)  # [T, batch_size, num_classes]

                # Calculate accuracy for all timestep lengths at once
                for t in range(args.T):
                    # Calculate average output of first t+1 timesteps
                    current_avg_fr = output[:t+1].mean(dim=0)  # [batch_size, num_classes]

                    # Calculate accuracy for current timestep length
                    prec1, prec5 = accuracy(current_avg_fr.data, labels_batch.data, topk=(1, 5))
                    timestep_top1[t].update(prec1.item(), ptns.size(0))
                    timestep_top5[t].update(prec5.item(), ptns.size(0))

                avg_fr = output.mean(dim=0)
            else:
                in_data, _ = torch.broadcast_tensors(ptns, torch.zeros((args.T,) + ptns.shape))
                in_data = in_data.reshape(-1, *in_data.shape[2:])
                output = model(in_data)

                # For non-single-step mode, calculate accuracy for different timestep lengths
                batch_size = ptns.size(0)
                # Reshape output to [T, batch_size, num_classes]
                output_reshaped = output.reshape(args.T, batch_size, -1)

                # Calculate accuracy for all timestep lengths at once
                for t in range(args.T):
                    # Calculate average output of first t+1 timesteps
                    current_avg_fr = output_reshaped[:t+1].mean(dim=0)  # [batch_size, num_classes]

                    # Calculate accuracy for current timestep length
                    prec1, prec5 = accuracy(current_avg_fr.data, labels_batch.data, topk=(1, 5))
                    timestep_top1[t].update(prec1.item(), ptns.size(0))
                    timestep_top5[t].update(prec5.item(), ptns.size(0))

                avg_fr = output.mean(dim=0)

            loss = evaluator(avg_fr, labels_batch)

            prec1, prec5 = accuracy(avg_fr.data, labels_batch.data, topk=(1, 5))
            losses.update(loss.data.item(), ptns.size(0))
            top1.update(prec1.item(), ptns.size(0))
            top5.update(prec5.item(), ptns.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                batch=idx + 1,
                size=len(val_ldr),
                data=data_time.avg,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg,
                top1=top1.avg,
                top5=top5.avg,
            )
            bar.next()
        bar.finish()

        # Output accuracy for different timestep lengths
        print("\n=== Inference Accuracy at Different Timestep Lengths ===")
        top1_results = [f"T{t+1}={timestep_top1[t].avg:.2f}" for t in range(args.T)]
        print(f"Top1: {' | '.join(top1_results)}")

        return top1.avg, losses.avg

def main():

    # set device, data type
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    dtype = torch.float
    log = Logger(args, args.log_path)
    log.info_args(args)
    writer = SummaryWriter(args.log_path)

    train_data, val_data, num_class = load_dataset(args.dataset, args.data_path, cutout=args.cutout,
                                                   auto_aug=args.auto_aug)

    train_ldr = DataLoader(dataset=train_data, batch_size=args.train_batch_size, shuffle=True,
                           pin_memory=True, num_workers=args.num_workers)
    val_ldr = DataLoader(dataset=val_data, batch_size=args.val_batch_size, shuffle=False,
                         pin_memory=True, num_workers=args.num_workers)

    kwargs_spikes = {'v_reset': args.v_reset, 'thresh': args.thresh, 'decay': warp_decay(args.decay),
                     'detach_reset': args.detach_reset}

    model = eval(args.stu_arch + f'(num_classes={num_class}, **kwargs_spikes)')
    model.to(device, dtype)
    t_model = ResNet_ANN.__dict__[args.tea_arch](num_classes=num_class, in_channels=3)
    t_model.to(device, dtype)

    bptt_model_setting(model, time_step=args.T, step_mode=args.step_mode)

    params = split_params(model)
    params = [
        {'params': params[1], 'weight_decay': args.wd},
        {'params': params[2], 'weight_decay': 0}
    ]

    if args.optim.lower() == 'sgdm':
        optimizer = optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optim.lower() == 'adam':
        optimizer = optim.Adam(params, lr=args.lr, amsgrad=False)
    else:
        raise NotImplementedError()

    evaluator = torch.nn.CrossEntropyLoss()
    start_epoch = 0
    best_epoch = 0
    best_acc = 0.0

    if args.tea_path is not None:
        state = torch.load(args.tea_path, map_location=device, weights_only=True)
        t_model.load_state_dict(state['best_net'])
        log.info('Load checkpoint from epoch {}'.format(start_epoch))
        log.info('Test the checkpoint: {}'.format(test(val_ldr, model,t_model, evaluator, args=args)))

    args.start_epoch = start_epoch
    if args.scheduler.lower() == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0, T_max=args.num_epoch)
    else:
        raise NotImplementedError()

    import time
    training_start_time = time.time()

    # Output current configuration information
    print("="*60)
    print("Training Configuration")
    print("="*60)

    # TMPD functionality status
    if args.use_tmpd:
        print(f"✓ TMPD: Enabled")
        print(f"  - Use random mask areas for each timestep")
        print(f"  - Mask probability: {args.mask_prob}")
        print(f"  - Mixing weight: {args.mask_lambda}")
    else:
        print(f"✗ TMPD: Disabled")
        print(f"  - All timesteps use same teacher output")

    # Training method status
    print(f"🎯 Training Method: Traditional method (temporal-wise comparison)")

    # TPD status
    if args.use_tpd:
        print(f"🔄 TPD: Enabled")
        print(f"  - Loss weight: {args.tpd_weight}")
    else:
        print(f"🔄 TPD: Disabled")

    # Other important parameters
    print(f"📊 Dataset: {args.dataset}")
    print(f"⏱️ Timesteps: {args.T}")
    print(f"🎓 Student Network: {args.stu_arch}")
    print(f"👨‍🏫 Teacher Network: {args.tea_arch}")
    print(f"📈 Distillation Weight: α={args.alpha}")

            # Inference timestep extension settings

    print("="*60)

    for epoch in range(start_epoch, args.num_epoch):
        train_acc, train_loss, loss_stats = train(train_ldr, optimizer, model, t_model, evaluator, args=args, num_classes=num_class)
        if args.scheduler != 'None':
            scheduler.step()
        val_acc, val_loss = test(val_ldr, model, t_model,evaluator, args=args)

        # Print weighted loss values for each component (single line display, optimized string operations)
        loss_str = f"Hard:{loss_stats['hard_loss']:.4f} | KD:{loss_stats['kd_loss']:.4f}→{loss_stats['kd_loss'] * args.alpha:.4f}"
        if args.use_tpd:
            loss_str += f" | TPD:{loss_stats['tpd_loss']:.4f}→{loss_stats['tpd_loss'] * args.tpd_weight:.4f}"
        loss_str += f" | Total:{train_loss:.4f}"
        print(f"Epoch {epoch:03d} - Loss: {loss_str}")
        if val_acc > best_acc:  # saving checkpoint
            best_acc = val_acc
            best_epoch = epoch
            state = {
                'best_acc': best_acc,
                'best_epoch': epoch,
                'best_net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state, os.path.join(args.log_path, 'model_weights.pth'))

        # Calculate remaining training time and estimated completion time
        elapsed_time = time.time() - training_start_time
        avg_time_per_epoch = elapsed_time / (epoch - start_epoch + 1)
        remaining_epochs = args.num_epoch - epoch - 1
        eta_seconds = remaining_epochs * avg_time_per_epoch
        eta_hours = int(eta_seconds // 3600)
        eta_minutes = int((eta_seconds % 3600) // 60)
        eta_str = f"{eta_hours}h{eta_minutes:02d}m"

        # Calculate estimated completion time
        from datetime import datetime, timedelta
        finish_time = datetime.now() + timedelta(seconds=eta_seconds)
        finish_str = finish_time.strftime("%H:%M")

        log.info(
            'Epoch %03d: train loss %.5f, test loss %.5f, train acc %.5f, test acc %.5f, Saved custom_model..  with acc %.5f in the epoch %03d | ETA: %s (finish ~%s)' % (
                epoch, train_loss, val_loss, train_acc, val_acc, best_acc, best_epoch, eta_str, finish_str))

        # record in tensorboard
        writer.add_scalars('Loss', {'val': val_loss, 'train': train_loss}, epoch + 1)
        writer.add_scalars('Acc', {'val': val_acc, 'train': train_acc}, epoch + 1)
    training_end_time = time.time()
    total_training_time = training_end_time - training_start_time
    total_hours = int(total_training_time // 3600)
    total_minutes = int((total_training_time % 3600) // 60)

    log.info('Training completed!')
    log.info(f'Best accuracy: {best_acc:.4f} at epoch {best_epoch}')
    log.info(f'Total training time: {total_hours}h {total_minutes}m')

            # Load best model and run final test
    if os.path.exists(os.path.join(args.log_path, 'model_weights.pth')):
        print(f"\n=== Loading Best Model for Final Testing ===")
        state = torch.load(os.path.join(args.log_path, 'model_weights.pth'), map_location=device, weights_only=True)
        model.load_state_dict(state['best_net'])

        # Run standard test
        final_acc, final_loss = test(val_ldr, model, t_model, evaluator, args=args)
        print(f"Final Standard Test Accuracy: {final_acc:.2f}%")

if __name__ == '__main__':
    from config.config import args

    init_config(args)
    main()

