import os
import sys
import json
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import AverageMeter, accuracy, get_score, set_optimizer, adjust_learning_rate, warmup_learning_rate
from utils import save_model, update_json, SpecAugment, get_backbone_class, update_moving_average
from datasets import ICBHIDataset
from losses import PAFALoss

class ProjectionHead(nn.Module):
    def __init__(self,
                 in_dim,
                 hidden_dim=2048,
                 out_dim=128,
                 num_layers=2,
                 use_linear=False,
                 scale=16.0,
                 max_s=40.0):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.scale = nn.Parameter(torch.tensor(scale))
        self.max_s = max_s

        if use_linear:
            self.layers = nn.Linear(in_dim, out_dim)
            nn.init.kaiming_uniform_(self.layers.weight, a=math.sqrt(5))
        else:
            layers = []
            for i in range(num_layers):
                if i == 0:
                    layers.append(nn.Linear(in_dim, hidden_dim))
                else:
                    layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(nn.LayerNorm(hidden_dim))
                layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Linear(hidden_dim, out_dim))
            self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.layers(x)
        return x

def calculate_ece(logits, labels, n_bins=10):
    softmax = nn.Softmax(dim=1)
    probs = softmax(logits)
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(labels)

    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = torch.zeros(1, device=logits.device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()


def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description='Finetuning with PAFA + ECE')
    parser.add_argument('--tag', type=str, default='seed1_best_param')
    parser.add_argument('--dataset', type=str, default='icbhi')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--class_split', type=str, default='lungsound')
    parser.add_argument('--n_cls', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--desired_length', type=float, default=5.0)
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--cosine', action='store_true', default=True)
    parser.add_argument('--model', type=str, default='cnn6')
    parser.add_argument('--test_fold', type=str, default='official')
    parser.add_argument('--pad_types', type=str, default='repeat')
    parser.add_argument('--resz', type=float, default=1.0)
    parser.add_argument('--n_mels', type=int, default=128)
    parser.add_argument('--ma_update', action='store_true', default=True)
    parser.add_argument('--ma_beta', type=float, default=0.5)
    parser.add_argument('--from_sl_official', action='store_true', default=True)
    parser.add_argument('--audioset_pretrained', action='store_true', default=True)
    parser.add_argument('--method', type=str, default='pafa')
    parser.add_argument('--w_ce', type=float, default=1.0)
    parser.add_argument('--w_pafa', type=float, default=1.0)
    parser.add_argument('--lambda_pcsl', type=float, default=5.0)
    parser.add_argument('--lambda_gpal', type=float, default=0.005)
    parser.add_argument('--norm_type', type=str, default='ln')
    parser.add_argument('--output_dim', type=int, default=768)
    parser.add_argument('--hidden_dim', type=int, default=2048)
    parser.add_argument('--proj_type', type=str, default='mlp')
    parser.add_argument('--nospec', action='store_true', default=False)
    parser.add_argument('--h', type=int, default=128)
    parser.add_argument('--w', type=int, default=1000)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--print_freq', type=int, default=10)
    parser.add_argument('--save_freq', type=int, default=50)
    parser.add_argument('--save_dir', type=str, default='./results')
    parser.add_argument('--save_folder', type=str, default='./results/cnn6_pafa')
    parser.add_argument('--model_name', type=str, default='cnn6_pafa')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--eval', action='store_true', default=False)
    parser.add_argument('--pretrained', action='store_true', default=False)
    parser.add_argument('--pretrained_ckpt', type=str, default=None)
    parser.add_argument('--two_cls_eval', action='store_true', default=False)
    args = parser.parse_args()
    os.makedirs(args.save_folder, exist_ok=True)
    return args


def set_loader(args):
    train_transform = []
    val_transform = []

    train_transform.append(transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz))))
    val_transform.append(transforms.Resize(size=(int(args.h * args.resz), int(args.w * args.resz))))

    train_transform = transforms.Compose(train_transform)
    val_transform = transforms.Compose(val_transform)

    train_dataset = ICBHIDataset(train_flag=True, transform=train_transform, args=args, print_flag=True)
    val_dataset = ICBHIDataset(train_flag=False, transform=val_transform, args=args, print_flag=True)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True)
    return train_loader, val_loader, args


def set_model(args):
    kwargs = {}
    if args.model == 'ast':
        kwargs['input_fdim'] = int(args.h * args.resz)
        kwargs['input_tdim'] = int(args.w * args.resz)
        kwargs['label_dim'] = args.n_cls
        kwargs['imagenet_pretrain'] = args.from_sl_official
        kwargs['audioset_pretrain'] = args.audioset_pretrained
    elif args.model == 'beats':
        if args.nospec:
            kwargs['spec_transform'] = None
        else:
            kwargs['spec_transform'] = SpecAugment(args)
    elif args.model == 'cnn6':
        kwargs['n_mels'] = args.n_mels
        kwargs['pretrained'] = args.from_sl_official

    model = get_backbone_class(args.model)(**kwargs)

    if args.model == 'beats' and args.method == 'pafa':
        classifier = nn.Linear(model.final_feat_dim, args.n_cls).cuda()
        projector = ProjectionHead(model.final_feat_dim, args.hidden_dim, args.output_dim, attention=True,
                                   norm_type=args.norm_type, proj_type=args.proj_type).cuda()
    elif args.model == 'ast' and args.method == 'pafa':
        classifier = nn.Linear(model.final_feat_dim, args.n_cls).cuda()
        projector = ProjectionHead(model.final_feat_dim, args.hidden_dim, args.output_dim, attention=False,
                                   norm_type=args.norm_type, proj_type=args.proj_type).cuda()
    elif args.model == 'cnn6' and args.method == 'pafa':
        classifier = nn.Linear(model.final_feat_dim, args.n_cls).cuda()
        projector = ProjectionHead(model.final_feat_dim, args.hidden_dim, args.output_dim, attention=False,
                                   norm_type=args.norm_type, proj_type=args.proj_type).cuda()
    else:
        classifier = nn.Linear(model.final_feat_dim, args.n_cls).cuda() if args.model not in ['ast'] else deepcopy(
            model.mlp_head).cuda()
        projector = nn.Identity()

    if args.model not in ['ast', 'beats'] and args.from_sl_official:
        model.load_sl_official_weights()
        print('CNN6 pretrained weights loaded from PyTorch ImageNet-pretrained')

    if args.pretrained and args.pretrained_ckpt is not None:
        ckpt = torch.load(args.pretrained_ckpt, map_location='cpu')
        state_dict = ckpt['model']
        new_state_dict = {}
        for k, v in state_dict.items():
            if "module." in k:
                k = k.replace("module.", "")
            if "backbone." in k:
                k = k.replace("backbone.", "")
            if not 'mlp_head' in k:
                new_state_dict[k] = v
        model.load_state_dict(new_state_dict, strict=False)
        if ckpt.get('classifier', None) is not None:
            classifier.load_state_dict(ckpt['classifier'], strict=True)
        print('Pretrained model loaded from: {}'.format(args.pretrained_ckpt))

    criterion_ce = nn.CrossEntropyLoss().cuda()
    criterion_pafa = PAFALoss().cuda()
    criterion = [criterion_ce, criterion_pafa]

    model.cuda()
    optim_params = list(model.parameters()) + list(classifier.parameters()) + list(projector.parameters())
    optimizer = set_optimizer(args, optim_params)

    return model, classifier, projector, criterion, optimizer


def train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args, scaler=None):
    model.train()
    classifier.train()
    projector.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()

    epoch_logits = []
    epoch_labels = []

    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)
        images = images.cuda(non_blocking=True)
        class_labels = labels[0].cuda(non_blocking=True)
        device_labels = labels[1].cuda(non_blocking=True)
        patient_labels = labels[2].cuda(non_blocking=True)
        bsz = class_labels.shape[0]

        if args.ma_update:
            with torch.no_grad():
                ma_ckpt = [deepcopy(model.state_dict()), deepcopy(classifier.state_dict()),
                           deepcopy(projector.state_dict())]

        warmup_learning_rate(args, epoch, idx, len(train_loader), optimizer)

        with torch.cuda.amp.autocast():
            if args.method == 'pafa':
                if args.model == 'beats':
                    features = model(images, training=True)
                    output = classifier(features)
                    output_projector = projector(features)
                    output = output.mean(dim=1)
                    loss_class = criterion[0](output, class_labels)
                    loss_pafa = criterion[1](output_projector, patient_labels, lambda_pcsl=args.lambda_pcsl,
                                             lambda_gpal=args.lambda_gpal)
                    loss = args.w_ce * loss_class + args.w_pafa * loss_pafa
                elif args.model == 'cnn6':
                    if args.nospec:
                        features = model(images, training=True)
                    else:
                        features = model(args.transforms(images), training=True)
                    output_projector = projector(features)
                    output = classifier(features)
                    loss_class = criterion[0](output, class_labels)
                    loss_pafa = criterion[1](output_projector, patient_labels, lambda_pcsl=args.lambda_pcsl,
                                             lambda_gpal=args.lambda_gpal)
                    loss = args.w_ce * loss_class + args.w_pafa * loss_pafa
                else:  # AST
                    if args.nospec:
                        features = model(images, args=args, training=True)
                    else:
                        features = model(args.transforms(images), args=args, training=True)
                    output_projector = projector(features)
                    output = classifier(features)
                    loss_class = criterion[0](output, class_labels)
                    loss_pafa = criterion[1](output_projector, patient_labels, lambda_pcsl=args.lambda_pcsl,
                                             lambda_gpal=args.lambda_gpal)
                    loss = args.w_ce * loss_class + args.w_pafa * loss_pafa

            losses.update(loss.item(), bsz)
            epoch_logits.append(output.detach())
            epoch_labels.append(class_labels.detach())

        [acc1], _ = accuracy(output[:bsz], class_labels, topk=(1,))
        top1.update(acc1[0], bsz)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        batch_time.update(time.time() - end)
        end = time.time()

        if args.ma_update:
            with torch.no_grad():
                model = update_moving_average(args.ma_beta, model, ma_ckpt[0])
                classifier = update_moving_average(args.ma_beta, classifier, ma_ckpt[1])
                projector = update_moving_average(args.ma_beta, projector, ma_ckpt[2])

        if (idx + 1) % args.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                epoch, idx + 1, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1))
            sys.stdout.flush()

    epoch_logits = torch.cat(epoch_logits, dim=0)
    epoch_labels = torch.cat(epoch_labels, dim=0)
    ece = calculate_ece(epoch_logits, epoch_labels)
    print(f'Epoch {epoch} Train ECE: {ece:.4f}')

    return losses.avg, top1.avg, ece


def validate(val_loader, model, classifier, criterion, args, best_acc, best_model=None, projector=None):
    save_bool = False
    model.eval()
    classifier.eval()
    projector.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    hits, counts = [0.0] * args.n_cls, [0.0] * args.n_cls

    epoch_logits = []
    epoch_labels = []

    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.cuda(non_blocking=True)
            class_labels = labels[0].cuda(non_blocking=True)
            labels = class_labels.cuda(non_blocking=True)
            bsz = labels.shape[0]

            with torch.cuda.amp.autocast():
                if args.model == 'beats':
                    features = model(images, training=False)
                    output = classifier(features)
                    output = output.mean(dim=1)
                elif args.model == 'cnn6':
                    features = model(images, training=False)
                    output = classifier(features)
                else:  # AST
                    features = model(images, args=args, training=False)
                    output = classifier(features)
                loss = criterion[0](output, labels)

            losses.update(loss.item(), bsz)
            [acc1], _ = accuracy(output, labels, topk=(1,))
            top1.update(acc1[0], bsz)

            epoch_logits.append(output.detach())
            epoch_labels.append(labels.detach())

            _, preds = torch.max(output, 1)
            for idx in range(preds.shape[0]):
                counts[labels[idx].item()] += 1.0
                if not args.two_cls_eval:
                    if preds[idx].item() == labels[idx].item():
                        hits[labels[idx].item()] += 1.0
                else:
                    if labels[idx].item() == 0 and preds[idx].item() == labels[idx].item():
                        hits[labels[idx].item()] += 1.0
                    elif labels[idx].item() != 0 and preds[idx].item() > 0:
                        hits[labels[idx].item()] += 1.0

            sp, se, sc, f1_normal = get_score(hits, counts)
            batch_time.update(time.time() - end)
            end = time.time()

            if (idx + 1) % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'S_p {sp:.3f}\t'
                      'S_e {se:.3f}\t'
                      'Score {sc:.3f}\t'
                      'F1 Score {f1:.3f}'.format(
                    idx + 1, len(val_loader), batch_time=batch_time,
                    loss=losses, top1=top1, sp=sp, se=se, sc=sc,
                    f1=f1_normal))

    # 计算本轮ECE
    epoch_logits = torch.cat(epoch_logits, dim=0)
    epoch_labels = torch.cat(epoch_labels, dim=0)
    ece = calculate_ece(epoch_logits, epoch_labels)
    print(f'Epoch {epoch if "epoch" in locals() else "Val"} Val ECE: {ece:.4f}')

    if sc > best_acc[-2] and se > 0.1:
        save_bool = True
        best_acc = [sp, se, sc, f1_normal, ece]
        best_model = [deepcopy(model.state_dict()), deepcopy(classifier.state_dict()), deepcopy(projector.state_dict())]

    print(' * S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f} (Best S_p: {:.2f}, S_e: {:.2f}, Score: {:.2f})'
          .format(sp, se, sc,best_acc[0],best_acc[1],best_acc[2]))
    print(' * F1 Score: {:.2f} (F1 Score: {:.2f})'.format(f1_normal, best_acc[3]))
    print(' * Acc@1 {top1.avg:.2f}'.format(top1=top1))
    print(' * Val ECE: {:.4f} (Best Val ECE: {:.4f})'.format(ece, best_acc[4]))

    return best_acc, best_model, save_bool


def main():
    args = parse_args()
    with open(os.path.join(args.save_folder, 'train_args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    best_model = None
    best_acc = [0, 0, 0, 0, 0]

    if not args.nospec:
        print("Enable SpecAugment for CNN6 (spectrogram input)")
        args.transforms = SpecAugment(args)

    train_loader, val_loader, args = set_loader(args)
    model, classifier, projector, criterion, optimizer = set_model(args)

    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f"Loaded checkpoint {args.resume} (epoch {checkpoint['epoch']})")
        else:
            print(f"No checkpoint found at {args.resume}")
    else:
        args.start_epoch = 1

    scaler = torch.cuda.amp.GradScaler()

    print('*' * 20)
    print('Checkpoint Name: {}'.format(args.model_name))

    if not args.eval:
        print(f'Training for {args.epochs} epochs on {args.dataset} (CNN6 + PAFA + ECE)')
        ece_log = []
        for epoch in range(args.start_epoch, args.epochs + 1):
            adjust_learning_rate(args, optimizer, epoch)
            time1 = time.time()
            loss, acc, train_ece = train(train_loader, model, classifier, projector, criterion, optimizer, epoch, args,
                                         scaler)
            time2 = time.time()
            print(f'Train epoch {epoch}, total time {time2 - time1:.2f}, accuracy:{acc:.2f}, ECE:{train_ece:.4f}')
            best_acc, best_model, save_bool = validate(val_loader, model, classifier, criterion, args, best_acc,
                                                       best_model, projector)
            ece_log.append({'epoch': epoch, 'train_ece': train_ece, 'val_ece': best_acc[4]})
            with open(os.path.join(args.save_folder, 'ece_log.json'), 'w') as f:
                json.dump(ece_log, f, indent=4)
            if save_bool:
                save_file = os.path.join(args.save_folder, f'best_epoch_{epoch}.pth')
                print(f'Best ckpt saved (Score={best_acc[2]:.2f}, ECE={best_acc[4]:.4f}) at epoch {epoch}')
                save_model(model, optimizer, args, epoch, save_file, classifier, projector)
            if epoch % args.save_freq == 0:
                save_file = os.path.join(args.save_folder, f'epoch_{epoch}.pth')
                save_model(model, optimizer, args, epoch, save_file, classifier, projector)

        save_file = os.path.join(args.save_folder, 'best.pth')
        model.load_state_dict(best_model[0])
        classifier.load_state_dict(best_model[1])
        projector.load_state_dict(best_model[2])
        save_model(model, optimizer, args, args.epochs, save_file, classifier, projector)
    else:
        print(f'Testing pretrained checkpoint on {args.dataset}')
        best_acc, _, _ = validate(val_loader, model, classifier, criterion, args, best_acc, best_model, projector)

    update_json(args.model_name, best_acc, path=os.path.join(args.save_dir, 'results.json'))
    print(f'Checkpoint {args.model_name} finished (Best ECE: {best_acc[4]:.4f})')


if __name__ == '__main__':
    main()