# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import time

import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data.distributed
from tensorboardX import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from utils.utils import AverageMeter, distributed_all_gather
from torchvision.utils import save_image, make_grid

from monai import data, transforms

from monai.data import decollate_batch
from utils.ops import aug_rand, rot_rand
# from losses import DiceCeLoss, softmax_kl_loss, softmax_weighted_loss
from losses import softmax_kl_loss
from losses import softmax_weighted_loss
from monai.losses import DiceCELoss


mod_keys = ['t1', 'flair', 't2', 't1ce',]


all_mod_masks = [
        [1, 1, 1, 1,],    
        [1, 1, 1, 0,],
        [1, 1, 0, 1,],
        [1, 0, 1, 1,],
        [0, 1, 1, 1,], 
        [1, 1, 0, 0,],
        [1, 0, 1, 0,],
        [1, 0, 0, 1,],
        [0, 1, 1, 0,],  
        [0, 1, 0, 1,],  
        [0, 0, 1, 1,], 
        [1, 0, 0, 0,],
        [0, 1, 0, 0,],
        [0, 0, 1, 0,],
        [0, 0, 0, 1,],    
    ]


def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args, teacher=None):

    model.train()
    start_time = time.time()
    run_loss = AverageMeter()
    # dice_ce_loss = DiceCeLoss(4)

    dice_ce_loss=DiceCELoss(
                    to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=0.0, smooth_dr=1e-6
                    )


    for idx, batch_data in enumerate(loader):
        all_modal = []
        for i in mod_keys:
            tmp= batch_data[i].squeeze(1)
            all_modal.append(tmp)
        
        
        all_data = torch.stack(all_modal, dim=1)
        target = batch_data['seg']
        target[target == 4] = 3
        all_data, target = all_data.cuda(args.rank), target.cuda(args.rank).int()

        # for param in model.parameters():
        #     param.grad = None

        mask = all_data.le(-0.35)
        target_mask = target.le(0.9)

        modal_setting_indexes = np.array([_ for _ in range(len(all_mod_masks))])
        np.random.shuffle(modal_setting_indexes)
        loss = 0.0
        if epoch >= 0:
            for modal_setting_index in modal_setting_indexes:
                model_setting = all_mod_masks[modal_setting_index]
                for param in model.parameters():
                    param.grad = None
                with autocast(enabled=args.amp):
                    step_loss = 0.0
                    missing_modal = []
                    for i, use_mode in zip(mod_keys, model_setting):
                        if use_mode:
                            tmp= batch_data[i].squeeze(1)
                            missing_modal.append(tmp)
                    while len(missing_modal) < 4:
                        missing_modal.append(missing_modal[-1])
                    missing_data = torch.stack(missing_modal, dim=1)
                    missing_data = missing_data.cuda(args.rank)

                    _, _, semantic, out1 = model(missing_data)

                    seg_loss = dice_ce_loss(out1, target)
                    step_loss += seg_loss
                    weighted_cross_loss = softmax_weighted_loss(out1, target, num_cls=4)
                    step_loss += weighted_cross_loss
              

                    _, _, semantic2, out2 = model(all_data)
                    seg_loss2 = dice_ce_loss(out2, target)
                    step_loss += seg_loss2
                    weighted_cross_loss2 = softmax_weighted_loss(out1, target, num_cls=4)
                    step_loss += weighted_cross_loss2


                    semantic_loss = loss_func(semantic, semantic2.detach())
                    step_loss += semantic_loss 
            
                    if args.amp:
                        if torch.isnan(step_loss).any():
                            for param in model.parameters():
                                param.grad = None
                        else:
                            scaler.scale(step_loss).backward()
                            scaler.step(optimizer)
                            scaler.update()
                    else:
                        if torch.isnan(step_loss).any():
                            for param in model.parameters():
                                param.grad = None
                        else:
                            step_loss.backward()
                            optimizer.step()

                    if torch.isnan(step_loss).any():
                        print('has nan')
                        print(out1.sum(1),target.sum())
                        print(dice_loss.item(), ce_loss.item(), seg_loss.item())
                        print(torch.isnan(x1).any(), torch.isnan(target).any())
                        print(torch.isnan(out1).any(), torch.isnan(semantic).any())
                        exit()

                    else:
                        loss += step_loss
        
        else:
            for param in model.parameters():
                param.grad = None
            with autocast(enabled=args.amp):
                step_loss = 0.0
                x2, rot2 = rot_rand(args, all_data)
                _, _, semantic2, out2 = model(x2)
                for i in range(len(out2)):
                    out2[i] = out2[i].rot90(4-rot2[i], (2,3)) 
                seg_loss2 = dice_ce_loss(out2, target)
                step_loss += seg_loss2
            
        
                if args.amp:
                    if torch.isnan(step_loss).any():
                        for param in model.parameters():
                            param.grad = None
                    else:
                        scaler.scale(step_loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                else:
                    if torch.isnan(step_loss).any():
                        for param in model.parameters():
                            param.grad = None
                    else:
                        step_loss.backward()
                        optimizer.step()
                        optimizer.zero_grad()
           
                loss += step_loss
        
        
        if args.distributed:
            loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=idx < loader.sampler.valid_length)
            run_loss.update(
                np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0)/len(all_mod_masks), n=args.batch_size * args.world_size
            )
        else:
            run_loss.update(loss.item()/len(all_mod_masks), n=4*args.batch_size)
        if args.rank == 0:
            try:
                print(
                    "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)),
                    "loss: {:.4f}".format(run_loss.avg),
                    "time {:.2f}s".format(time.time() - start_time),
                    "loss1", seg_loss.item(), semantic_loss.item(),
                )
            except:
                pass
        start_time = time.time()
    
    return run_loss.avg

from monai.transforms import (
    Activations,
    EnsureChannelFirstd,
    AsDiscrete,
    EnsureType,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotate90d,
    ScaleIntensityd,
)
from utils.utils import dice, resample_3d
from eval import *

def val_epoch(model, loader, epoch, acc_func, args, model_inferer=None, post_label=None, post_pred=None):
        
    model.eval()
    post_trans_output = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
    post_trans_gt = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
    ###
    run_acc = AverageMeter()
    start_time = time.time()

    res = {}
    for i in range(len(all_mod_masks)):
        res[i] = {}
        res[i]['setting'] = all_mod_masks[i]
        res[i]['all_dice_all_wt'] = []
        res[i]['all_dice_all_co'] = []
        res[i]['all_dice_all_ec'] = []
        res[i]['all_dice_all_mean'] = []

    with torch.no_grad():
        counter = 0
        for batch_data in loader:

        
            counter = 0
            
            for batch_data in loader:
                val_outputs = []

                for mod_idx in range(len(all_mod_masks)):
                    all_modal = []
                    current_mod = all_mod_masks[mod_idx]
                    for i, use_mode in zip(mod_keys, current_mod):
                        if use_mode:
                
                            tmp= batch_data[i].squeeze(1)
                            # tmp = post_transform(tmp)
                            all_modal.append(tmp)
                    while len(all_modal) < 4:
                        all_modal.append(all_modal[-1])
                    data = torch.stack(all_modal, dim=1)
                    target = batch_data['seg']
                    target[target == 4] = 3
            
                    data, target = data.cuda(args.rank), target.cpu().numpy()
                    mse = 0.
                    mask = data.ge(-0.35)
                    tmp_dice_list = []
                    val_output =  model_inferer(
                        data)
                    val_output = torch.nn.Softmax(dim=1)(val_output) 
                    val_output = torch.argmax(val_output, dim=1).cpu().numpy()

                    dice_wt, dice_co, dice_ec, dice_mean = eval_one_dice(val_output, target)
                    print(  dice_wt, dice_co, dice_ec, dice_mean )
                    res[mod_idx]['all_dice_all_wt'].append(dice_wt)
                    res[mod_idx]['all_dice_all_co'].append(dice_co)
                    res[mod_idx]['all_dice_all_ec'].append(dice_ec)
                    res[mod_idx]['all_dice_all_mean'].append(dice_mean)
                    val_outputs.append(val_output)

                    counter += 1
                
            if counter >=1: break
            print(counter)

    return res, val_outputs, target


def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None):
    state_dict = model.state_dict() if not args.distributed else model.module.state_dict()
    save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict}
    if optimizer is not None:
        save_dict["optimizer"] = optimizer.state_dict()
    if scheduler is not None:
        try:
            save_dict["scheduler"] = scheduler.state_dict()
        except:
            pass
    filename = os.path.join(args.logdir, filename)
    torch.save(save_dict, filename)
    print("Saving checkpoint", filename)


def run_training(
    model,
    train_loader,
    val_loader,
    optimizer,
    loss_func,
    acc_func,
    args,
    model_inferer=None,
    scheduler=None,
    start_epoch=0,
    post_label=None,
    post_pred=None,
    teacher=None
):
    writer = None
    if args.logdir is not None and args.rank == 0:
        writer = SummaryWriter(log_dir=args.logdir)
        if args.rank == 0:
            print("Writing Tensorboard logs to ", args.logdir)
    scaler = None
    if args.amp:
        scaler = GradScaler()
    val_dice_min = 0.0
    epoch = start_epoch

    for epoch in range(start_epoch, args.max_epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            torch.distributed.barrier()
        print(args.rank, time.ctime(), "Epoch:", epoch)
        epoch_time = time.time()


        train_loss = train_epoch(
            model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args, teacher=teacher
        )

        if args.rank == 0:
            print(
                "Final training  {}/{}".format(epoch, args.max_epochs - 1),
                "loss: {:.4f}".format(train_loss),
                "time {:.2f}s".format(time.time() - epoch_time),
            )
        if args.rank == 0 and writer is not None:
            writer.add_scalar("train_loss", train_loss, epoch)

        b_new_best = False
        if args.rank == 0 and args.logdir is not None and (epoch) % args.save_every == 0:
            save_checkpoint(model, epoch, args, best_acc=val_dice_min, filename="model_final_{}.pt".format(str(epoch)))
      
        if (epoch) % args.val_every == 0:
            print('start to validate')
            
            if args.distributed:
                torch.distributed.barrier()
            epoch_time = time.time()
            res, val_outputs, target = val_epoch(
                model,
                val_loader,
                epoch=epoch,
                acc_func=acc_func,
                model_inferer=model_inferer,
                args=args,
                post_label=post_label,
                post_pred=post_pred,
            )

            all_mean = []
            if args.rank == 0 and writer is not None:
                for mod_idx in range(len(all_mod_masks)):
                    print(res[mod_idx]['setting'])
                    print('all_dice_all_wt', np.array(res[mod_idx]['all_dice_all_wt']).mean())
                    print('all_dice_all_co', np.array(res[mod_idx]['all_dice_all_co']).mean())
                    print('all_dice_all_ec', np.array(res[mod_idx]['all_dice_all_ec']).mean())
                    print('all_dice_all_mean', np.array(res[mod_idx]['all_dice_all_mean']).mean())
                    all_mean.append(np.array(res[mod_idx]['all_dice_all_mean']).mean())
                    writer.add_scalar(f"mod_keys_{res[mod_idx]['setting']}/all_dice_all_wt", np.array(res[mod_idx]['all_dice_all_wt']).mean(), epoch)
                    writer.add_scalar(f"mod_keys_{res[mod_idx]['setting']}/all_dice_all_co", np.array(res[mod_idx]['all_dice_all_co']).mean(), epoch)
                    writer.add_scalar(f"mod_keys_{res[mod_idx]['setting']}/all_dice_all_ec", np.array(res[mod_idx]['all_dice_all_ec']).mean(), epoch)
                    writer.add_scalar(f"mod_keys_{res[mod_idx]['setting']}/all_dice_all_mean", np.array(res[mod_idx]['all_dice_all_mean']).mean(), epoch)
                    print()
            
            if args.rank == 0 and writer is not None:
                write_fig(val_outputs, target, writer, epoch)


            if args.rank == 0:
                val_dice_avg = np.array(all_mean).mean()
                print(
                    "Final validation  {}/{}".format(epoch, args.max_epochs - 1),
                    "acc",
                    val_dice_avg,
                    "time {:.2f}s".format(time.time() - epoch_time),
                )
                if writer is not None:
                    writer.add_scalar("val_acc", val_dice_avg, epoch)
                if val_dice_avg > val_dice_min:
                    print("new best ({:.6f} --> {:.6f}). ".format(val_dice_min, val_dice_avg))
                    val_dice_min = val_dice_avg
                    b_new_best = True
                    if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
                        save_checkpoint(
                            model, epoch, args, best_acc=val_dice_min, optimizer=optimizer, scheduler=scheduler
                        )
        
        if scheduler is not None:
            try:
                scheduler.step()
            except:
                pass

    print("Training Finished !, Best MSE: ", val_dice_min)
    return val_dice_min



def write_fig(val_outputs_1, target_image, writer, epoch):

    keys = mod_keys + ['full']
    shape = target_image[0].shape  
    max_index = np.argmax(target_image)
    num = max_index // ( 240 * 155)
    num1 = (max_index % ( 240 * 155)) // 155
    num2 = max_index -( num*( 240 * 155) + num1*155)

    for i in range(3):
        key = all_mod_masks[i]
        current_for_vis = val_outputs_1[i]
        current_for_vis = [
            (current_for_vis[0,:,num1,:] -current_for_vis[0,:,:,:].min() )/(current_for_vis[0,:,:,:].max() - current_for_vis[0,:,:,:].min() ), 
        ]
        fixed_grid = make_grid(torch.tensor(current_for_vis).unsqueeze(1), nrow=2, value_range=(0, 1), normalize=True)
        writer.add_image(f'GenImages_{key}', fixed_grid, epoch)
        

    current_for_vis = [
        (target_image[0,:,num1,:] -target_image[0,:,:,:].min() )/(target_image[0,:,:,:].max() - target_image[0,:,:,:].min() ),  
        ]
    fixed_grid = make_grid(torch.tensor(current_for_vis).unsqueeze(1), nrow=1, value_range=(0, 1), normalize=True)
    writer.add_image('Target', fixed_grid, epoch)

    current_for_vis = [
        (target_image[0,:,:,num2] -target_image[0,:,:,:].min() )/(target_image[0,:,:,:].max() - target_image[0,:,:,:].min() ),  
    ]
    fixed_grid = make_grid(torch.tensor(current_for_vis).unsqueeze(1), nrow=1, value_range=(0, 1), normalize=True)
    writer.add_image('Target_top', fixed_grid, epoch)