import sys
sys.path.insert(0, "/root/RetNet-main/RetNet-main/")
import logging
import os
import numpy as np
import torch
import torch.nn as nn
import random
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from datetime import datetime
from src.utils import MetricsTop
from src.utils.functions import dict_to_str
import pickle
# import retnet
# from mamba_ssm import Mamba
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from  mamba_ssm.modules.mamba_simple import MutiMamba,Mamba
logger = logging.getLogger('MMSA')
class MultiModalDataset(Dataset):
    def __init__(self, audio_data, text_data, vision_data, labels, id):
        self.audio_data = audio_data
        self.text_data = text_data
        self.vision_data = vision_data
        self.labels = labels
        self.id = id

    def __len__(self):
        # 确保所有数据集大小相同
        return len(self.audio_data)

    def __getitem__(self, idx):
        # import ipdb
        # ipdb.set_trace()
        audio = self.audio_data[idx]
        audio=sampling(audio,500,audio.shape[1])
        text = self.text_data[idx]
        text=sampling(text,500,text.shape[1])
        vision = self.vision_data[idx]
        vision=sampling(vision,500,vision.shape[1])
        label = self.labels[idx]
        id = self.id[idx]
        return audio, text, vision, label , id
    
file_path = '/root/code/mm/datasets/mosi/unaligned_50.pkl'
# 设置随机种子以使结果可复现
def set_random_seed(seed):
    print("随机种子为：",seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    # 读取Pickle文件
with open(file_path, 'rb') as f:
    aligned_50_data = pickle.load(f)

train_dataset = MultiModalDataset(
aligned_50_data['train']['audio'], 
aligned_50_data['train']['text'], 
aligned_50_data['train']['vision'], 
aligned_50_data['train']['regression_labels'],
aligned_50_data['train']['id'],
    )
valid_dataset = MultiModalDataset(
aligned_50_data['valid']['audio'], 
aligned_50_data['valid']['text'], 
aligned_50_data['valid']['vision'], 
aligned_50_data['valid']['regression_labels'],
aligned_50_data['valid']['id'],
    )
test_dataset = MultiModalDataset(
aligned_50_data['test']['audio'], 
aligned_50_data['test']['text'], 
aligned_50_data['test']['vision'], 
aligned_50_data['test']['regression_labels'],
aligned_50_data['test']['id'],
    )
batch_size=16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True,num_workers=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# mutiretnet=retnet.MultiModalRetNet(layers , hidden_dim_fusion, ffn_size, heads_fusion, batch_size, double_v_dim=True).float().to(device)#8,128,1024,8


class mamba_train():
    def __init__(self):
        # self.KeyEval = "Acc_3"
        self.KeyEval = "MAE"
        self.aligned=True
        self.train_mode = 'regression'
        self.early_stop = 10
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.update_epochs = 10
        self.metrics = MetricsTop(self.train_mode).getMetics('MOSI')
        self.start_time = datetime.now()

    def do_train(self, mutimamba,return_epoch_results=True):

        mode_params = [param for name, param in mutimamba.named_parameters()] 
        optimizer = optim.Adam([{'params': mode_params}], weight_decay=0.0005,lr=0.0005)
        # 初始化学习率调度器
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
        # initilize results
        logger.info("Start training")
        epochs, best_epoch = 0, 0
        if return_epoch_results:
            epoch_results = {
                'train': [],
                'valid': [],
                'test': []
            }
        min_or_max = 'min' if self.KeyEval in ['MAE'] else 'max'
        # min_or_max = 'max'
        best_valid = 1e8 if min_or_max == 'min' else 0
        # best_valid = 0  # 因为你现在是寻求最大值
        # loop util earlystop
        while True: 
            epochs += 1
            # train
            y_pred = []
            y_true = []
            losses = []
            mutimamba.train()
            train_loss = 0.0
            left_epochs = self.update_epochs
            ids = []
            with tqdm(train_loader) as td:
                for batch_data in td:
                    if left_epochs == self.update_epochs:
                        optimizer.zero_grad()
                    left_epochs -= 1#经过多个batch之后再进行更新，用于训练大模型

                    audio, text, vision, label_m , id = batch_data

                    audio, text, vision, label_m = [x.float().to(self.device) for x in [audio, text, vision, label_m ]]

                    if audio.size(0) != batch_size:
                        continue

                    cur_id = id
                    ids.extend(cur_id)
                    # forward
                    mutimamba_out=mutimamba(batch_size,audio,vision,text)         #batchsize,50,128          

                    predict=mutimamba_out.to(device)#[batch_size,3]
                    label = label_m.float().to(device).view(-1,1)
                    predict=predict.float()
                    # compute loss
                    loss = 0.0
                    loss+=nn.L1Loss()(predict,label)#这里可修改损失函数
                    # nll_loss = nn.NLLLoss()

                    # loss=nll_loss(y_pred, y_true)
                    # backward
                    loss.requires_grad_(True)
                    loss.backward()
                    train_loss += loss.item()
                    y_pred.append(predict)
                    y_true.append(label)
                    
                    # update parameters
                    if not left_epochs:
                        # update
                        optimizer.step()
                        left_epochs = self.update_epochs
                # if not left_epochs:
                #     # update
                #     optimizer.step()
            train_loss = train_loss / len(train_loader)
            pred, true = torch.cat(y_pred), torch.cat(y_true)
            train_results = self.metrics(pred, true)
            # print("train_loss is ",train_results["F1_score_3"])
            logger.info('%s: >> ' %("fusion") + dict_to_str(train_results))         
            # validation
            val_results = self.do_test(mutimamba, valid_loader, mode="VAL")
            cur_valid = val_results[self.KeyEval]
            # print("当前的验证集的F1分数为",val_results["F1_score_3"])
            # save best model
            isBetter = cur_valid <= (best_valid - 1e-6) if min_or_max == 'min' else cur_valid >= (best_valid + 1e-6)
            # isBetter = cur_valid > best_valid
            if isBetter:
                best_valid, best_epoch = cur_valid, epochs
                # save model
                torch.save(mutimamba.cpu().state_dict(), "/root/code/mm/mamba/mamba_ssm/result/mosei_regression/mutimamba_moseireg_" + self.start_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pth")
                mutimamba.to(self.device)
                print("had been save model")

            # epoch results
            if return_epoch_results:
                train_results["Loss"] = train_loss
                epoch_results['train'].append(train_results)
                epoch_results['valid'].append(val_results)
                test_results = self.do_test(mutimamba, test_loader, mode="TEST")
                epoch_results['test'].append(test_results)
            # early stop
            print(epochs)
            print(test_results)            
            if epochs % 50 == 0:
                scheduler.step()
            if epochs > 150:#epochs - best_epoch >= self.early_stop
                return epoch_results if return_epoch_results else None

    def do_test(self, mutimamba, dataloader,mode="VAL", return_sample_results=False):
        
        mutimamba.eval()
        y_pred = []
        y_true = []
        eval_loss = 0.0
        if return_sample_results:
            ids, sample_results = [], []
            all_labels = []
            features = {
                "Feature_t": [],
                "Feature_a": [],
                "Feature_v": [],
                "Feature_f": [],
            }
        # criterion = nn.L1Loss()
        with torch.no_grad():
            with tqdm(dataloader) as td:
                for batch_data in td:
                    loss=0.0
                    audio, text, vision, label_m , id = batch_data

                    audio, text, vision, label_m = [x.float().to(self.device) for x in [audio, text, vision, label_m ]]

                    if audio.size(0) != batch_size:
                        continue


                    labels_m = label_m.to(self.device).view(-1,1)
                    # forward

                    mutimamba_out=mutimamba(batch_size,audio,vision,text)

                    predict=mutimamba_out.float()
                    label=labels_m

                    if return_sample_results:
                        ids.extend(id)
                        # for item in features.keys():
                        #     features[item].append(outputs[item].cpu().detach().numpy())
                        all_labels.extend(labels_m.cpu().detach().tolist())
                        preds = mutimamba_out.cpu().detach().numpy()
                        # test_preds_i = np.argmax(preds, axis=1)
                        sample_results.extend(preds.squeeze())
                    
                    # loss = self.weighted_loss(outputs['M'], labels_m)
                    nll_loss = nn.L1Loss()#损失函数可修改
                    loss += nll_loss(predict, label)
                    eval_loss += loss.item()
                    y_pred.append(predict)
                    y_true.append(label)
        eval_loss = eval_loss / len(dataloader)
        logger.info(mode+"-(%s)" % "retnet" + " >> loss: %.4f " % eval_loss)
        pred, true = torch.cat(y_pred), torch.cat(y_true)
        eval_results = self.metrics(pred, true)
        logger.info('M: >> ' + dict_to_str(eval_results))
        eval_results['Loss'] = round(eval_loss, 4)

        if return_sample_results:
            eval_results["Ids"] = ids
            eval_results["SResults"] = sample_results
            for k in features.keys():
                features[k] = np.concatenate(features[k], axis=0)
            eval_results['Features'] = features
            eval_results['Labels'] = all_labels

        return eval_results