import wandb
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms

import os
import utils, losses
import glob
import datasets
import sys
import numpy as np
import matplotlib.pyplot as plt
import ours_mamba as MODELS
from tqdm import tqdm

import multiprocessing
import ml_collections
import time


def train(rank, world_size, gpu_ids, config, port):
    utils.setup(rank, world_size, port)
    torch.cuda.set_device(gpu_ids[rank])
    device = torch.device(f"cuda:{gpu_ids[rank]}")
    save_dir = config.save_dir
    model = MODELS.Orochi_Finetune(config).to(device)
    if rank == 0:
        os.makedirs(save_dir+'experiments/logs/', exist_ok=True)
        os.makedirs(save_dir+'experiments/visualizations/', exist_ok=True)
        logger = utils.Logger(save_dir+'experiments/logs/')
        sys.stdout = logger
        sys.stderr = logger

    start_epoch = 0
    best_dsc = 0
    global_step = 0

    train_composed = transforms.Compose([
        datasets.NumpyType((np.float32, np.float32)),
                                         ],)

    val_composed = transforms.Compose([
        datasets.Seg_norm(),
        datasets.NumpyType((np.float32, np.int16)),
                                       ],)

    train_set = datasets.IXIBrainDataset(glob.glob(config.train_dir + '*.pkl'),
                                         config.atlas_dir,
                                         transforms=train_composed)
    val_set = datasets.IXIBrainInferDataset(glob.glob(config.val_dir + '*.pkl'),
                                            config.atlas_dir,
                                            transforms=val_composed)
    
    train_sampler = DistributedSampler(train_set)
    train_loader = DataLoader(train_set,
                              batch_size=config.batch_size,
                              sampler=train_sampler, 
                              num_workers=config.num_workers,
                              pin_memory=True,
                              prefetch_factor=2)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            shuffle=False,
                            num_workers=config.num_workers,
                            pin_memory=True)

    optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    total_steps = len(train_loader) * config.max_epoch
    warmup_steps = int(total_steps * config.warmup_ratio)
    scheduler = utils.WarmupCosineSchedule(optimizer, 
                                    warmup_steps=warmup_steps, 
                                    t_total=total_steps, 
                                    warmup_start_factor=config.warmup_start_factor)

    if config.if_resume:
        if os.path.isfile(config.checkpoint_dir):
            print("##################Checkpoint found##################") 
            checkpoint = torch.load(config.checkpoint_dir, map_location=device)
            start_epoch = checkpoint['epoch']
            best_dsc = checkpoint['best_DSC']
            global_step = start_epoch * len(train_loader)  
            
            state_dict = checkpoint['state_dict']
            if "module." in list(state_dict.keys())[0] and "module." not in list(model.state_dict().keys())[0]:
                state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
            elif "module." not in list(state_dict.keys())[0] and "module." in list(model.state_dict().keys())[0]:
                state_dict = {f"module.{k}": v for k, v in state_dict.items()}
            
            model.load_state_dict(state_dict)
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print(f"=> loaded checkpoint '{config.checkpoint_dir}' (epoch {checkpoint['epoch']})") 
        else:
            print(f"=> no checkpoint found at '{config.checkpoint_dir}'")     
    elif config.checkpoint_dir:
        print("##################Checkpoint found##################") 
        checkpoint = torch.load(config.checkpoint_dir, map_location=device)
        checkpoint_state_dict = checkpoint['state_dict']

        if "module." in list(checkpoint_state_dict.keys())[0] and "module." not in list(model.state_dict().keys())[0]:
            checkpoint_state_dict = {k.replace("module.", ""): v for k, v in checkpoint_state_dict.items()}
        elif "module." not in list(checkpoint_state_dict.keys())[0] and "module." in list(model.state_dict().keys())[0]:
            checkpoint_state_dict = {f"module.{k}": v for k, v in checkpoint_state_dict.items()}

        selected_state_dict = {k: v for k, v in checkpoint_state_dict.items() 
                            if any(module in k for module in config.load_modules['load'])}

        model.load_state_dict(selected_state_dict, strict=False)
        print(f"Loaded {config.load_modules['load']} from checkpoint: {config.checkpoint_dir}") 

        for name, param in model.named_parameters():
            for froze_module in config.load_modules['froze']:
                for unfroze_module in config.load_modules['unfroze_from_froze']:
                    if froze_module in name:
                        param.requires_grad = False
                    if unfroze_module in name:
                        param.requires_grad = True
        print(f"froze {config.load_modules['froze']}") 
        print(f"unfroze from froze {config.load_modules['unfroze_from_froze']}")  
        del checkpoint
        print("##################Checkpoint loaded##################")  

    model = DDP(model, device_ids=[gpu_ids[rank]], output_device=gpu_ids[rank], find_unused_parameters=True)
    if rank == 0:
        print(f'Configuration: {config}')
        MODELS.print_model_details(model)
        # 初始化 wandb
        if config.wandb_key:
            wandb.login(key=config.wandb_key)
            wandb.init(project=config.wandb_project,
                       config=config,
                       save_code=True,
                       group="finetune",
                       job_type="train",
                       name=f"IXI_reg{time.strftime('%m_%d_%H_%M', time.localtime())}",
                       )
            wandb.watch(model, log="all", log_freq=100, log_graph=True)
            print(f"Wandb initialized with run name: {wandb.run.name}")
    
    reg_model = utils.register_model(config.img_size, 'nearest').to(device)
    reg_model_bilin = utils.register_model(config.img_size, 'bilinear').to(device)

    writer = SummaryWriter(log_dir=save_dir+'experiments/logs/') 
    save_dir = save_dir+'experiments/'

    for epoch in range(start_epoch, config.max_epoch):
        model.train()
        train_loader.sampler.set_epoch(epoch)
        loss_all = utils.AverageMeter()
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}/{config.max_epoch}', disable=rank != 0)

        for idx, data in enumerate(progress_bar):
            try:
                data = [t.to(device) for t in data]
                source, target, source_seg, target_seg = data
            
                logits, aux_loss = model(source, target)
                optimizer.zero_grad()
                flat_aux_loss = utils.flatten_loss_dict(aux_loss)
                loss = sum(flat_aux_loss.values())
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                scheduler.step()
                loss_all.update(loss.item(), source.numel())
                
                if rank == 0:
                    postfix_dict = {
                        'loss': f'{loss.item():.4f}',
                        'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'
                    }
                    for key, value in flat_aux_loss.items():
                        postfix_dict[key] = f'{value.item():.4f}'
                    progress_bar.set_postfix(postfix_dict)

                    if config.wandb_key:
                        wandb.log({
                            'Loss/train': loss.item(),
                            'LearningRate': optimizer.param_groups[0]['lr'],
                            **{f'Loss/{key}': value.item() for key, value in flat_aux_loss.items()}
                        }, step=global_step)

                    writer.add_scalar('Loss/train', loss.item(), global_step)
                    for key, value in flat_aux_loss.items():
                        writer.add_scalar(f'Loss/{key}', value.item(), global_step)
                    writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], global_step)

                global_step += 1

            except Exception as e:
                print(f"Error in training loop: {e}")
                continue

        if rank == 0 and epoch % config.save_steps == 0:
            model.eval()
            eval_dsc = utils.AverageMeter()
            with torch.no_grad():
                val_progress = tqdm(val_loader, desc=f'Validation Epoch {epoch}', leave=False)
                for idx, data in enumerate(val_progress):
                    data = [t.to(device) for t in data]
                    source, target, source_seg, target_seg = data
                    logits, aux_loss = model(source, target)
                    deformed_move, flow = logits['registered'], logits['flow']
                    grid_img = utils.mk_grid_img(8, 1, config.img_size).to(device)
                    deformed_seg = reg_model([source_seg.float(), flow])
                    deformed_grid = reg_model_bilin([grid_img.float(), flow])
                    dsc = utils.dice_val_VOI(deformed_seg.long(), target_seg.long())
                    eval_dsc.update(dsc.item(), target.size(0))
                    val_progress.set_postfix({'DSC': f'{dsc.item():.4f}', 'Avg DSC': f'{eval_dsc.avg:.4f}'})

                    plt.figure(figsize=(20, 16))
                    plt.subplot(141)
                    plt.imshow(source_seg[0, 0, 48].cpu().numpy(), cmap='gray')
                    plt.title('Move seg')
                    plt.subplot(142)
                    plt.imshow(target_seg[0, 0, 48].cpu().numpy(), cmap='gray')
                    plt.title('Fix seg')
                    plt.subplot(143)
                    plt.imshow(deformed_seg[0, 0, 48].cpu().numpy(), cmap='gray')
                    plt.title('Deformed seg')
                    plt.subplot(144)
                    plt.imshow(deformed_grid[0, 0, 48].cpu().numpy(), cmap='gray')
                    plt.title('Deformed Grid')
                    plt.subplot(341)
                    plt.imshow(flow[0, 0, 48].cpu().numpy(), cmap='gray')
                    plt.title('flow')
                    plt.subplot(342)
                    plt.imshow(source[0, 0, 48].cpu().numpy(), cmap='hot')
                    plt.title('Move')
                    plt.subplot(343)
                    plt.imshow(target[0, 0, 48].cpu().numpy(), cmap='hot')
                    plt.title('Fix')
                    plt.subplot(344)
                    plt.imshow(deformed_move[0, 0, 48].cpu().numpy(), cmap='hot')
                    plt.title('Deformed Move')
                    vis_save_path = os.path.join(save_dir,
                                                    'visualizations',
                                                    f'epoch_{epoch}_iter_{global_step}',
                                                    f'DSC_{dsc.item():.4f}_idx_{idx}.png')
                    if not os.path.exists(os.path.dirname(vis_save_path)):
                        os.makedirs(os.path.dirname(vis_save_path))
                    plt.savefig(vis_save_path)
                    plt.close()
            print(f'Epoch {epoch}, Best DSC {best_dsc}, Avg DSC {eval_dsc.avg}')
            
            if config.wandb_key:
                wandb.log({
                    'Validation/Avg_DSC': eval_dsc.avg,
                    'Validation/Best_DSC': best_dsc
                }, step=global_step)

            is_best = eval_dsc.avg > best_dsc
            best_dsc = max(eval_dsc.avg, best_dsc)
            utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_DSC': best_dsc,
            }, config, is_best, save_dir=save_dir, filename=f'Orochi_epoch_{epoch+1}_DSC_{eval_dsc.avg:.4f}.pth.tar')

    if rank == 0:
        writer.close()
        sys.stdout.close()
        if config.wandb_key:
            wandb.finish()
    
    utils.cleanup()

def get_orochi_B_config():
    config = ml_collections.ConfigDict()
    #encoder
    config.img_size = (160, 192, 224)
    config.patch_size = 4
    config.pat_merg_rf = 2
    config.in_chans = 2
    config.embed_dim = 128
    config.depths = (4, 4, 4, 4)
    config.drop_rate = 0
    config.drop_path_rate = 0.2
    config.if_convskip = True
    config.out_indices = (0, 1, 2, 3)
    #mamba
    config.ssm_cfg=None
    config.norm_epsilon=1e-5
    config.initializer_cfg=None
    config.fused_add_norm=True
    config.rms_norm=True
    config.residual_in_fp32=True
    config.patch_norm = True      
    config.use_checkpoint = False
    #decoder
    config.decoder_bn = False
    config.decoder_depthseparable = False
    config.decoder_mode = '3d'
    config.decoder_head_chan = 64
    #training
    config.if_resume = False
    config.finetune_mode = 'reg' # 'reg', 'fus', 'SR', 'IR'
    config.load_modules = {
        'load': ['encoder'],
        'froze': ['encoder'],
        'unfroze_from_froze': ['norm', 'bias']
        } 
    config.batch_size = 2
    config.lr = 0.0001
    config.weight_decay = 0.01
    config.warmup_ratio = 0.1
    config.warmup_start_factor = 0.01
    config.max_epoch = 301
    config.gpu_ids = [0,1,2,3,4,5,6]
    num_cpus = multiprocessing.cpu_count()
    config.num_workers = min(num_cpus * 2, 16)
    config.save_steps = 5
    config.losses = {
        "mse": (nn.MSELoss(), 1.0),
        "ssim": (losses.SSIM3D(), 0.0), 
        "ncc": (losses.NCC_vxm(), 1.0),
        "grad": (losses.Grad3d(penalty='l2'), 1.0),
        "dice": (losses.DiceLoss(), 1.0),
    }
    #path
    config.checkpoint_dir = './checkpoint.pth.tar'
    config.train_dir = './IXI_data/Train/'
    config.atlas_dir = '/IXI_data/atlas.pkl'
    config.val_dir = './IXI_data/Val/'
    config.test_dir = './IXI_data/Test/'
    config.save_dir = f'./Experiment/IXIREG/{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}/'
    
    config.wandb_key = None  
    config.wandb_project = "Orochi"
    
    return config

def main():
    config = get_orochi_B_config()
    num_gpus = len(config.gpu_ids) 
    
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, config.gpu_ids))
    os.environ['GLOO_SOCKET_IFNAME'] = 'eth0' 
    port = utils.get_free_port()
    
    mp.spawn(train, args=(num_gpus, list(range(num_gpus)), config, port), nprocs=num_gpus, join=True)

if __name__ == "__main__":
    main()