import time
import torch
import torch.optim as optim
# import tensorflow as tf
from torch.utils.tensorboard import SummaryWriter
from sparsemax import Sparsemax

import numpy as np
from collections import defaultdict
from collections import OrderedDict
import os
from os.path import join as pjoin

from data.utils import MotionNormalizerTorch, face_joint_indx, fid_l, fid_r
from data.quaternion import *
from utils.utils import print_current_loss
from eval import evaluation_during_training
from models_interhuman_selfattn.mask_transformer.tools import *
from timm.utils import ApexScaler, NativeScaler
from einops import rearrange, repeat
from spikingjelly.clock_driven import functional
from einops import rearrange, repeat

from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt
import math
import os
def def_value():
    return 0.0
def plot_result_figure(path):
    for i,file in enumerate(os.listdir(path)):
        try:
            if file.endswith("png"):
                continue
            event_path = path+"/"+file
            # 创建 EventAccumulator 实例
            ea = event_accumulator.EventAccumulator(event_path)
            ea.Reload()  # 加载数据

            # 查看有哪些 scalars（比如 loss, accuracy）
            print(ea.Tags())

            # 2. 获取所有 scalar 的名字
            scalar_tags = ea.Tags().get('scalars', [])
            print("找到的所有标量：", scalar_tags)

            # 3. 计算子图布局（比如自动计算几行几列）
            num_plots = len(scalar_tags)
            cols = 3  # 每行放3个子图
            rows = math.ceil(num_plots / cols)
            plt.clf()
            # 4. 创建子图
            plt.figure(i+1)
            fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
            axes = axes.flatten()  # 方便索引

            # 5. 遍历每个loss，画到对应子图
            for idx, tag in enumerate(scalar_tags):
                events = ea.Scalars(tag)
                steps = [e.step for e in events]
                values = [e.value for e in events]

                ax = axes[idx]
                ax.plot(steps, values, label=tag)
                ax.set_title(tag)
                ax.set_xlabel('Step')
                ax.set_ylabel('Value')
                ax.grid(True)
                ax.legend()

            # 如果子图数量比tag多，关掉多余的空白子图
            for j in range(len(scalar_tags), len(axes)):
                fig.delaxes(axes[j])

            # 6. 布局调整+显示
            plt.tight_layout()
            plt.savefig(f"{path}/tensorboard_{file}.png")  # 保存图片
            plt.show()
        except Exception as e:
            print(f"Error processing {file}: {e}")
def save_result_txt(path,output_txt_path,epoch):
    results = []
    for i,file in enumerate(os.listdir(path)):
        try:
            if file.endswith("png"):
                continue
            event_path = path+"/"+file
            # 创建 EventAccumulator 实例
            ea = event_accumulator.EventAccumulator(event_path)
            ea.Reload()  # 加载数据

            # 查看有哪些 scalars（比如 loss, accuracy）
            print(ea.Tags())

            # 2. 获取所有 scalar 的名字
            scalar_tags = ea.Tags().get('scalars', [])
            print("找到的所有标量：", scalar_tags)

            # 3. 计算子图布局（比如自动计算几行几列）
            num_plots = len(scalar_tags)
            cols = 3  # 每行放3个子图
            rows = math.ceil(num_plots / cols)
            
            # 4. 创建子图
            exp_name = os.path.basename(file)
            line = f"{exp_name}_{epoch}:"
            # 5. 遍历每个loss，画到对应子图
            for idx, tag in enumerate(scalar_tags):
                events = ea.Scalars(tag)
                if not events:
                    continue
                latest_value = events[-1].value
                line += f" {tag} {latest_value:.6f}"
            results.append(line)
        
            print(f"{line}") 
        except Exception as e:
            print(f"Error processing {file}: {e}")

    with open(output_txt_path, "a+") as f:
        for line in results:
            f.write(line + "\n")


class MaskTransformerTrainerDDP:
    def __init__(self, args, t2m_transformer, vq_model):
        self.opt = args
        self.t2m_transformer = t2m_transformer
        self.vq_model = vq_model
        self.device = args.device
        self.vq_model.eval()
        self.normalizer = MotionNormalizerTorch(self.device)
        self.InteractionLoss = torch.nn.SmoothL1Loss(reduction='none')
        self.softmax = Sparsemax(dim=-1)
        self.accumulation_steps=args.accumulation_steps
        if args.is_train:
            self.logger = SummaryWriter(args.log_dir)


    def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):

        current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
        for param_group in self.opt_t2m_transformer.param_groups:
            param_group["lr"] = current_lr

        return current_lr
    

    def calc_dm_loss(self, motion1_joints, motion2_joints, pred_motion1_joints, pred_motion2_joints, thresh_pred=1, thresh_tgt=0.1):
        pred_distance_matrix = torch.cdist(pred_motion1_joints.contiguous(), pred_motion2_joints)
        tgt_distance_matrix = torch.cdist(motion1_joints.contiguous(), motion2_joints)

        pred_distance_matrix = pred_distance_matrix.reshape(pred_distance_matrix.shape[0], -1) # T, njoints=22, 22 -> T, 484
        tgt_distance_matrix = tgt_distance_matrix.reshape(tgt_distance_matrix.shape[0], -1)

        dm_mask = (pred_distance_matrix < thresh_pred).float()
        dm_tgt_mask = (tgt_distance_matrix < thresh_tgt).float()
        
        dm_loss = (self.InteractionLoss(pred_distance_matrix, tgt_distance_matrix) * dm_mask).sum()/ (dm_mask.sum() + 1.e-7)
        dm_tgt_loss = (self.InteractionLoss(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix)) * dm_tgt_mask).sum()/ (dm_tgt_mask.sum() + 1.e-7)
        
        return dm_loss + dm_tgt_loss
    
    def calc_ro_loss(self, motion1_joints, motion2_joints, pred_motion1_joints, pred_motion2_joints):
        motion_joints = torch.cat([motion1_joints.unsqueeze(1), motion2_joints.unsqueeze(1)], dim=1)
        pred_motion_joints = torch.cat([pred_motion1_joints.unsqueeze(1), pred_motion2_joints.unsqueeze(1)], dim=1)

        r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
        across = pred_motion_joints[..., r_hip, :] - pred_motion_joints[..., l_hip, :]
        across = across / across.norm(dim=-1, keepdim=True)
        across_gt = motion_joints[..., r_hip, :] - motion_joints[..., l_hip, :]
        across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True)

        y_axis = torch.zeros_like(across)
        y_axis[..., 1] = 1

        forward = torch.cross(y_axis, across, axis=-1)
        forward = forward / forward.norm(dim=-1, keepdim=True)
        forward_gt = torch.cross(y_axis, across_gt, axis=-1)
        forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True)

        pred_relative_rot = qbetween(forward[..., 0, :], forward[..., 1, :])
        tgt_relative_rot = qbetween(forward_gt[..., 0, :], forward_gt[..., 1, :])

        ro_loss = self.InteractionLoss(pred_relative_rot[..., [0, 2]],
                                    tgt_relative_rot[..., [0, 2]]).mean()

        return ro_loss

    def calc_interaction_loss(self, motion1, motion2, logits, id_lens):
        nbp = 5
        nt = self.opt.num_tokens
        m_lens = id_lens * 4
        # print(id_lens, m_lens)
        
        # denormalize input motions
        motions = torch.cat([motion1.unsqueeze(-2), motion2.unsqueeze(-2)], dim=-2)
        motions = self.normalizer.backward(motions)
        motion1_denorm, motion2_denorm = motions.chunk(2,dim=-2)
        # print(motion1.shape, motion2.shape)

        # get probs from logits
        # probs = logits.softmax(dim=-1)
        probs = self.softmax(logits)
        # print(probs.shape, m_lens)
        probs1, probs2 = probs.chunk(2, dim =1)
        # print(probs1.shape, probs.shape)
        # print(probs1.grad_fn, probs2.grad_fn)        
        
        dm_loss = 0
        ro_loss = 0
        j_loss = 0
        for i in range(len(id_lens)):    
            # vq decode
            # print('\n')
            # print(i)
            # print(probs1[[i]].reshape(1,nbp,-1,nt)[:,:,:id_lens[i],:].reshape(1,-1,nt).unsqueeze(-2).shape)
            pred_motion1 = self.vq_model.forward_decoder(probs1[[i]].reshape(1,nbp,-1,nt)[:,:,:id_lens[i],:].reshape(1,-1,nt).unsqueeze(-2), soft_lookup=True)
            # print(pred_motion1.shape)
            # print(pred_motion1.grad_fn)
            # print(probs2[[i]].reshape(1,nbp,-1,nt)[:,:,:id_lens[i],:].reshape(1,-1,nt).unsqueeze(-2).shape)
            pred_motion2 = self.vq_model.forward_decoder(probs2[[i]].reshape(1,nbp,-1,nt)[:,:,:id_lens[i],:].reshape(1,-1,nt).unsqueeze(-2), soft_lookup=True)
            # print(pred_motion2.shape)
            # print(pred_motion2.grad_fn)

            # # denormalize predicted motions
            pred_motion = torch.cat([pred_motion1.unsqueeze(-2), pred_motion2.unsqueeze(-2)], dim=-2)
            pred_motion = self.normalizer.backward(pred_motion)
            pred_motion1_denorm, pred_motion2_denorm = pred_motion.chunk(2, dim=-2)
            # print(pred_motion1.shape, pred_motion2.shape)
            # print(pred_motion1.grad_fn, pred_motion2.grad_fn)

            # get joints
            motion1_joints = motion1_denorm[i, :m_lens[i], :][..., :self.opt.joints_num *3].reshape(-1, self.opt.joints_num, 3)
            motion2_joints = motion2_denorm[i, :m_lens[i], :][..., :self.opt.joints_num *3].reshape(-1, self.opt.joints_num, 3)
            pred_motion1_joints = pred_motion1_denorm[0, :m_lens[i], :][..., :self.opt.joints_num *3].reshape(-1, self.opt.joints_num, 3)
            pred_motion2_joints = pred_motion2_denorm[0, :m_lens[i], :][..., :self.opt.joints_num *3].reshape(-1, self.opt.joints_num, 3)
            # print(motion1_joints.shape, motion2_joints.shape, pred_motion1_joints.shape, pred_motion2_joints.shape)
            # print(pred_motion1_joints.grad_fn, pred_motion2_joints.grad_fn)
            
            # calc losses
            dm_loss += self.calc_dm_loss(motion1_joints, motion2_joints, pred_motion1_joints, pred_motion2_joints) # torch.tensor(0).to(motion1.device) #
            ro_loss += self.calc_ro_loss(motion1_joints, motion2_joints, pred_motion1_joints, pred_motion2_joints) # torch.tensor(0).to(motion1.device) #
            # j_loss += self.InteractionLoss(pred_motion1_joints, motion1_joints).mean() + self.InteractionLoss(pred_motion2_joints, motion2_joints).mean()
            j_loss += self.InteractionLoss(pred_motion1[0, :m_lens[i], :], motion1[i, :m_lens[i], :]).mean() + self.InteractionLoss(pred_motion2[0, :m_lens[i], :], motion2[i, :m_lens[i], :]).mean()
            # print(j_loss.grad_fn)

        dm_loss = dm_loss / len(m_lens)
        ro_loss = ro_loss / len(m_lens)
        j_loss = j_loss / len(m_lens)
        # exit()
        return dm_loss, ro_loss, j_loss

    def get_vq_codes(self, motion1, motion2, m_lens):
        B = motion1.shape[0]
        T = motion1.shape[1]//4*5*2
        Q = 1
        code_idx_all = torch.zeros(B, T, Q, dtype=torch.int64).to(self.device)
        
        def pad_codes(code_idx):
            code_idx = code_idx.reshape(1, 5, -1, Q)
            code_idx = torch.cat([code_idx, 
                                  -1*torch.zeros(1, 5, T//10 - code_idx.shape[2], Q, dtype=torch.int64).to(self.device)], dim=2)
            code_idx = code_idx.reshape(1, -1, Q)
            return code_idx 
        
        for i in range(B):
            code_idx1, _ = self.vq_model.encode(motion1[i].unsqueeze(0)[:, :m_lens[i].item()])
            code_idx2, _ = self.vq_model.encode(motion2[i].unsqueeze(0)[:, :m_lens[i].item()])
            
            code_idx1 = pad_codes(code_idx1)
            code_idx2 = pad_codes(code_idx2)
            
            code_idx = torch.cat([code_idx1, code_idx2], dim=1)
            # print(f"Code Index: {code_idx1.shape}, {code_idx2.shape}, {code_idx.shape}, {code_idx_all.shape}")
            code_idx_all[i] = code_idx
            
        
        return code_idx_all
    
    def forward(self, batch_data):
        
        if self.opt.dataset_name == "interhuman":
            name, conds, motion1, motion2, m_lens = batch_data
        elif self.opt.dataset_name == "interx":
            _, _, conds, _, motions, m_lens, _ = batch_data
            # motions = motions.reshape(motions.shape[0], motions.shape[1], motions.shape[2]//12, 12)
            motion1, motion2 = motions.split(6, dim=-1)

        motion1 = motion1.detach().float().to(self.device)
        motion2 = motion2.detach().float().to(self.device)
        m_lens = m_lens.detach().long().to(self.device)
        # print(f"Motions from dataset: {motion1.shape}, {motion2.shape}")
        # print(f"Motion lenghts: {m_lens}")
        
        code_idx1, _ = self.vq_model.encode(motion1)
        code_idx2, _ = self.vq_model.encode(motion2)
        code_idx = torch.cat([code_idx1, code_idx2], dim=1)
        # print(f"Code Index: {code_idx1.shape}, {code_idx2.shape}, {code_idx.shape}")
        # code_idx = self.get_vq_codes(motion1, motion2, m_lens)
        
        conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds

        m_lens = m_lens // 4
        # print(f"Motion Lengths: {m_lens}")

        _loss, _acc, _, _, _ = self.t2m_transformer(code_idx[..., 0], conds, m_lens)
        return _loss, _acc
        
       

    def update_old(self, batch_data):
        loss, acc = self.forward(batch_data)

        self.opt_t2m_transformer.zero_grad()
        loss.backward()
        self.opt_t2m_transformer.step()
        self.scheduler.step()

        return loss.item(), acc
    def update(self, batch_data):
        loss, acc = self.forward(batch_data)

        # self.opt_t2m_transformer.zero_grad()
        # loss.backward()
        # self.opt_t2m_transformer.step()
        # self.scheduler.step()

        return loss, acc

    def save(self, file_name, ep, total_it):
        t2m_trans_state_dict = self.t2m_transformer.state_dict()
        clip_weights = [e for e in t2m_trans_state_dict.keys() if e.startswith('clip_')]
        for e in clip_weights:
            del t2m_trans_state_dict[e]
        state = {
            't2m_transformer': t2m_trans_state_dict,
            'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(),
            # 'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None ,
            'ep': ep,
            'total_it': total_it,
        }
        torch.save(state, file_name)

    def resume(self, model_dir):
        checkpoint = torch.load(model_dir, map_location=self.device)
        missing_keys, unexpected_keys = self.t2m_transformer.load_state_dict(checkpoint['t2m_transformer'], strict=False)
        assert len(unexpected_keys) == 0
        assert all([k.startswith('clip_') for k in missing_keys])
        
        self.opt_t2m_transformer.load_state_dict(checkpoint['opt_t2m_transformer']) # Optimizer
        try:
            self.scheduler.load_state_dict({key: checkpoint['scheduler'][key] for key in ["last_epoch", "_step_count"]}) # Scheduler
        except:
            print('Resume wo optimizer')
        return checkpoint['ep'], checkpoint['total_it']

    def train(self, rank,train_sampler,train_loader, val_loader, test_loader, eval_wrapper):
        self.t2m_transformer.to(self.device)
        self.vq_model.to(self.device)
        # self.t2m_transformer=torch.nn.DataParallel(self.t2m_transformer, device_ids=[1,2])
        # for name, p in self.t2m_transformer.named_parameters():
        #     print(name)
        
        total_iters = self.opt.max_epoch * len(train_loader)
        self.opt.milestones = [int(total_iters * 0.5), int(total_iters * 0.7), int(total_iters * 0.85)]
        self.opt.warm_up_iter = len(train_loader) // 4
        self.opt.log_every = len(train_loader) // 10
        self.opt.save_latest = len(train_loader) // 2
        if rank==0:
            print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
            print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader)))
            print(f'Milestones: {self.opt.milestones}')
            print('Warm Up Iterations: %04d, Log Every: %04d, Save Latest: %04d' % (self.opt.warm_up_iter, self.opt.log_every, self.opt.save_latest))
        
        self.opt_t2m_transformer = optim.AdamW(self.t2m_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_t2m_transformer, milestones=self.opt.milestones, gamma=self.opt.gamma)
        self.scaler=NativeScaler()
        
        epoch = 0
        it = 0

        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, 'latest.tar')
            epoch, it = self.resume(model_dir)
            it = it // self.opt.log_every * self.opt.log_every
            print("Load model epoch:%d iterations:%d" % (epoch, it))

        start_time = time.time()
        logs = defaultdict(def_value, OrderedDict())

        max_acc = -np.inf
        min_loss = np.inf
        min_fid = np.inf
        max_top1 = -np.inf

        if self.opt.do_eval:
            eval_file = pjoin(self.opt.eval_dir, 'evaluation_training.log')

        # Gradient accumulation settings
        accumulation_steps = self.accumulation_steps
        # assert accumulation_steps > 0, "accumulation_steps must be greater than 0."

        while epoch < self.opt.max_epoch:
            epoch += 1
            # train_sampler.set_epoch(epoch)    
            self.t2m_transformer.train()
            self.vq_model.eval()

            # if epoch > 200:
            #     self.opt.eval_every_e = 10

            for i, batch in enumerate(train_loader):
                functional.reset_net(self.t2m_transformer)
                it += 1
                if it < self.opt.warm_up_iter:
                    self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)

                # Zero gradients for accumulation
                if (i % accumulation_steps) == 0:
                    self.opt_t2m_transformer.zero_grad()
                    
                loss, acc = self.update(batch_data=batch)
                loss/=self.accumulation_steps
                # Backward pass and gradient accumulation
                self.scaler._scaler.scale(loss).backward()
                # loss.backward()

                if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                    # Perform optimizer step only after accumulation_steps iterations
                    # self.opt_t2m_transformer.step()
                    self.scaler._scaler.step(self.opt_t2m_transformer)
                    self.scaler._scaler.update()
                    self.scheduler.step()


                logs['loss'] += loss.item()
                logs['acc'] += acc
                logs['lr'] += self.opt_t2m_transformer.param_groups[0]['lr']

                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.add_scalar('val_loss', val_loss, it)
                    # self.l
                    for tag, value in logs.items():
                        self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = defaultdict(def_value, OrderedDict())
                    if rank==0:
                        print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)

                if it % self.opt.save_latest == 0:
                    if rank==0:
                        self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            if rank==0:
                self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
                print('Validation time:')
                self.vq_model.eval()
                self.t2m_transformer.eval()
                val_loss = []
                val_acc = []
                with torch.no_grad():
                    for i, batch_data in enumerate(val_loader):
                        loss, acc = self.forward(batch_data)
                        val_loss.append(loss.item())
                        val_acc.append(acc)
            
                print(f"Validation loss:{np.mean(val_loss):.3f}, accuracy:{np.mean(val_acc):.3f}")
                self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch)
                self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch)

                if np.mean(val_acc) > max_acc:
                    
                    print(f"Improved accuracy from {max_acc:.02f} to {np.mean(val_acc)}!!!")
                    self.save(pjoin(self.opt.model_dir, 'best_acc.tar'), epoch, it)
                    max_acc = np.mean(val_acc)
                
                if np.mean(val_loss) < min_loss:

                    print(f"Improved Loss from {min_loss:.02f} to {np.mean(val_loss)}!!!")
                    self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
                    min_loss = np.mean(val_loss)
                plot_result_figure(self.opt.log_dir)
            if rank==0:
                if self.opt.do_eval:
                    if epoch % self.opt.eval_every_e == 0:
                        self.vq_model.eval()
                        self.t2m_transformer.eval()
                        
                        fid, mat, top1 = evaluation_during_training(self.opt, self.vq_model, test_loader, 
                                                                    eval_wrapper, epoch, eval_file, trans=self.t2m_transformer)
                        self.logger.add_scalar('Test/FID', fid, epoch)
                        self.logger.add_scalar('Test/Matching', mat, epoch)
                        self.logger.add_scalar('Test/Top1', top1, epoch)
                        if fid < min_fid:
                            min_fid = fid
                            self.save(pjoin(self.opt.model_dir, 'best_fid.tar'), epoch, it)
                            print('Best FID Model So Far!~')
                        if top1 > max_top1:
                            max_top1 = top1
                            self.save(pjoin(self.opt.model_dir, 'best_top1.tar'), epoch, it)
                            print('Best Top1 Model So Far!~')
                    plot_result_figure(self.opt.log_dir)
                    save_result_txt(self.opt.log_dir, self.opt.save_root+"/log.txt",epoch)                    
                print('\n')
