import os
import sys
import json
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import AverageMeter, accuracy, get_score, set_optimizer, adjust_learning_rate, warmup_learning_rate
from utils import save_model, update_json, SpecAugment, get_backbone_class, update_moving_average
from datasets import ICBHIDataset
from losses import PAFALoss

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description=' Finetuning with CE + ECE')

    parser.add_argument('--tag', type=str, default='seed1_best_param')
    parser.add_argument('--dataset', type=str, default='icbhi')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--class_split', type=str, default='lungsound')
    parser.add_argument('--n_cls', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--desired_length', type=float, default=5.0)
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--cosine', action='store_true', default=True)
    parser.add_argument('--model', type=str, default='cnn6')
    parser.add_argument('--test_fold', type=str, default='official')
    parser.add_argument('--pad_types', type=str, default='repeat')
    parser.add_argument('--resz', type=float, default=1.0)
    parser.add_argument('--n_mels', type=int, default=128)
    parser.add_argument('--ma_update', action='store_true', default=True)
    parser.add_argument('--ma_beta', type=float, default=0.5)
    parser.add_argument('--from_sl_official', action='store_true', default=True)
    parser.add_argument('--audioset_pretrained', action='store_true', default=True)
    parser.add_argument('--method', type=str, default='ce')
    parser.add_argument('--nospec', action='store_true', default=False)
    parser.add_argument('--h', type=int, default=128)
    parser.add_argument('--w', type=int, default=1000)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--print_freq', type=int, default=10)
    parser.add_argument('--save_freq', type=int, default=50)
    parser.add_argument('--save_dir', type=str, default='./results')
    parser.add_argument('--save_folder', type=str, default='./results/cnn6_ce')
    parser.add_argument('--model_name', type=str, default='cnn6_ce')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--eval', action='store_true', default=False)
    parser.add_argument('--pretrained', action='store_true', default=False)
    parser.add_argument('--pretrained_ckpt', type=str, default=None)
    parser.add_argument('--two_cls_eval', action='store_true', default=False)
    args = parser.parse_args()

    os.makedirs(args.save_folder, exist_ok=True)
    return args


def set_loader(args):
    train_transform = []
    val_transform = []

    train_transform.append(transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz))))
    val_transform.append(transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz))))

    train_transform = transforms.Compose(train_transform)
    val_transform = transforms.Compose(val_transform)

    train_dataset = ICBHIDataset(train_flag=True, transform=train_transform, args=args, print_flag=True)
    val_dataset = ICBHIDataset(train_flag=False, transform=val_transform, args=args, print_flag=True)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True)
    return train_loader, val_loader, args


def set_model(args):
    kwargs = {}
    if args.model == 'ast':
        kwargs['input_fdim'] = int(args.h * args.resz)
        kwargs['input_tdim'] = int(args.w * args.resz)
        kwargs['label_dim'] = args.n_cls
        kwargs['imagenet_pretrain'] = args.from_sl_official
        kwargs['audioset_pretrain'] = args.audioset_pretrained
    elif args.model == 'beats':
        if args.nospec:
            kwargs['spec_transform'] = None
        else:
            kwargs['spec_transform'] = SpecAugment(args)

    elif args.model == 'cnn6':
        kwargs['n_mels'] = args.n_mels
        kwargs['pretrained'] = args.from_sl_official

    model = get_backbone_class(args.model)(**kwargs)

    if args.model == 'beats' and args.method == 'pafa':
        classifier = nn.Linear(model.final_feat_dim, args.n_cls).cuda()
        projector = ProjectionHead(model.final_feat_dim, args.hidden_dim, args.output_dim, attention=True,
                                   norm_type=args.norm_type, proj_type=args.proj_type).cuda()
    elif args.model == 'ast' and args.method == 'pafa':
        classifier = nn.Linear(model.final_feat_dim, args.n_cls).cuda()
        projector = ProjectionHead(model.final_feat_dim, args.hidden_dim, args.output_dim, attention=False,
                                   norm_type=args.norm_type, proj_type=args.proj_type).cuda()
    elif args.model == 'cnn6' and args.method == 'pafa':
        classifier = nn.Linear(model.final_feat_dim, args.n_cls).cuda()
        projector = ProjectionHead(model.final_feat_dim, args.hidden_dim, args.output_dim, attention=False,
                                   norm_type=args.norm_type, proj_type=args.proj_type).cuda()
    else:
        classifier = nn.Linear(model.final_feat_dim, args.n_cls).cuda() if args.model not in ['ast'] else deepcopy(
            model.mlp_head).cuda()
        projector = nn.Identity()

    if args.model not in ['ast', 'beats'] and args.from_sl_official:
        model.load_sl_official_weights()
        print('CNN6 pretrained weights loaded from PyTorch ImageNet-pretrained')

    if args.pretrained and args.pretrained_ckpt is not None:
        ckpt = torch.load(args.pretrained_ckpt, map_location='cpu')
        state_dict = ckpt['model']
        new_state_dict = {}
        for k, v in state_dict.items():
            if "module." in k:
                k = k.replace("module.", "")
            if "backbone." in k:
                k = k.replace("backbone.", "")
            if not 'mlp_head' in k:
                new_state_dict[k] = v
        model.load_state_dict(new_state_dict, strict=False)
        if ckpt.get('classifier', None) is not None:
            classifier.load_state_dict(ckpt['classifier'], strict=True)
        print('Pretrained model loaded from: {}'.format(args.pretrained_ckpt))

    criterion = nn.CrossEntropyLoss().cuda()
    criterion = [criterion]

    model.cuda()
    optim_params = list(model.parameters()) + list(classifier.parameters()) + list(projector.parameters())
    optimizer = set_optimizer(args, optim_params)

    return model, classifier, projector, criterion, optimizer


def train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args, scaler=None):
    model.train()
    classifier.train()
    projector.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()

    epoch_logits = []
    epoch_labels = []

    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)
        images = images.cuda(non_blocking=True)
        class_labels = labels[0].cuda(non_blocking=True)
        device_labels = labels[1].cuda(non_blocking=True)
        patient_labels = labels[2].cuda(non_blocking=True)
        bsz = class_labels.shape[0]

        if args.ma_update:
            with torch.no_grad():
                ma_ckpt = [deepcopy(model.state_dict()), deepcopy(classifier.state_dict()),
                           deepcopy(projector.state_dict())]

        warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer)

        with torch.cuda.amp.autocast():
            if args.method == 'ce':
                if args.model == 'beats':
                    features = model(images, training=True)
                    output = classifier(features)
                    output = output.mean(dim=1)
                    loss = criterion[0](output, class_labels)
                elif args.model == 'cnn6':
                    if args.nospec:
                        features = model(images, training=True)
                    else:
                        features = model(args.transforms(images), training=True)
                    output = classifier(features)
                    loss = criterion[0](output, class_labels)
                else:  # AST
                    if args.nospec:
                        features = model(images, args=args, training=True)
                    else:
                        features = model(args.transforms(images), args=args, training=True)
                    output = classifier(features)
                    loss = criterion[0](output, class_labels)

            losses.update(loss.item(), bsz)
            epoch_logits.append(output.detach())
            epoch_labels.append(class_labels.detach())

        [acc1], _ = accuracy(output[:bsz], class_labels, topk=(1,))
        top1.update(acc1[0], bsz)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        batch_time.update(time.time() - end)
        end = time.time()

        if args.ma_update:
            with torch.no_grad():
                model = update_moving_average(args.ma_beta, model, ma_ckpt[0])
                classifier = update_moving_average(args.ma_beta, classifier, ma_ckpt[1])
                projector = update_moving_average(args.ma_beta, projector, ma_ckpt[2])

        if (idx + 1) % args.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, idx + 1, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1))
            sys.stdout.flush()

    epoch_logits = torch.cat(epoch_logits, dim=0)
    epoch_labels = torch.cat(epoch_labels, dim=0)
    ece = calculate_ece(epoch_logits, epoch_labels)
    print(f'Epoch {epoch} Train ECE: {ece:.4f}')

    return losses.avg, top1.avg, ece


def validate(val_loader, model, classifier, criterion, args, best_acc, best_model=None, projector=None):

    save_bool = False
    model.eval()
    classifier.eval()
    projector.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    hits, counts = [0.0] * args.n_cls, [0.0] * args.n_cls

    epoch_logits = []
    epoch_labels = []

    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.cuda(non_blocking=True)
            class_labels = labels[0].cuda(non_blocking=True)
            labels = class_labels.cuda(non_blocking=True)
            bsz = labels.shape[0]

            with torch.cuda.amp.autocast():
                if args.model == 'beats':
                    features = model(images, training=False)
                    output = classifier(features)
                    output = output.mean(dim=1)
                elif args.model == 'cnn6':
                    features = model(images, training=False)
                    output = classifier(features)
                else:  # AST
                    features = model(images, args=args, training=False)
                    output = classifier(features)
                loss = criterion[0](output, labels)

            losses.update(loss.item(), bsz)
            [acc1], _ = accuracy(output, labels, topk=(1,))
            top1.update(acc1[0], bsz)

            # 存储logits和labels
            epoch_logits.append(output.detach())
            epoch_labels.append(labels.detach())

            _, preds = torch.max(output, 1)
            for idx in range(preds.shape[0]):
                counts[labels[idx].item()] += 1.0
                if not args.two_cls_eval:
                    if preds[idx].item() == labels[idx].item():
                        hits[labels[idx].item()] += 1.0
                else:
                    if labels[idx].item() == 0 and preds[idx].item() == labels[idx].item():
                        hits[labels[idx].item()] += 1.0
                    elif labels[idx].item() != 0 and preds[idx].item() > 0:
                        hits[labels[idx].item()] += 1.0

            sp, se, sc, f1_normal = get_score(hits, counts)
            batch_time.update(time.time() - end)
            end = time.time()

            if (idx + 1) % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'S_p {sp:.3f}\t'
                      'S_e {se:.3f}\t'
                      'Score {sc:.3f}\t'
                      'F1 Score {f1:.3f}'.format(
                    idx + 1, len(val_loader), batch_time=batch_time,
                    loss=losses, top1=top1, sp=sp, se=se, sc=sc,
                    f1=f1_normal))

    # 计算本轮ECE
    epoch_logits = torch.cat(epoch_logits, dim=0)
    epoch_labels = torch.cat(epoch_labels, dim=0)
    ece = calculate_ece(epoch_logits, epoch_labels)
    print(f'Epoch {epoch if "epoch" in locals() else "Val"} Val ECE: {ece:.4f}')

    if sc > best_acc[-2] and se > 0.1:
        save_bool = True
        best_acc = [sp, se, sc, f1_normal, ece]  # 新增ECE到best_acc
        best_model = [deepcopy(model.state_dict()), deepcopy(classifier.state_dict()), deepcopy(projector.state_dict())]

    print(' * S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f} (Best S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f})'
          .format(sp, se, sc, best_acc[0], best_acc[1], best_acc[2]))
    print(' * F1 Score: {:.2f} (F1 Score: {:.2f})'.format(f1_normal, best_acc[3]))
    print(' * Acc@1 {top1.avg:.2f}'.format(top1=top1))
    print(' * Val ECE: {:.4f} (Best Val ECE: {:.4f})'.format(ece, best_acc[4]))

    return best_acc, best_model, save_bool


def main():
    args = parse_args()
    with open(os.path.join(args.save_folder, 'train_args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    best_model = None
    best_acc = [0, 0, 0, 0, 0]

    if not args.nospec:
        print("Enable SpecAugment for CNN6 (spectrogram input)")
        args.transforms = SpecAugment(args)

    train_loader, val_loader, args = set_loader(args)
    model, classifier, projector, criterion, optimizer = set_model(args)

    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f"Loaded checkpoint {args.resume} (epoch {checkpoint['epoch']})")
        else:
            print(f"No checkpoint found at {args.resume}")
    else:
        args.start_epoch = 1

    scaler = torch.cuda.amp.GradScaler()

    print('*' * 20)
    print('Checkpoint Name: {}'.format(args.model_name))

    if not args.eval:
        print(f'Training for {args.epochs} epochs on {args.dataset} (CNN6 + CE + ECE)')
        ece_log = []
        for epoch in range(args.start_epoch, args.epochs + 1):
            adjust_learning_rate(args, optimizer, epoch)
            time1 = time.time()
            loss, acc, train_ece = train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args,
                                         scaler)
            time2 = time.time()
            print(f'Train epoch {epoch}, total time {time2 - time1:.2f}, accuracy:{acc:.2f}, ECE:{train_ece:.4f}')
            best_acc, best_model, save_bool = validate(val_loader, model, classifier, criterion, args, best_acc,
                                                       best_model, projector)
            ece_log.append({'epoch': epoch, 'train_ece': train_ece, 'val_ece': best_acc[4]})
            with open(os.path.join(args.save_folder, 'ece_log.json'), 'w') as f:
                json.dump(ece_log, f, indent=4)
            if save_bool:
                save_file = os.path.join(args.save_folder, f'best_epoch_{epoch}.pth')
                print(f'Best ckpt saved (Score={best_acc[2]:.2f}, ECE={best_acc[4]:.4f}) at epoch {epoch}')
                save_model(model, optimizer, args, epoch, save_file, classifier, projector)
            if epoch % args.save_freq == 0:
                save_file = os.path.join(args.save_folder, f'epoch_{epoch}.pth')
                save_model(model, optimizer, args, epoch, save_file, classifier, projector)

        save_file = os.path.join(args.save_folder, 'best.pth')
        model.load_state_dict(best_model[0])
        classifier.load_state_dict(best_model[1])
        projector.load_state_dict(best_model[2])
        save_model(model, optimizer, args, args.epochs, save_file, classifier, projector)
    else:
        print(f'Testing pretrained checkpoint on {args.dataset}')
        best_acc, _, _ = validate(val_loader, model, classifier, criterion, args, best_acc, best_model, projector)

    update_json(args.model_name, best_acc, path=os.path.join(args.save_dir, 'results.json'))
    print(f'Checkpoint {args.model_name} finished (Best ECE: {best_acc[4]:.4f})')

if __name__ == '__main__':
    main()