import rich.progress
import torch.utils
import torch.utils.data
import torch
import torchvision
import torch.utils.data.distributed
import model.networks as nets
from dataset import classify_data
import os
import rich
import argparse
import torch.optim.lr_scheduler as lr_scheduler
import torch.distributed as dist
import math
import logging
from rich.console import Console


def get_args():
    parser = argparse.ArgumentParser(description='training parameters')
    parser.add_argument('--model', type=str)
    parser.add_argument('--in_channels', type=int)
    parser.add_argument('--in_size', type=int)
    parser.add_argument('--classify_num', type=int)
    parser.add_argument('--train_dataset', type=str)
    parser.add_argument('--test_dataset', type=str)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--batch_size', type=str)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--epochs', type=int)
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--continue_train', type=str, required=True)
    parser.add_argument('--checkpoint', type=str, help='RacoNet parameters')
    parser.add_argument('--parameter', type=str, help='GIAO parameters')
    parser.add_argument('--start_epoch', type=int, default=0)
    parser.add_argument('--dataset_name', type=str)
    parser.add_argument('--checkpoints_dir', type=str, required=True)
    return parser.parse_args()


def get_model(num_class, model_name, in_size, in_channels, parameter):
    if model_name == 'RacoNetClassify':
        model_train = nets.RacoNetClassify(
                coder_in_channels=in_channels,
                coder_in_size=in_size,
                coder_out_channels=8,
                coder_out_size=64,
                coder_parameters=parameter,

                classify_in_channels=in_channels,
                classify_num=num_class,
                classify_out_channels=8,
                classify_out_size=64
        )
    
    else:
        raise f'ERROR!!! NO {model_name} !!!'

    return model_train

def train_model(
        num_class,
        model_name,
        device,
        batch_size,
        lr,
        train_data_csv,
        test_data_csv,
        epochs,
        image_size,
        in_channels,
        continue_train,
        parameter,
        checkpoint,
        dataset_name,
        checkpoints_dir
):
    # Necessary imports
    import numpy as np
    from sklearn.metrics import (
        accuracy_score, f1_score, recall_score, precision_score,
        roc_auc_score, confusion_matrix, cohen_kappa_score
    )
    from sklearn.preprocessing import label_binarize

    # Logger setup
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    log_file = "training.log"
    file_handler = logging.FileHandler(log_file, mode='a')
    file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(message)s')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    console = Console()

    transform_img = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize((256, 256))
    ])

    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    dist.init_process_group(backend='nccl')

    # Load training dataset
    train_dataset = classify_data.ClassifyDataSet(
        data_csv=train_data_csv,
        transform_img=transform_img
    )
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=dist.get_world_size(),
        rank=dist.get_rank()
    )
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=int(batch_size),
        sampler=train_sampler,
        drop_last=True,
        num_workers=12,
        prefetch_factor=10
    )

    # Load testing dataset
    test_dataset = classify_data.ClassifyDataSet(
        data_csv=test_data_csv,
        transform_img=transform_img
    )
    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_dataset,
        num_replicas=dist.get_world_size(),
        rank=dist.get_rank(),
        shuffle=False,
        drop_last=True
    )
    test_data_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=int(int(batch_size) / 2),
        sampler=test_sampler,
        drop_last=False,
        num_workers=12,
        prefetch_factor=10
    )

    model = get_model(num_class, model_name, image_size, in_channels, parameter)

    if continue_train.lower() == 'true':
        old_weights = torch.load(checkpoint, map_location='cpu')
        new_weights = {}
        for k, v in old_weights.items():
            if '.freq_conv.' in k:
                new_k = k.replace('.freq_conv.', '.freq_module.conv.')
            elif '.freq_norm.' in k:
                new_k = k.replace('.freq_norm.', '.freq_module.norm.')
            else:
                new_k = k
            if 'module.' in new_k:
                new_k = new_k.replace('module.', '')
            new_weights[new_k] = v

    model = model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[local_rank],
        output_device=local_rank,
        find_unused_parameters=True
    )

    CrossEntropyloss = torch.nn.CrossEntropyLoss().to(device)
    MSEloss = torch.nn.MSELoss().to(device)
    L1loss = torch.nn.L1Loss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    warmup_epochs = int(0.1 * epochs)
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: (
        (epoch + 1) / warmup_epochs if epoch < warmup_epochs else
        0.5 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
    ))
    acc_max = 0

    with rich.progress.Progress(console=console) as ps:
        task = ps.add_task('[cyan] training... ', total=epochs * len(train_data_loader))
        for epoch in range(epochs):
            train_sampler.set_epoch(epoch)
            model.train()
            loss_add = 0
            cross = 0
            total_samples = 0

            for _, (imgs, labs, truths) in enumerate(train_data_loader):
                imgs, labs, truths = imgs.to(device, dtype=torch.float32), labs.to(device, dtype=torch.long), truths.to(device, dtype=torch.float32)
                if model_name == 'RacoNetClassify':
                    encoder, out_img, logit = model(imgs, truths)
                    loss = CrossEntropyloss(logit, labs) * 0.1 + (L1loss(out_img, encoder) + MSEloss(out_img, encoder)) * 100
                else:
                    prediction = model(imgs)
                    loss = CrossEntropyloss(prediction, labs)

                batch_size_current = imgs.size(0)
                loss_add += loss.item() * batch_size_current
                cross += loss.item() * batch_size_current
                total_samples += batch_size_current

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ps.update(task, advance=1, description=f'epoch:{epoch+1}/{epochs}  Cross:{loss.item():.5f}  BestAcc:{acc_max:.3f}')
                torch.cuda.synchronize()

            avg_loss = loss_add / total_samples
            avg_cross = cross / total_samples

            if (epoch + 1) % 1 == 0:
                console.log(f'NO. {epoch + 1}  average loss:{avg_loss:.5f}  CrossEntropyloss:{avg_cross:.5f}')
                if local_rank == 0:
                    logger.info(f'NO. {epoch + 1}  average loss:{avg_loss:.5f}  CrossEntropyloss:{avg_cross:.5f}')

                model.eval()
                all_preds = []
                all_labels = []
                all_probs = []
                total_test_loss = 0.0
                total_test_samples = 0

                test_sampler.set_epoch(epoch)
                eval_task = ps.add_task('[green] eval... ', total=len(test_data_loader))
                with torch.no_grad():
                    for imgs, labs, truths in test_data_loader:
                        imgs, labs, truths = imgs.to(device, dtype=torch.float32), labs.to(device, dtype=torch.long), truths.to(device, dtype=torch.float32)

                        # Handle models that return multiple outputs
                        if model_name == 'RacoNetClassify':
                            _, _, prediction = model(imgs, truths)
                        else:
                            prediction = model(imgs)

                        # Ensure prediction is a tensor
                        if isinstance(prediction, tuple):
                            prediction = prediction[0]

                        loss_test = CrossEntropyloss(prediction, labs)

                        batch_size_current = imgs.size(0)
                        total_test_loss += loss_test.item() * batch_size_current
                        total_test_samples += batch_size_current

                        probs = torch.softmax(prediction, dim=1)
                        preds = torch.argmax(prediction, dim=1)
                        all_probs.append(probs.cpu())
                        all_preds.append(preds.cpu())
                        all_labels.append(labs.cpu())
                        ps.update(eval_task, advance=1)

                ps.remove_task(eval_task)

                # Using all_reduce to sum losses and sample counts
                total_test_loss_tensor = torch.tensor([total_test_loss], dtype=torch.float32).to(device)
                total_test_samples_tensor = torch.tensor([total_test_samples], dtype=torch.float32).to(device)
                dist.all_reduce(total_test_loss_tensor, op=dist.ReduceOp.SUM)
                dist.all_reduce(total_test_samples_tensor, op=dist.ReduceOp.SUM)
                average_test_loss = total_test_loss_tensor.item() / total_test_samples_tensor.item()

                # Concatenate tensors
                all_preds = torch.cat(all_preds)
                all_labels = torch.cat(all_labels)
                all_probs = torch.cat(all_probs)

                # Gather all data across processes
                def gather_tensor(tensor, device):
                    world_size = dist.get_world_size()
                    rank = dist.get_rank()
                    # Get sizes of tensors on each process
                    local_size = torch.tensor([tensor.size(0)], device=device)
                    size_list = [torch.tensor([0], device=device) for _ in range(world_size)]
                    dist.all_gather(size_list, local_size)
                    sizes = [int(size.item()) for size in size_list]
                    max_size = max(sizes)
                    # Pad tensor to max_size
                    pad_size = [max_size - tensor.size(0)] + list(tensor.shape[1:])
                    if pad_size[0] > 0:
                        padding = torch.zeros(pad_size, dtype=tensor.dtype, device=device)
                        padded_tensor = torch.cat([tensor.to(device), padding], dim=0)
                    else:
                        padded_tensor = tensor.to(device)
                    # Gather tensors
                    tensor_list = [torch.zeros_like(padded_tensor) for _ in range(world_size)]
                    dist.all_gather(tensor_list, padded_tensor)
                    # Remove padding
                    tensors = []
                    for i, size in enumerate(sizes):
                        tensors.append(tensor_list[i][:size])
                    return torch.cat(tensors, dim=0)

                all_preds = gather_tensor(all_preds, device)
                all_labels = gather_tensor(all_labels, device)
                all_probs = gather_tensor(all_probs, device)

                # Convert to numpy arrays
                all_preds_np = all_preds.cpu().numpy()
                all_labels_np = all_labels.cpu().numpy()
                all_probs_np = all_probs.cpu().numpy()

                num_classes = num_class  # or len(np.unique(all_labels_np))

                # Compute evaluation metrics
                acc = accuracy_score(all_labels_np, all_preds_np) * 100

                # Per-class metrics
                precision_per_class = precision_score(all_labels_np, all_preds_np, average=None, zero_division=0)
                recall_per_class = recall_score(all_labels_np, all_preds_np, average=None, zero_division=0)
                f1_per_class = f1_score(all_labels_np, all_preds_np, average=None, zero_division=0)

                precision_mean = np.mean(precision_per_class)
                precision_std = np.std(precision_per_class)
                recall_mean = np.mean(recall_per_class)
                recall_std = np.std(recall_per_class)
                f1_mean = np.mean(f1_per_class)
                f1_std = np.std(f1_per_class)

                # Confusion matrix
                cm = confusion_matrix(all_labels_np, all_preds_np)

                # Compute TN, FP, FN, TP per class
                TP = np.diag(cm)
                FP = np.sum(cm, axis=0) - TP
                FN = np.sum(cm, axis=1) - TP
                TN = np.sum(cm) - (TP + FP + FN)

                # Avoid division by zero
                specificity_per_class = np.divide(TN, TN + FP, out=np.zeros_like(TN, dtype=float), where=(TN + FP) != 0)
                specificity_mean = np.mean(specificity_per_class)
                specificity_std = np.std(specificity_per_class)

                # G-Mean per class
                G_mean_per_class = np.sqrt(recall_per_class * specificity_per_class)
                G_mean_mean = np.mean(G_mean_per_class)
                G_mean_std = np.std(G_mean_per_class)

                # Kappa
                kappa = cohen_kappa_score(all_labels_np, all_preds_np)

                # Dice coefficient (same as F1 score in multi-class)
                Dice_per_class = f1_per_class
                Dice_mean = np.mean(Dice_per_class)
                Dice_std = np.std(Dice_per_class)

                # ROC AUC per class
                all_labels_binarized = label_binarize(all_labels_np, classes=range(num_classes))
                auc_per_class = []
                for i in range(num_classes):
                    if np.sum(all_labels_binarized[:, i]) == 0:
                        auc = np.nan
                    else:
                        auc = roc_auc_score(all_labels_binarized[:, i], all_probs_np[:, i])
                    auc_per_class.append(auc)
                auc_per_class = np.array(auc_per_class)
                auc_mean = np.nanmean(auc_per_class)
                auc_std = np.nanstd(auc_per_class)

                if local_rank == 0:
                    # Output evaluation results
                    console.log(f'{model_name} NO. {epoch + 1} result of test:')
                    console.log(f'accuracy: {acc:.3f}%')
                    console.log(f'F1 Score: Mean={f1_mean:.4f}, Std={f1_std:.4f}')
                    console.log(f'Precision (PPV): Mean={precision_mean:.4f}, Std={precision_std:.4f}')
                    console.log(f'Recall (Sensitivity): Mean={recall_mean:.4f}, Std={recall_std:.4f}')
                    console.log(f'AUC Score: Mean={auc_mean:.4f}, Std={auc_std:.4f}')
                    console.log(f'G-Mean: Mean={G_mean_mean:.4f}, Std={G_mean_std:.4f}')
                    console.log(f'Specificity: Mean={specificity_mean:.4f}, Std={specificity_std:.4f}')
                    console.log(f'Kappa: {kappa:.4f}')
                    console.log(f'Dice Coefficient: Mean={Dice_mean:.4f}, Std={Dice_std:.4f}')

                if acc > acc_max:
                    acc_max = acc
                    if local_rank == 0:
                        logger.info(f'NO. {epoch + 1} Test-{dataset_name}')
                        logger.info(f'accuracy: {acc:.3f}%')
                        logger.info(f'F1 Score: Mean={f1_mean:.4f}, Std={f1_std:.4f}')
                        logger.info(f'Precision (PPV): Mean={precision_mean:.4f}, Std={precision_std:.4f}')
                        logger.info(f'Recall (Sensitivity): Mean={recall_mean:.4f}, Std={recall_std:.4f}')
                        logger.info(f'Specificity: Mean={specificity_mean:.4f}, Std={specificity_std:.4f}')
                        logger.info(f'G-Mean: Mean={G_mean_mean:.4f}, Std={G_mean_std:.4f}')
                        logger.info(f'Kappa: {kappa:.4f}')
                        logger.info(f'Dice Coefficient: Mean={Dice_mean:.4f}, Std={Dice_std:.4f}')

                        # Save model
                        os.makedirs(f'{checkpoints_dir}/{model_name}/{dataset_name}', exist_ok=True)
                        torch.save(model.state_dict(),
                                f'{checkpoints_dir}/{model_name}/{dataset_name}/model_weights_{acc_max:.5f}.pth')
                        torch.cuda.synchronize()
            scheduler.step()


if __name__ == '__main__':
    args = get_args()
    console = Console()
    train_model(
        num_class=args.classify_num,
        model_name=args.model,
        device=args.device,
        batch_size=args.batch_size,
        lr=args.lr,
        epochs=args.epochs,
        image_size=args.in_size,
        in_channels=args.in_channels,
        continue_train=args.continue_train,
        parameter=args.parameter,
        train_data_csv=args.train_dataset,
        test_data_csv=args.test_dataset,
        checkpoint=args.checkpoint,
        checkpoints_dir=args.checkpoints_dir,
        dataset_name=args.dataset_name
    )

