import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import timm
import random
from dataset.isic_dataset import SkinDataset
from model.ivq import Model
from torchvision import transforms, models
from sklearn.metrics import balanced_accuracy_score
import copy
from torch.utils.data import DataLoader
from optparse import OptionParser
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid, save_image
import utils
import matplotlib.pyplot as plt
import os
import sys
import time
import math

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

dataset_dict = {
    'isic2018': SkinDataset,
}

def train_net(model, config):
    train_transforms = copy.deepcopy(config.preprocess)
    train_transforms.transforms.pop(0)
    if model.model_name != 'clip':
        train_transforms.transforms.pop(0)
    train_transforms.transforms.insert(0, transforms.RandomVerticalFlip())
    train_transforms.transforms.insert(0, transforms.RandomHorizontalFlip())
    train_transforms.transforms.insert(0, transforms.RandomResizedCrop(size=(224, 224), scale=(0.75, 1.0), ratio=(0.75, 1.33), interpolation=utils.get_interpolation_mode('bicubic')))
    train_transforms.transforms.insert(0, transforms.ToPILImage())

    val_transforms = copy.deepcopy(config.preprocess)
    val_transforms.transforms.insert(0, transforms.ToPILImage())

    trainset = dataset_dict[config.dataset](config.data_path, mode='train', transforms=train_transforms, flag=config.flag, config=config, return_concept_label=True)
    trainLoader = DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=8, drop_last=True)

    valset = dataset_dict[config.dataset](config.data_path, mode='val', transforms=val_transforms, flag=config.flag, config=config, return_concept_label=True)
    valLoader = DataLoader(valset, batch_size=config.batch_size, shuffle=False, num_workers=2, drop_last=False)
    
    testset = dataset_dict[config.dataset](config.data_path, mode='test', transforms=val_transforms, flag=config.flag, config=config, return_concept_label=True)
    testLoader = DataLoader(testset, batch_size=config.batch_size, shuffle=False, num_workers=2, drop_last=False)

    writer = SummaryWriter(os.path.join(config.log_path, config.unique_name))
    
    if config.cls_weight is None:
        criterion = nn.CrossEntropyLoss().cuda() 
    else:
        lesion_weight = torch.FloatTensor(config.cls_weight).cuda()
        criterion = nn.CrossEntropyLoss(weight=lesion_weight).cuda()
    
    if config.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=0.0005)
    elif config.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=config.lr)
    elif config.optimizer == 'adamw':
        optimizer = optim.AdamW([
            {'params': model.get_backbone_params(), 'lr': config.lr * 0.1},
            {'params': model.get_bridge_params(), 'lr': config.lr},
        ])

    scaler = torch.cuda.amp.GradScaler() if config.amp else None

    BMAC, acc, _, _, _, _, _ = validation(model, valLoader, criterion)
    print(f'Initial Validation -> BMAC: {BMAC:.5f}, Acc: {acc:.5f}')

    best_bmac_val = 0
    best_bmac_test = 0
    best_acc_val = 0
    best_acc_test = 0
    rank_epochs = [[], []]
    for epoch in range(config.epochs):
        print(f'Starting epoch {epoch+1}/{config.epochs}')
        epoch_loss_cls = 0
        epoch_loss_concept = 0
        model.train()
        end = time.time()
        rank_list = []  
        exp_scheduler = utils.exp_lr_scheduler_with_warmup(optimizer, init_lr=config.lr, epoch=epoch, warmup_epoch=config.warmup_epoch, max_epoch=config.epochs)

        for i, (data, label, concept_label) in enumerate(trainLoader, 0):
            x, target = data.float().cuda(), label.long().cuda()
            concept_label = concept_label.long().cuda()
            optimizer.zero_grad()

            cls_logits, image_logits_dict, vq_loss, avg_rank = model(x)
            rank_list.append(avg_rank)

            num_classes = cls_logits.shape[1]
            
            loss_cls = criterion(cls_logits, target)

            loss_concepts = 0
            idx = 0
            for key in net.concept_token_dict.keys():
                concept_logits = image_logits_dict[key]
                concept_target = concept_label[:, idx]

                num_concept_classes = concept_logits.shape[1]
                concept_target_one_hot = F.one_hot(concept_target, num_classes=num_concept_classes).float()

                image_concept_loss = F.binary_cross_entropy_with_logits(concept_logits, concept_target_one_hot)
                loss_concepts += image_concept_loss
                idx += 1

            loss = loss_cls + loss_concepts / idx + vq_loss
            loss.backward()
            optimizer.step()

            epoch_loss_cls += loss_cls.item()
            epoch_loss_concept += loss_concepts.item()
            batch_time = time.time() - end
            end = time.time()

            print(f'{i}, loss_cls: {loss.item():.5f}, loss_concept: {loss_concepts.item():.5f}, batch_time: {batch_time:.5f}')
        
        print(f'[epoch {epoch+1}] epoch loss_cls: {epoch_loss_cls/(i+1):.5f}, epoch_loss_concept: {epoch_loss_concept/(i+1):.5f}')
        rank_epochs[0].append(np.mean(rank_list))

        writer.add_scalar('Train/Loss_cls', epoch_loss_cls/(i+1), epoch+1)
        writer.add_scalar('Train/Loss_concept', epoch_loss_concept/(i+1), epoch+1)

        checkpoint_dir = os.path.join(config.cp_path, config.unique_name)
        if not os.path.isdir(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        val_BMAC, val_acc, val_loss_cls, val_loss_concept, val_BMAC_concept, val_acc_concept, val_avg_rank = validation(model, valLoader, criterion)
        rank_epochs[1].append(val_avg_rank)
        writer.add_scalar('Val/BMAC', val_BMAC, epoch+1)
        writer.add_scalar('Val/Acc', val_acc, epoch+1)
        writer.add_scalar('Val/val_loss_cls', val_loss_cls, epoch+1)
        writer.add_scalar('Val/val_loss_concept', val_loss_concept, epoch+1)
        writer.add_scalar('Val/BMAC_concept', val_BMAC_concept, epoch+1)
        writer.add_scalar('Val/acc_concept', val_acc_concept, epoch+1)
        
        test_BMAC, test_acc, test_loss_cls, test_loss_concept, test_BMAC_concept, test_acc_concept, test_avg_rank = validation(model, testLoader, criterion)
        writer.add_scalar('Test/BMAC', test_BMAC, epoch+1)
        writer.add_scalar('Test/Acc', test_acc, epoch+1)
        writer.add_scalar('Test/test_loss_cls', test_loss_cls, epoch+1)
        writer.add_scalar('Test/test_loss_concept', test_loss_concept, epoch+1)
        writer.add_scalar('Test/BMAC_concept', test_BMAC_concept, epoch+1)
        writer.add_scalar('Test/acc_concept', test_acc_concept, epoch+1)

        lr = optimizer.param_groups[0]['lr']
        writer.add_scalar('LR/lr', lr, epoch+1)

        if val_BMAC >= best_bmac_val:
            best_bmac_val = val_BMAC
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best_ckpt.pth'))
        if test_BMAC >= best_bmac_test:
            best_bmac_test = test_BMAC
        if test_acc >= best_acc_test:
            best_acc_test = test_acc
        if val_acc >= best_acc_val:
            best_acc_val = val_acc
        
        if epoch % 10 == 0:
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'{epoch}_ckpt.pth'))

        output = f"""
        --- METRICS SUMMARY ---
        [Concept]
        Validation  |  BMAC: {val_BMAC_concept:.5f}  |  Accuracy: {val_acc_concept:.5f}
        Test        |  BMAC: {test_BMAC_concept:.5f}  |  Accuracy: {test_acc_concept:.5f}

        [Overall Performance (Current / Best)]
        Validation  |  BMAC: {val_BMAC:.5f} / {best_bmac_val:.5f}  |  Accuracy: {val_acc:.5f} / {best_acc_val:.5f}
        Test        |  BMAC: {test_BMAC:.5f} / {best_bmac_test:.5f}  |  Accuracy: {test_acc:.5f} / {best_acc_test:.5f}
        """
        print(output)

    rank_epochs_data = {'train_ranks': rank_epochs[0], 'val_ranks': rank_epochs[1]}
    np.save('rank.npy', rank_epochs_data)

def validation(model, dataloader, criterion):
    net = model
    net.eval()
    
    losses_cls = 0
    losses_concepts = 0
    pred_list = np.zeros((0), dtype=np.uint8)
    gt_list = np.zeros((0), dtype=np.uint8)
    pred_list_concept = np.zeros((0), dtype=np.uint8)
    gt_list_concept = np.zeros((0), dtype=np.uint8)
    rank_list = []

    with torch.no_grad():
        for i, (data, label, concept_label) in enumerate(dataloader):
            data, label = data.cuda(), label.long().cuda()
            concept_label = concept_label.long().cuda()
            cls_logits, image_logits_dict, vq_loss, avg_rank = net(data)  
            rank_list.append(avg_rank)

            num_classes = cls_logits.shape[1]
            loss_cls = criterion(cls_logits, label)
            losses_cls += loss_cls.item()

            tmp_loss_concepts = 0
            idx = 0
            for key in net.concept_token_dict.keys():
                concept_logits = image_logits_dict[key]
                concept_target = concept_label[:, idx]

                num_concept_classes = concept_logits.shape[1]
                concept_target_one_hot = F.one_hot(concept_target, num_classes=num_concept_classes).float()
                image_concept_loss = F.binary_cross_entropy_with_logits(concept_logits, concept_target_one_hot)
                tmp_loss_concepts += image_concept_loss.item()
                
                _, concept_pred = torch.max(image_logits_dict[key], dim=1)
                pred_list_concept = np.concatenate((pred_list_concept, concept_pred.cpu().numpy().astype(np.uint8)), axis=0)
                gt_list_concept = np.concatenate((gt_list_concept, concept_label[:, idx].cpu().numpy().astype(np.uint8)), axis=0)
                idx += 1

            losses_concepts += tmp_loss_concepts / len(list(net.concept_token_dict.keys()))
            
            _, label_pred = torch.max(cls_logits, dim=1)
            pred_list = np.concatenate((pred_list, label_pred.cpu().numpy().astype(np.uint8)), axis=0)
            gt_list = np.concatenate((gt_list, label.cpu().numpy().astype(np.uint8)), axis=0)
    
    BMAC = balanced_accuracy_score(gt_list, pred_list)
    correct = np.sum(gt_list == pred_list)
    acc = 100 * correct / len(pred_list)
    
    BMAC_concept = balanced_accuracy_score(gt_list_concept, pred_list_concept) 
    acc_concept = 100 * np.sum(gt_list_concept == pred_list_concept) / len(pred_list_concept)
    return BMAC, acc, losses_cls/(i+1), losses_concepts/(i+1), BMAC_concept, acc_concept, np.mean(rank_list)

if __name__ == '__main__':
    parser = OptionParser()
    parser.add_option('-e', '--epochs', dest='epochs', default=150, type='int')
    parser.add_option('-b', '--batch_size', dest='batch_size', default=32, type='int')
    parser.add_option('--warmup_epoch', dest='warmup_epoch', default=5, type='int')
    parser.add_option('--optimizer', dest='optimizer', default='adamw', type='str')
    parser.add_option('-l', '--lr', dest='lr', default=0.0001, type='float')
    parser.add_option('-c', '--resume', type='str', dest='load', default=False)
    parser.add_option('-p', '--checkpoint-path', type='str', dest='cp_path', default='./checkpoint/')
    parser.add_option('-o', '--log-path', type='str', dest='log_path', default='./log/')
    parser.add_option('-m', '--model', type='str', dest='model', default='Model')
    parser.add_option('--linear-probe', dest='linear_probe', action='store_true')
    parser.add_option('-d', '--dataset', type='str', dest='dataset', default='isic2018')
    parser.add_option('--data-path', type='str', dest='data_path', default='./dataset/isic2018/')
    parser.add_option('-u', '--unique_name', type='str', dest='unique_name', default='test')
    parser.add_option('--flag', type='int', dest='flag', default=2)
    parser.add_option('--preprocess', default=None)
    parser.add_option('--gpu', type='str', dest='gpu', default='0')
    parser.add_option('--amp', action='store_true')

    (config, args) = parser.parse_args()
    
    os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu
    config.log_path = os.path.join(config.log_path, config.dataset)
    config.cp_path = os.path.join(config.cp_path, config.dataset)
    
    num_class_dict = {'isic2018': 7}
    cls_weight_dict = {'isic2018': [1, 1, 1, 1, 1, 1, 1]}
    
    config.cls_weight = cls_weight_dict[config.dataset]
    config.num_class = num_class_dict[config.dataset]

    from concept_dataset import isic_dict
    concept_list = isic_dict
    net = Model(concept_list=concept_list, model_name='biomedclip', config=config)

    if config.load:
        net.load_state_dict(torch.load(config.load))
        print(f'Model loaded from {config.load}')

    net.cuda()
    train_net(net, config)