import os
import gc
import json
import shutil

import torch
import torch.nn as nn
from tqdm import tqdm
import utils

from clipbufferloss import BiMixCoBufferLoss

def train(
    accelerator, 
    distributed:bool,
    optimizer:torch.optim, 
    train_dataloader:torch.utils.data.DataLoader, 
    val_dataloader:torch.utils.data.DataLoader, 
    num_epochs:int, 
    start_epoch:int, 
    clip_fmri,
    clip, 
    scheduler:dict=None, 
    processor=None,
    use_image_aug=False,
    img_augment=None,
    logger=None, 
    checkpoint_dir=None, 
    num_iterations_per_epoch:int=None,
    save_step:int=1,
    log_step:int=1,
    val_step:int=1,
    best_value=None,
    ## task
    text_scale: float=0.,
    mixup_pct: float=0.,
    norm_nii: bool=False,
    norm_dict: dict=None,
    padding_list: list=None,
    mixin=False,
    buffer_size=0,
    local_loss=False,
    local_loss2=False,
    gather_with_grad=False,
    ):
    lr_scheduler = None
    if scheduler is not None: 
        lr_scheduler = scheduler['scheduler']
        scheduler_step = scheduler['interval'] == 'step'
        scheduler_frequency = scheduler['frequency']

    
    if best_value==None: best_value=999.

    soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))

    loss_img = BiMixCoBufferLoss(accelerator=accelerator, local_loss=local_loss, gather_with_grad=gather_with_grad, buffer_size=buffer_size, local_loss2=local_loss2)
    loss_txt = BiMixCoBufferLoss(accelerator=accelerator, local_loss=local_loss, gather_with_grad=gather_with_grad, buffer_size=buffer_size, local_loss2=local_loss2)

    for epoch_idx in range(start_epoch, num_epochs):
        ## data
        ## train
        clip_fmri.train()
        optimizer.zero_grad()
        tr_loss = 0.
        tr_loss_img = 0.
        tr_loss_txt = 0.
        tr_img_fwd = 0.
        tr_img_bwd = 0.
        tr_txt_fwd = 0.
        tr_txt_bwd = 0.

        loop = tqdm(enumerate(train_dataloader), total =num_iterations_per_epoch,  desc=f"Epoch Train {epoch_idx}/{num_epochs}: ",  disable= not accelerator.is_main_process, dynamic_ncols=True)
        for batch_idx, batch in loop:
            # train batch
            # with accelerator.autocast():
            # with torch.cuda.amp.autocast(enabled=accelerator.mixed_precision!='no'):
            logs = training_on_batch(accelerator=accelerator, 
                                    distributed=distributed, 
                                    clip_fmri=clip_fmri, 
                                    clip=clip,
                                    processor=processor,
                                    use_image_aug=use_image_aug,
                                    img_augment=img_augment,
                                    batch=batch, 
                                    ## task
                                    text_scale=text_scale,
                                    epoch_idx=epoch_idx,
                                    soft_loss_temps=soft_loss_temps,
                                    mixup_pct=mixup_pct,
                                    num_epochs=num_epochs,
                                    norm_nii=norm_nii,
                                    norm_dict=norm_dict,
                                    padding_list=padding_list,
                                    mixin=mixin,
                                    loss_img=loss_img,
                                    loss_txt=loss_txt,
                                    )
            if (batch_idx+1)%accelerator.gradient_accumulation_steps==0:
                optimizer.step()
                optimizer.zero_grad()
                if lr_scheduler is not None and scheduler_step: 
                    if (batch_idx+1) % (scheduler_frequency*accelerator.gradient_accumulation_steps)==0: lr_scheduler.step()
            
            logs.update({'train/lr': optimizer.param_groups[0]['lr']})
            logs.update({'epoch': epoch_idx})
            postfix_info = "|".join([f"{k}: {v:.4f}" for k, v in logs.items()])
            loop.set_postfix_str(postfix_info)
            # log batch
            if (batch_idx+1)%log_step==0 or (batch_idx+1)==num_iterations_per_epoch:
                if accelerator.is_main_process:
                    accelerator.log(logs, step=epoch_idx*num_iterations_per_epoch+batch_idx)
            tr_loss+=logs['tr/loss']
            tr_loss_img+=logs['tr/loss_img']
            tr_loss_txt+=logs['tr/loss_txt']
            tr_img_fwd+=logs['tr/img_fwd']
            tr_img_bwd+=logs['tr/img_bwd']
            tr_txt_fwd+=logs['tr/txt_fwd']
            tr_txt_bwd+=logs['tr/txt_bwd']

        accelerator.wait_for_everyone()

        tr_loss/=(batch_idx+1)
        tr_loss_img/=(batch_idx+1)
        tr_loss_txt/=(batch_idx+1)
        tr_img_fwd/=(batch_idx+1)
        tr_img_bwd/=(batch_idx+1)
        tr_txt_fwd/=(batch_idx+1)
        tr_txt_bwd/=(batch_idx+1)

        # lr_schedule epoch
        if lr_scheduler is not None and not scheduler_step: 
            if (epoch_idx+1) % (scheduler_frequency)==0: lr_scheduler.step()

        if (epoch_idx+1) % val_step==0:
            # valid
            clip_fmri.eval()
            vl_loss = 0.
            vl_loss_img = 0.
            vl_loss_txt = 0.
            vl_img_fwd = 0.
            vl_img_bwd = 0.
            vl_txt_fwd = 0.
            vl_txt_bwd = 0.

            loop = tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc=f"Epoch Valid {epoch_idx}/{num_epochs}: ", disable=not accelerator.is_main_process, dynamic_ncols=True)
            for batch_idx, batch in loop:
                logs = valid_on_batch(accelerator,
                            distributed,
                            clip_fmri,
                            clip,
                            processor,
                            batch,
                            ## task
                            text_scale=text_scale,
                            epoch_idx=epoch_idx,
                            soft_loss_temps=soft_loss_temps,
                            mixup_pct=mixup_pct,
                            num_epochs=num_epochs,
                            norm_nii=norm_nii,
                            norm_dict=norm_dict,
                            padding_list=padding_list,
                            )
                postfix_info = "|".join([f"{k}: {v:.4f}" for k, v in logs.items()])
                loop.set_postfix_str(postfix_info)
                vl_loss+=logs['vl/loss']
                vl_loss_img+=logs['vl/loss_img']
                vl_loss_txt+=logs['vl/loss_txt']
                vl_img_fwd+=logs['vl/img_fwd']
                vl_img_bwd+=logs['vl/img_bwd']
                vl_txt_fwd+=logs['vl/txt_fwd']
                vl_txt_bwd+=logs['vl/txt_bwd']


            vl_loss/=(batch_idx+1)
            vl_loss_img/=(batch_idx+1)
            vl_loss_txt/=(batch_idx+1)
            vl_img_fwd/=(batch_idx+1)
            vl_img_bwd/=(batch_idx+1)
            vl_txt_fwd/=(batch_idx+1)
            vl_txt_bwd/=(batch_idx+1)

            if accelerator.is_main_process:
                accelerator.log({
                    'vl/loss': vl_loss,
                    'vl/loss_img': vl_loss_img,
                    'vl/loss_txt': vl_loss_txt,
                    'vl/img_fwd': vl_img_fwd,
                    'vl/img_bwd': vl_img_bwd,
                    'vl/txt_fwd': vl_txt_fwd,
                    'vl/txt_bwd': vl_txt_bwd,
                }, step=epoch_idx)

        # wait for other GPUs to catch up if needed
        accelerator.wait_for_everyone()
        # ======= log info and checkpoint ======= #
        
        info = {
            'tr/loss': tr_loss,
            'tr/loss_img': tr_loss_img,
            'tr/loss_txt': tr_loss_txt,
            'tr/img_fwd': tr_img_fwd,
            'tr/img_bwd': tr_img_bwd,
            'tr/txt_fwd': tr_txt_fwd,
            'tr/txt_bwd': tr_txt_bwd,
        }
        if (epoch_idx+1)%val_step==0:
            info.update({'vl/loss': vl_loss,
                            'vl/loss_img': vl_loss_img,
                            'vl/loss_txt': vl_loss_txt,
                            'vl/img_fwd': vl_img_fwd,
                            'vl/img_bwd': vl_img_bwd,
                            'vl/txt_fwd': vl_txt_fwd,
                            'vl/txt_bwd': vl_txt_bwd
                            })
        if accelerator.is_main_process:
            logger.info(f"Epoch[{epoch_idx}/{num_epochs}], Logs: {info}")
            if (epoch_idx+1)%val_step==0:
                if vl_loss < best_value: # and epoch_idx>num_epochs//2:
                    best_value = vl_loss
                    unwrapped_model = accelerator.unwrap_model(clip_fmri)
                    save_path = os.path.join(checkpoint_dir, 'best'+'.pth')
                    # 检查文件是否存在
                    if os.path.exists(save_path):
                        os.remove(save_path)  # 删除已存在的文件
                    accelerator.save(unwrapped_model.state_dict(), save_path)
                    print(f"\n---saved {checkpoint_dir}/best pth!---\n")
                    json_save_path = os.path.join(checkpoint_dir, 'best.json')
                    with open(json_save_path, 'w') as json_file:
                        json.dump({'epoch_idx': epoch_idx, **info}, json_file)

        if (epoch_idx+1)%save_step==0 and epoch_idx!=num_epochs-1:
            state_path = os.path.join(checkpoint_dir, f'state_last')

            accelerator.save_state(state_path)
            if accelerator.is_main_process:
                json_save_path = os.path.join(checkpoint_dir, 'last.json')
                with open(json_save_path, 'w') as json_file:
                    json.dump({'epoch_idx': epoch_idx, **info}, json_file)
            print(f"\n---saved {checkpoint_dir}/last pth!---\n")
        torch.cuda.empty_cache()
        gc.collect()

    accelerator.end_training()
    if accelerator.is_main_process: 
        unwrapped_model = accelerator.unwrap_model(clip_fmri)
        save_path = os.path.join(checkpoint_dir, 'last'+'.pth')
        accelerator.save(unwrapped_model.state_dict(), save_path)
        json_save_path = os.path.join(checkpoint_dir, 'last.json')
        with open(json_save_path, 'w') as json_file:
            json.dump({'epoch_idx': epoch_idx, **info}, json_file)
        print(f"\n---saved {checkpoint_dir}/last pth!---\n")
        state_path = os.path.join(checkpoint_dir, f'state_last')
        if os.path.exists(state_path):
            shutil.rmtree(state_path)

def training_on_batch(
    accelerator, 
    distributed:bool,
    clip_fmri,
    clip,
    processor,
    use_image_aug,
    img_augment,
    batch, 
    ## task
    text_scale,
    epoch_idx,
    soft_loss_temps,
    mixup_pct,
    num_epochs,
    norm_nii,
    norm_dict,
    padding_list,
    mixin=False,
    loss_img=None,
    loss_txt=None,
    ):  
        logs={}
        loss = 0.

        nii_batch, image_batch, input_ids_batch, attention_mask_batch, subj, _ = batch
        nii_batch = nii_batch.float()/300.
        if norm_nii:
            for i in range(len(subj)):
                nii_batch[i] = cal_norm_nii(nii_batch[i], subj[i], norm_dict)
        nii_batch = torch.nn.functional.pad(nii_batch, pad=padding_list, mode='constant', value=0)

        if epoch_idx < int(mixup_pct * num_epochs): 
            nii_batch, perm, betas, select = utils.mixco(nii_batch)
        else:
            perm, betas, select = None, None, None
        image_batch = image_batch.to(nii_batch.dtype)
        
        if use_image_aug:
            image_batch = img_augment(image_batch.float())
        with torch.no_grad(), torch.cuda.amp.autocast():
            output = clip(input_ids=input_ids_batch, attention_mask=attention_mask_batch, pixel_values=image_batch)
            clip_feat_text = output.text_embeds
            clip_feat_img = output.image_embeds
        
        with accelerator.autocast():
            fmri_emb = clip_fmri(nii_batch)
            clip_feat_fmri = fmri_emb.image_embeds

            temp = 0.01

            image_features = torch.nn.functional.normalize(clip_feat_img, dim=-1)
            fmri_features = torch.nn.functional.normalize(clip_feat_fmri, dim=-1)
            text_features = torch.nn.functional.normalize(clip_feat_text, dim=-1)
            
            loss_clip_image = loss_img(
                fmri_features,
                image_features,
                temp=temp, perm=perm, betas=betas)
            if text_scale > 0:
                loss_clip_text = loss_txt(
                    fmri_features,
                    text_features,
                    temp=temp, perm=perm, betas=betas)
            else:
                loss_clip_text = torch.tensor(0., device=fmri_features.device)
        
        utils.check_loss(loss_clip_image)
        utils.check_loss(loss_clip_text)

        loss += loss_clip_image
        loss_clip_text *= text_scale
        loss += loss_clip_text
    
        accelerator.backward(loss)

        with torch.no_grad():     
            labels = torch.arange(len(image_features)).to(image_features.device) 
            img_fwd = utils.topk(utils.batchwise_cosine_similarity(fmri_features, image_features), labels, k=1)
            img_bwd = utils.topk(utils.batchwise_cosine_similarity(image_features, fmri_features), labels, k=1)
            txt_fwd = utils.topk(utils.batchwise_cosine_similarity(fmri_features, text_features), labels, k=1)
            txt_bwd = utils.topk(utils.batchwise_cosine_similarity(text_features, fmri_features), labels, k=1)
            loss = accelerator.gather(loss).mean()
            loss_clip_image = accelerator.gather(loss_clip_image).mean()
            loss_clip_text = accelerator.gather(loss_clip_text).mean()
            img_fwd = accelerator.gather(img_fwd).mean()
            img_bwd = accelerator.gather(img_bwd).mean()
            txt_fwd = accelerator.gather(txt_fwd).mean()
            txt_bwd = accelerator.gather(txt_bwd).mean()
        
        logs.update({
            'tr/loss': loss.item(),
            'tr/loss_img': loss_clip_image.item(),
            'tr/loss_txt': loss_clip_text.item(),
            'tr/img_fwd': img_fwd.item(),
            'tr/img_bwd': img_bwd.item(),
            'tr/txt_fwd': txt_fwd.item(),
            'tr/txt_bwd': txt_bwd.item(),
        })
        return logs

def valid_on_batch(
    accelerator,
    distributed,
    clip_fmri,
    clip,
    processor,
    batch, 
    # =============== task ============
    text_scale,
    epoch_idx,
    soft_loss_temps,
    mixup_pct,
    num_epochs,
    norm_nii,
    norm_dict,
    padding_list,
    ):
    with torch.no_grad():
        logs={}
        loss=0.
        nii_batch, image_batch, input_ids_batch, attention_mask_batch, subj, _ = batch
        nii_batch = nii_batch.float()/300.
        if norm_nii:
            for i in range(len(subj)):
                nii_batch[i] = cal_norm_nii(nii_batch[i], subj[i], norm_dict)
        nii_batch = torch.nn.functional.pad(nii_batch, pad=padding_list, mode='constant', value=0)
        image_batch = image_batch.to(nii_batch.dtype)
        
        with torch.no_grad(), torch.cuda.amp.autocast():
            output = clip(input_ids=input_ids_batch, attention_mask=attention_mask_batch, pixel_values=image_batch)
            clip_feat_text = output.text_embeds
            clip_feat_img = output.image_embeds
        
        with accelerator.autocast():
            fmri_emb = clip_fmri(nii_batch)
            clip_feat_fmri = fmri_emb.image_embeds
            temp = 0.01

            image_features = torch.nn.functional.normalize(clip_feat_img, dim=-1)
            fmri_features = torch.nn.functional.normalize(clip_feat_fmri, dim=-1)
            text_features = torch.nn.functional.normalize(clip_feat_text, dim=-1)
        
            loss_clip_image = utils.clip_loss(
                fmri_features,
                image_features,
                temp=temp)
            if text_scale > 0:
                loss_clip_text = utils.clip_loss(
                    fmri_features,
                    text_features,
                    temp=temp)
            else:
                loss_clip_text = torch.tensor(0., device=fmri_features.device)
        
        utils.check_loss(loss_clip_image)
        utils.check_loss(loss_clip_text)

        loss += loss_clip_image
        loss_clip_text *= text_scale
        loss += loss_clip_text
     
        labels = torch.arange(len(image_features)).to(image_features.device) 
        img_fwd = utils.topk(utils.batchwise_cosine_similarity(fmri_features, image_features), labels, k=1)
        img_bwd = utils.topk(utils.batchwise_cosine_similarity(image_features, fmri_features), labels, k=1)
        txt_fwd = utils.topk(utils.batchwise_cosine_similarity(fmri_features, text_features), labels, k=1)
        txt_bwd = utils.topk(utils.batchwise_cosine_similarity(text_features, fmri_features), labels, k=1)

        
        loss_clip_image = accelerator.gather(loss_clip_image).mean()
        loss_clip_text = accelerator.gather(loss_clip_text).mean()
        img_fwd = accelerator.gather(img_fwd).mean()
        img_bwd = accelerator.gather(img_bwd).mean()
        txt_fwd = accelerator.gather(txt_fwd).mean()
        txt_bwd = accelerator.gather(txt_bwd).mean()

        logs.update({
            'vl/loss': loss.item(),
            'vl/loss_img': loss_clip_image.item(),
            'vl/loss_txt': loss_clip_text.item(),
            'vl/img_fwd': img_fwd.item(),
            'vl/img_bwd': img_bwd.item(),
            'vl/txt_fwd': txt_fwd.item(),
            'vl/txt_bwd': txt_bwd.item(),
        })
    return logs

def cal_norm_nii(nii, subj, norm_dict):
    mean = norm_dict['subj{:02d}_mean'.format(subj)].to(nii.device)
    std = norm_dict['subj{:02d}_std'.format(subj)].to(nii.device)
    nii = (nii - mean) / (std+1e-5)
    return nii

def clip_loss(feat1, feat2, logit_scale):
    logit_per_1_2 = logit_scale * feat1 @ feat2.T
    logit_per_2_1 = logit_scale * feat2 @ feat1.T

    labels = torch.arange(feat1.size(0), device=feat1.device)
    loss_1_2 = nn.CrossEntropyLoss()(logit_per_1_2, labels)
    loss_2_1 = nn.CrossEntropyLoss()(logit_per_2_1, labels)

    return (loss_1_2+loss_2_1)/2

