import os
import sys

import time
import argparse
import copy
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from easydict import EasyDict
import unfoldNd
import torch
from torch.utils.data import DataLoader, ConcatDataset, Subset, random_split
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToPILImage
from data.dataset_3d_lungs import LabelledDS, UnlabelledDS
from models import deeplabv3
# from utils.loss_functions import DSCLoss as DSCLoss
from monai.losses import DiceFocalLoss as DSCLoss
from utils.logger import logger as logging
from utils.utils import *
from utils.mask_generator import BoxMaskGenerator, AddMaskParamsToBatch, SegCollate
from utils.ramps import sigmoid_rampup
from utils.torch_utils import seed_torch
from utils.model_init import init_weight
from monai.metrics import DiceMetric
# from torch.optim import AdamW
from cadam import AdamWHCO as AdamW
from cadam import CAdamW
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.enabled = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'


def get_args(known=False):
    parser = argparse.ArgumentParser(description='PyTorch Implementation')
    parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
    parser.add_argument('--project', type=str, default='/path/to/output/hco_outputs', help='project path for saving results')
    parser.add_argument('--run_name', type=str, default='experiment', help='name of this run for wandb tracking')
    parser.add_argument('--backbone', type=str, default='VNet', choices=['VNet'], help='segmentation backbone')
    parser.add_argument('--data_path', type=str, default='/path/to/data/labels/train_data_5_split_0.json', help='path to the data')
    parser.add_argument('--ckpt_continue', type=str, default=None, help='path to the data')
    parser.add_argument('--unlabelled_json_path', type=str, default='/path/to/data/unlabelled_split_1.json', help='path to the data')
    parser.add_argument('--val_json_path', type=str, default='/path/to/data/labels/val_data.json', help='path to the data')
    parser.add_argument('--image_size', type=int, default=[144, 144, 64], help='the size of images for training and testing')
    parser.add_argument('--labeled_percentage', type=float, default=0.2, help='the percentage of labeled data')
    parser.add_argument('--is_mix', type=bool, default=True, help='cut mix')
    parser.add_argument('--topk', type=int, default=2, help='top k')
    parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=4, help='number of inputs per batch')
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers to use for dataloader')
    parser.add_argument('--in_channels', type=int, default=1, help='input channels')
    parser.add_argument('--num_classes', type=int, default=3, help='number of target categories')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate')
    parser.add_argument('--intra_weights', type=list, default=[1., 1., 1.], help='inter classes weighted coefficients in the loss function')
    parser.add_argument('--inter_weight', type=float, default=0.8, help='inter losses weighted coefficients in the loss function')
    parser.add_argument('--log_freq', type=float, default=1, help='logging frequency of metrics accord to the current iteration')
    parser.add_argument('--save_freq', type=float, default=10, help='saving frequency of model weights accord to the current epoch')
    
    # Added AdamWHCO optimizer parameters
    parser.add_argument('--caution', action='store_true', help='Use caution in AdamWHCO optimizer')
    parser.add_argument('--hirearchical_caution', action='store_true', help='Use hierarchical caution in AdamWHCO optimizer')
    parser.add_argument('--unlabeled_momentum_update', action='store_true', help='Update momentum for unlabeled data')
    parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay for AdamWHCO optimizer')
    parser.add_argument('--beta1', type=float, default=0.9, help='Beta1 for AdamWHCO optimizer')
    parser.add_argument('--beta2', type=float, default=0.999, help='Beta2 for AdamWHCO optimizer')
    parser.add_argument('--eps', type=float, default=1e-8, help='Epsilon for AdamWHCO optimizer')
    
    # WandB parameters
    parser.add_argument('--use_wandb', action='store_true', help='Use wandb for logging')
    parser.add_argument('--wandb_project', type=str, default='lung-segmentation', help='WandB project name')
    parser.add_argument('--wandb_entity', type=str, default=None, help='WandB entity name')
    
    args = parser.parse_known_args()[0] if known else parser.parse_args()
    # make the run name from labeled split, unlabeled split, hirearchical_caution_{True/False}, caution_{True/False}, unlabeled_momentum_update_{True/False}
    args.run_name = (
    f"{args.run_name}_labeled_{args.data_path.split('_')[-1].split('.')[0]}_unlabeled_"
    f"{args.unlabelled_json_path.split('_')[-1].split('.')[0]}_hirearchical_caution_{args.hirearchical_caution}"
    f"_caution_{args.caution}_unlabeled_momentum_update_{args.unlabeled_momentum_update}/"
    )

    args.project = os.path.join(args.project, args.run_name)
    return args


def load_pretrained_weights(model, pretrained_dict, strict=False):
    """
    Loads pre-trained weights into a model, skipping mismatched layers.

    Args:
        model (torch.nn.Module): The model to load weights into.
        pretrained_dict (dict): The dictionary containing pre-trained weights.
        strict (bool, optional): Whether to raise an error for mismatched weights. Defaults to False.

    Returns:
        None
    """
    model_dict = model.state_dict()
    # Filter out weights with mismatched shapes
    filtered_pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}

    # Load weights for matching layers
    model.load_state_dict(filtered_pretrained_dict, strict=strict)

    # Initialize remaining weights randomly
    for name, param in model.named_parameters():
        if name not in filtered_pretrained_dict:
            print(f"Initializing weights for layer: {name}")
            if 'weight' in name:
                    torch.nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                torch.nn.init.constant_(param, 0.0)


def get_data(args):
    val_set = LabelledDS(json_file=args.val_json_path, image_size=args.image_size, stage='val', is_augmentation=False)
    labeled_train_set = LabelledDS(json_file=args.data_path, image_size=args.image_size, stage='train', is_augmentation=True)
    unlabeled_train_set = UnlabelledDS(json_file=args.unlabelled_json_path, image_size=args.image_size, stage='train', is_augmentation=True)

    print('before: ', len(unlabeled_train_set), len(labeled_train_set), len(val_set))
    labeled_ratio = len(unlabeled_train_set) // len(labeled_train_set)
    labeled_train_set = ConcatDataset([labeled_train_set for i in range(labeled_ratio)])
    labeled_train_set = ConcatDataset([labeled_train_set,
                                       Subset(labeled_train_set, range(len(unlabeled_train_set) - len(labeled_train_set)))])
    print('after: ', len(unlabeled_train_set), len(labeled_train_set), len(val_set))
    assert len(labeled_train_set) == len(unlabeled_train_set)
    train_labeled_dataloder = DataLoader(dataset=labeled_train_set, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True)
    train_unlabeled_dataloder = DataLoader(dataset=unlabeled_train_set, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True)
    val_dataloder = DataLoader(dataset=val_set, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, pin_memory=True)
    mask_generator = BoxMaskGenerator(prop_range=(0.25, 0.5),
                                      n_boxes=3,
                                      random_aspect_ratio=True,
                                      prop_by_area=True,
                                      within_bounds=True,
                                      invert=True)

    add_mask_params_to_batch = AddMaskParamsToBatch(mask_generator)
    mask_collate_fn = SegCollate(batch_aug_fn=add_mask_params_to_batch)
    aux_dataloder = DataLoader(dataset=unlabeled_train_set, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True, pin_memory=True, collate_fn=mask_collate_fn)
    return train_labeled_dataloder, train_unlabeled_dataloder, val_dataloder, aux_dataloder


def main(is_debug=False):
    args = get_args()
    seed_torch(args.seed)
    
    # Project Saving Path
    project_path = args.project
    ensure_dir(project_path)
    save_path = project_path + 'weights/'
    ensure_dir(save_path)
    
    # Tensorboard & Statistics Results & Logger
    tb_dir = project_path + '/tensorboard{}'.format(time.strftime("%b%d_%d-%H-%M", time.localtime()))
    writer = SummaryWriter(tb_dir)
    metrics = EasyDict()
    metrics.train_loss = []
    metrics.train_s_loss = []
    metrics.train_u_loss = []
    metrics.train_x_loss = []
    metrics.val_loss = []
    logger = logging(os.path.join(project_path, 'train_val.log'))
    logger.info('PyTorch Version {}\n Experiment{}'.format(torch.__version__, project_path))
    logger.info(f'Args: {args}')

    # Load Data
    train_labeled_dataloader, train_unlabeled_dataloader, val_dataloader, _ = get_data(args=args)
    iters = len(train_labeled_dataloader)
    val_iters = len(val_dataloader)

    # Load Model & EMA
    student1 = deeplabv3.__dict__[args.backbone](in_channels=args.in_channels, out_channels=args.num_classes).to(device)
    init_weight(student1.net.classifier, nn.init.kaiming_normal_,
                nn.BatchNorm3d, 1e-5, 0.1,
                mode='fan_in', nonlinearity='relu')
    student2 = deeplabv3.__dict__[args.backbone](in_channels=args.in_channels, out_channels=args.num_classes).to(device)
    init_weight(student2.net.classifier, nn.init.kaiming_normal_,
                nn.BatchNorm3d, 1e-5, 0.1,
                mode='fan_in', nonlinearity='relu')
    teacher = deeplabv3.__dict__[args.backbone](in_channels=args.in_channels, out_channels=args.num_classes).to(device)
    init_weight(teacher.net.classifier, nn.init.kaiming_normal_,
                nn.BatchNorm3d, 1e-5, 0.1,
                mode='fan_in', nonlinearity='relu')
    
    if args.ckpt_continue is not None:
        print("loading checkpoint")
        student1.load_state_dict(torch.load(args.ckpt_continue))
        student2.load_state_dict(torch.load(args.ckpt_continue))
        teacher.load_state_dict(torch.load(args.ckpt_continue))

    teacher.detach_model()
    best_epoch = 0
    best_loss = 100

    # Criterion & Optimizer & LR Schedule
    criterion_dsc = DSCLoss(to_onehot_y=True, softmax=True, reduction='mean')
    
   
    if args.caution:
        # Use AdamW with caution
        optimizer1 = CAdamW(
            student1.parameters(), 
            lr=args.learning_rate, 
            betas=(args.beta1, args.beta2), 
            eps=args.eps,
            weight_decay=args.weight_decay,
            caution=args.caution
        )
        optimizer2 = CAdamW(
            student2.parameters(), 
            lr=args.learning_rate, 
            betas=(args.beta1, args.beta2), 
            eps=args.eps,
            weight_decay=args.weight_decay,
            caution=args.caution
        )
    else:
        # Use AdamW without caution
        optimizer1 = optim.AdamW(
            student1.parameters(), 
            lr=args.learning_rate, 
            betas=(args.beta1, args.beta2), 
            eps=args.eps,
            weight_decay=args.weight_decay
        )
        optimizer2 = optim.AdamW(
            student2.parameters(), 
            lr=args.learning_rate, 
            betas=(args.beta1, args.beta2), 
            eps=args.eps,
            weight_decay=args.weight_decay
        )

    # Log optimizer configuration
    logger.info(f"Optimizer configuration:")
    logger.info(f"  Learning rate: {args.learning_rate}")
    logger.info(f"  Betas: ({args.beta1}, {args.beta2})")
    logger.info(f"  Weight decay: {args.weight_decay}")
    logger.info(f"  Caution: {args.caution}")
    logger.info(f"  Hierarchical caution: {args.hirearchical_caution}")
    logger.info(f"  Unlabeled momentum update: {args.unlabeled_momentum_update}")

    # Train
    since = time.time()
    logger.info('start training')
    
    for epoch in range(1, args.num_epochs + 1):
        epoch_metrics = EasyDict()
        epoch_metrics.train_loss = []
        epoch_metrics.train_s_loss = []
        epoch_metrics.train_u_loss = []
        epoch_metrics.train_x_loss = []
        if is_debug:
            pbar = range(10)
        else:
            pbar = tqdm(range(iters), desc=f"Epoch {epoch}/{args.num_epochs}")
        iter_train_labeled_dataloader = iter(train_labeled_dataloader)
        iter_train_unlabeled_dataloader = iter(train_unlabeled_dataloader)

        ############################
        # Train
        ############################
        student1.train()
        student2.train()
        teacher.train()
        
        for idx in pbar:
            # Labeled data
            image, label, imageA1, imageA2 = next(iter_train_labeled_dataloader)
            image, label = image.to(device), label.to(device)
            imageA1, imageA2 = imageA1.to(device), imageA2.to(device)
            # Unlabeled data
            uimage, _, uimageA1, uimageA2 = next(iter_train_unlabeled_dataloader)
            uimage, uimageA1, uimageA2 = uimage.to(device), uimageA1.to(device), uimageA2.to(device)

            optimizer1.zero_grad()
            optimizer2.zero_grad()

            # ----- Supervised Path -----
            pred_s1_logits = student1(imageA1)['out']
            pred_s2_logits = student2(imageA2)['out']
            loss_s = (criterion_dsc(pred_s1_logits, label) + criterion_dsc(pred_s2_logits, label)) / 2.

            # ----- Unsupervised Path -----
            with torch.no_grad():
                pred_u = teacher(uimage)
                pred_u_logits = pred_u['out']
                pred_u_probs = torch.softmax(pred_u_logits, dim=1)
                # Detach pseudo labels
                pred_u_pseudo = torch.argmax(pred_u_probs, dim=1).detach()

            pred_u1A1 = student1(uimageA1)
            pred_u1A1_logits = pred_u1A1['out']
            pred_u1A1_probs = torch.softmax(pred_u1A1_logits, dim=1)
            pred_u1A1_pseudo = torch.argmax(pred_u1A1_probs, dim=1).detach()
            
            pred_u2A2 = student2(uimageA2)
            pred_u2A2_logits = pred_u2A2['out']
            pred_u2A2_probs = torch.softmax(pred_u2A2_logits, dim=1)
            pred_u2A2_pseudo = torch.argmax(pred_u2A2_probs, dim=1).detach()
            
            # Unsupervised losses; note the use of detach() on pseudo labels.
            lambda_ = sigmoid_rampup(current=idx + len(pbar) * (epoch-1), rampup_length=len(pbar)*5)
            loss_x = (criterion_dsc(pred_u1A1_logits, pred_u2A2_pseudo.unsqueeze(1).detach()) +
                    criterion_dsc(pred_u2A2_logits, pred_u1A1_pseudo.unsqueeze(1).detach())) / 2.
            loss_u = (criterion_dsc(pred_u1A1_logits, pred_u_pseudo.unsqueeze(1).detach()) +
                    criterion_dsc(pred_u2A2_logits, pred_u_pseudo.unsqueeze(1).detach())) / 2.
            loss = loss_s + loss_x * 0.1 * lambda_ + loss_u * 0.1 * lambda_

            loss.backward()
            optimizer1.step()
            optimizer2.step()


            teacher.weighted_update(student1, student2, ema_decay=0.99,
                                    cur_step=idx + len(pbar) * (epoch-1))
            train_loss = loss.item()
            
            # Update tensorboard
            writer.add_scalar('train_s_loss', loss_s.item(), idx + len(pbar) * (epoch-1))
            writer.add_scalar('train_u_loss', loss_u.item(), idx + len(pbar) * (epoch-1))
            writer.add_scalar('train_x_loss', loss_x.item(), idx + len(pbar) * (epoch-1))
            writer.add_scalar('train_loss', train_loss, idx + len(pbar) * (epoch-1))
            
            if idx % args.log_freq == 0:
                logger.info("Train: Epoch/Epochs {}/{}, "
                            "iter/iters {}/{}, "
                            "loss {:.3f}, loss_s {:.3f}, loss_u {:.3f}, loss_x {:.3f}, lambda {:.3f}".format(epoch, args.num_epochs, idx, len(pbar),
                                                                                  loss_s.item()+loss_u.item()+ loss_x.item(), loss_s.item(), loss_u.item(), loss_x.item(), lambda_))
                
                # Update tqdm description
                if not is_debug:
                    pbar.set_description(f"Epoch {epoch}/{args.num_epochs} Loss: {train_loss:.3f}")
                    
            epoch_metrics.train_loss.append(train_loss)
            epoch_metrics.train_s_loss.append(loss_s.item())
            epoch_metrics.train_u_loss.append(loss_u.item())
            epoch_metrics.train_x_loss.append(loss_x.item())

        ############################
        # Validation
        ############################
        epoch_metrics.val_loss = []
        iter_val_dataloader = iter(val_dataloader)
        val_pbar = tqdm(range(val_iters), desc=f"Validation Epoch {epoch}/{args.num_epochs}")
        teacher.eval()
        dice_metric = DiceMetric(include_background=False, reduction="mean")
        
        with torch.no_grad():
            for idx in val_pbar:
                image, label = next(iter_val_dataloader)
                image, label = image.to(device), label.to(device)
                pred = teacher(image)['out']
                loss = criterion_dsc(pred, label)
                epoch_metrics.val_loss.append(loss.item())
                # Calculate dice score
                pred_seg = torch.argmax(pred, dim=1).unsqueeze(1)  # Convert to one-hot format for dice calculation
                # Calculate dice score - returns a tensor with one value per class (excluding background)
                dice_scores = dice_metric(pred_seg, label)

                # Add individual dice scores to tensorboard and other logging
                writer.add_scalar('val_loss', loss.item(), idx + len(val_pbar) * (epoch-1))
                writer.add_scalar('val_dice_class1', dice_scores[0].item(), idx + len(val_pbar) * (epoch-1))
                writer.add_scalar('val_dice_class2', dice_scores[1].item(), idx + len(val_pbar) * (epoch-1))

                # Calculate mean dice for overall metric
                mean_dice = dice_scores.mean().item()
                writer.add_scalar('val_dice_mean', mean_dice, idx + len(val_pbar) * (epoch-1))

                if idx % args.log_freq == 0:
                    logger.info("Val: Epoch/Epochs {}/{}, "
                                "iter/iters {}/{}, "
                                "loss {:.3f}, dice_1 {:.3f}, dice_2 {:.3f}, mean_dice {:.3f}".format(
                                    epoch, args.num_epochs, idx, len(val_pbar),
                                    loss.item(), dice_scores[0].item(), dice_scores[1].item(), mean_dice))

                if not is_debug:
                    val_pbar.set_description(
                        f"Val Epoch {epoch}/{args.num_epochs} Loss: {loss.item():.3f} "
                        f"Dice1: {dice_scores[0].item():.3f} Dice2: {dice_scores[1].item():.3f}")

        # Calculate average metrics
        avg_val_loss = np.mean(epoch_metrics.val_loss)
        avg_train_loss = np.mean(epoch_metrics.train_loss)
        
        logger.info("Average: Epoch/Epochs {}/{}, "
                    "train epoch loss {:.3f}, "
                    "val epoch loss {:.3f}\n".format(epoch, args.num_epochs, avg_train_loss, avg_val_loss))
                    
        metrics.val_loss.append(avg_val_loss)
        
        # Save Model
        if avg_val_loss <= best_loss:
            best_epoch = epoch
            best_loss = avg_val_loss
            torch.save(teacher.state_dict(), save_path + 'best.pth')
            logger.info(f"New best model saved! Epoch {best_epoch}, Loss: {best_loss:.4f}")
                
        torch.save(teacher.state_dict(), save_path + 'last.pth')

        torch.save(teacher.state_dict(), save_path + 'teacher_last.pth')
        torch.save(student1.state_dict(), save_path + 'student1_last.pth')
        torch.save(student2.state_dict(), save_path + 'student2_last.pth')

        if epoch % args.save_freq == 0:
            torch.save(teacher.state_dict(), save_path + f'teacher_{epoch}.pth')

    ############################
    # Save Metrics
    ############################
    data_frame = pd.DataFrame(
        data={'loss': metrics.train_loss,
              'loss_s': metrics.train_s_loss,
              'loss_u': metrics.train_u_loss,
              'loss_x': metrics.train_x_loss,
              'val_loss': metrics.val_loss},
        index=range(1, args.num_epochs + 1))
    data_frame.to_csv(project_path + 'train_val_loss.csv', index_label='Epoch')
    
    plt.figure(figsize=(10, 6))
    plt.title("Loss During Training and Validating")
    plt.plot(metrics.train_loss, label="Train")
    plt.plot(metrics.val_loss, label="Val")
    plt.xlabel("epochs")
    plt.ylabel("Loss")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.savefig(project_path + 'train_val_loss.png')
    
    print(f"Results saved to: {project_path}")
    time_elapsed = time.time() - since
    logger.info('Training completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logger.info('TRAINING FINISHED!')


if __name__ == '__main__':
    main()
