import logging
import numpy as np
import torch
import torch.nn as nn
import json
import os
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from .utils import MetricsTop, dict_to_str
import torch.nn.functional as F
logger = logging.getLogger('MMSA')


class ModalityConflictLoss(nn.Module):
    def __init__(self, weight=0.1):
        super().__init__()
        self.weight = weight
    
    def forward(self, feature_t, feature_a, feature_v, main_loss):
        grad_t = torch.autograd.grad(main_loss, feature_t, create_graph=True)[0]
        grad_a = torch.autograd.grad(main_loss, feature_a, create_graph=True)[0]
        grad_v = torch.autograd.grad(main_loss, feature_v, create_graph=True)[0]
        
        cos_sim_l_a = F.cosine_similarity(grad_t.reshape(1, -1), grad_a.reshape(1, -1), dim=1).squeeze()
        cos_sim_l_v = F.cosine_similarity(grad_t.reshape(1, -1), grad_v.reshape(1, -1), dim=1).squeeze()
        cos_sim_a_v = F.cosine_similarity(grad_a.reshape(1, -1), grad_v.reshape(1, -1), dim=1).squeeze()
        cos_similarities = {
                                    "cos_sim_l_a": cos_sim_l_a.item(),
                                    "cos_sim_l_v": cos_sim_l_v.item(),
                                    "cos_sim_a_v": cos_sim_a_v.item()
                                }
        
        # TODO: 是否需要允许适当的模态间的互补性/正交
        return -self.weight * (cos_sim_l_a + cos_sim_l_v + cos_sim_a_v), cos_similarities


class Trainer():
    def __init__(self, args):
        self.args = args
        self.criterion = nn.L1Loss() if args.train_mode == 'regression' else nn.CrossEntropyLoss()
        self.metrics = MetricsTop(args.train_mode).getMetics(args.dataset_name)

        self.modality_conflict_loss = ModalityConflictLoss(weight=self.args.modality_conflict_loss_weight) # adjust the wieght for modality conflict loss
        
    def do_train(self, model, dataloader, return_epoch_results=False):
        optimizer = optim.Adam(model.parameters(), lr=self.args.learning_rate)
        min_or_max = 'min' if self.args.KeyEval in ['Loss'] else 'max' # KeyEval: ['Loss', 'Acc_2', 'F-1', 'Acc_7']
        best_valid = 1e8 if min_or_max == 'min' else 0
        scheduler = ReduceLROnPlateau(optimizer, mode=min_or_max, factor=0.5, min_lr=1e-6, verbose=True, patience=self.args.patience) 
        # initilize results
        epochs, best_epoch = 0, 0
        if return_epoch_results:
            epoch_results = {
                'train': [],
                'valid': [],
                'test': []
            }
        
        cos_similarities_history = []
        
        if self.args.generate:
            # load model weights obtained by pretraining with available complete multimodal data
            origin_model = torch.load('results/pretrained/pretrained-{}-{}.pt'.format(self.args.dataset_name, self.args.available_size))
            net_dict = model.state_dict()

            if hasattr(origin_model, 'state_dict'):
                origin_model = origin_model.state_dict()

            new_state_dict = {}
            for k, v in origin_model.items():
                new_key = k.replace('fuse.', '')
                new_state_dict[new_key] = v
            
            net_dict.update(new_state_dict)
            
            model.load_state_dict(net_dict)
            print("Weight loaded successfully!")
            model = model.to(self.args.device)
        
        while True:
            epochs += 1
            # train
            y_pred, y_true = [], []
            model.train()
            train_loss = 0.0
            left_epochs = self.args.update_epochs # set update_epochs to 1, to disenable gradient accumulation

            with tqdm(dataloader['train']) as td:
                    for batch_data in enumerate(td):
                        batch_data = batch_data[1]
                        if left_epochs == self.args.update_epochs:
                            optimizer.zero_grad()
                        left_epochs -= 1
                        vision = batch_data['vision'].to(self.args.device)
                        audio = batch_data['audio'].to(self.args.device)
                        text = batch_data['text'].to(self.args.device)
                        labels = batch_data['labels']['M'].to(self.args.device)
                        if self.args.train_mode == 'classification':
                            labels = labels.view(-1).long()
                        else:
                            labels = labels.view(-1, 1)

                        # forward
                        #outputs = model([text, audio, vision], num_modal=self.args.num_modal, ava_modal_idx=self.args.ava_modal_idx) # idmer ✅ 
                        outputs = model([text, audio, vision], labels, num_modal=self.args.num_modal, ava_modal_idx=self.args.ava_modal_idx) # dicmor ✅
                
                        # compute loss
                        combined_loss = torch.tensor(0.0, device=self.args.device)
                        task_loss = self.criterion(outputs['M'], labels)
                        combined_loss += task_loss
                        
                        if self.args.generate:
                            generation_loss_l = outputs['loss_score_l'] if self.args.model_name == 'imder' else outputs['log_p_l']
                            generation_loss_v = outputs['loss_score_v'] if self.args.model_name == 'imder' else outputs['log_p_v']
                            generation_loss_a = outputs['loss_score_a'] if self.args.model_name == 'imder' else outputs['log_p_a']
                            cross_modal_generation_loss = self.args.cross_modal_generation_loss_weight * (generation_loss_l + generation_loss_v + generation_loss_a)
                            combined_loss += cross_modal_generation_loss
                        
                        # modality conflict regularization
                        if hasattr(self.args, 'use_modal_conflict') and self.args.use_modal_conflict:
                            #print("Tensor Type:", outputs['Feature_t'].is_leaf)
                            feature_t = outputs['Feature_t']
                            feature_a = outputs['Feature_a']
                            feature_v = outputs['Feature_v']
                           
                            if feature_t.requires_grad and feature_a.requires_grad and feature_v.requires_grad:
                                modal_conflict_loss, cos_similarities = self.modality_conflict_loss(
                                                        feature_t, feature_a, feature_v, task_loss)
                                
                                cos_similarities["epoch"] = epochs
                                cos_similarities_history.append(cos_similarities)
                                combined_loss += modal_conflict_loss
                            
                            cos_sim_path = os.path.join(self.args.save_dir, 'cos_similarities.json')
                            with open(cos_sim_path, 'w', encoding='utf-8') as f:
                                json.dump(cos_similarities_history, f, indent=4)
                        
                        # backward
                        combined_loss.backward()   
                        
                        if self.args.grad_clip != -1.0:
                            nn.utils.clip_grad_value_([param for param in model.parameters() if param.requires_grad],
                                                      self.args.grad_clip)
                        # store results
                        train_loss += combined_loss.item()
                        y_pred.append(outputs['M'].cpu())
                        y_true.append(labels.cpu())
                        if not left_epochs:
                            optimizer.step()
                            left_epochs = self.args.update_epochs
                    if not left_epochs:
                        # update
                        optimizer.step()
            train_loss = train_loss / len(dataloader['train'])

            pred, true = torch.cat(y_pred), torch.cat(y_true)
            train_results = self.metrics(pred, true)
            logger.info(
                f"TRAIN-({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] "
                f">> loss: {round(train_loss, 4)} "
                f"{dict_to_str(train_results)}"
            )

            # validation
            val_results = self.do_test(model, dataloader['valid'], mode="VAL")
            test_results = self.do_test(model, dataloader['test'], mode="TEST")
            cur_valid = val_results[self.args.KeyEval] # 可以以acc_2或者val loss为标准来保存最佳模型
            scheduler.step(val_results['Loss']) # 始终使用val loss来监控训练过程
            # save each epoch model
            model_save_path = self.args.save_dir + str(epochs) + '.pth'
            torch.save(model.state_dict(), model_save_path)
            # save best model
            isBetter = cur_valid <= (best_valid - 1e-6) if min_or_max == 'min' else cur_valid >= (best_valid + 1e-6)
            if isBetter:
                best_valid, best_epoch = cur_valid, epochs
                # save model
                torch.save(model.cpu().state_dict(), self.args.model_save_path)
                model.to(self.args.device)
            # epoch results
            if return_epoch_results:
                train_results["Loss"] = train_loss
                epoch_results['train'].append(train_results)
                epoch_results['valid'].append(val_results)
                test_results = self.do_test(model, dataloader['test'], mode="TEST")
                epoch_results['test'].append(test_results)
            # early stop
            if epochs - best_epoch >= self.args.early_stop:
                return epoch_results if return_epoch_results else None


    def do_test(self, model, dataloader, mode="VAL", return_sample_results=False):
        model.eval()
        y_pred, y_true = [], []

        eval_loss = 0.0
        if return_sample_results:
            ids, sample_results = [], []
            all_labels = []
            features = {
                "Feature_t": [],
                "Feature_a": [],
                "Feature_v": [],
                "Feature_f": [],
            }
        
        with torch.no_grad():
            with tqdm(dataloader) as td:
                for batch_data in td:
                    vision = batch_data['vision'].to(self.args.device)
                    audio = batch_data['audio'].to(self.args.device)
                    text = batch_data['text'].to(self.args.device)
                    labels = batch_data['labels']['M'].to(self.args.device)
                    if self.args.train_mode == 'classification':
                        labels = labels.view(-1).long()
                    else:
                        labels = labels.view(-1, 1)

                    #outputs = model([text, audio, vision], num_modal=self.args.num_modal, ava_modal_idx=self.args.ava_modal_idx)
                    #outputs = model([text, audio, vision], labels, num_modal=self.args.num_modal, ava_modal_idx=self.args.ava_modal_idx)
                    
                    outputs = model([text, audio, vision], labels, num_modal=3, ava_modal_idx=[0,1,2])  # use fixed full modal test set to inference
                   

                    if return_sample_results:
                        ids.extend(batch_data['id'])
                        for item in features.keys():
                            features[item].append(outputs[item].cpu().detach().numpy())
                        all_labels.extend(labels.cpu().detach().tolist())
                        preds = outputs["M"].cpu().detach().numpy()
                        sample_results.extend(preds.squeeze())

                    loss = self.criterion(outputs['M'], labels)
                    eval_loss += loss.item()
                    y_pred.append(outputs['M'].cpu())
                    y_true.append(labels.cpu())

        eval_loss = eval_loss / len(dataloader)
        pred, true = torch.cat(y_pred), torch.cat(y_true)

        eval_results = self.metrics(pred, true)
        eval_results["Loss"] = round(eval_loss, 4)
        logger.info(f"{mode}-({self.args.model_name}) >> {dict_to_str(eval_results)}")

        if return_sample_results:
            eval_results["Ids"] = ids
            eval_results["SResults"] = sample_results
            for k in features.keys():
                features[k] = np.concatenate(features[k], axis=0)
            eval_results['Features'] = features
            eval_results['Labels'] = all_labels

        return eval_results

        
                        

        

