import torch
import torch.nn.functional as F
import logging
from torch import nn
from utils.functions import restore_model, save_model, EarlyStopping
from tqdm import trange, tqdm
from utils.metrics import AverageMeter, Metrics
from .pareto import ImprovedParetoOptimizer
from .schedular import AdaptiveLearningRateScheduler

__all__ = ['ConMR_Manager']

def KLDIV(p, q):
    p = F.log_softmax(p, dim=-1)
    q = F.softmax(q, dim=-1)
    
    return F.kl_div(p, q, reduction='batchmean')

class ConMR_Manager:

    def __init__(self, args, data, model):

        self.logger = logging.getLogger(args.logger_name)
        
        self.device, self.model = model.device, model.model

        self.optimizer_main = torch.optim.Adam(model.model.model._get_main_params(), lr=args.lr_main)
        self.optimizer_concept = torch.optim.Adam(model.model.model._get_concept_params(), lr=args.lr_concept)
        self.optimizer_text_weight = torch.optim.Adam(model.model.model.text_weight_predictor.parameters(), lr=args.lr_t_weight)
        self.optimizer_audio_weight = torch.optim.Adam(model.model.model.audio_weight_predictor.parameters(), lr=args.lr_a_weight)
        self.optimizer_video_weight = torch.optim.Adam(model.model.model.video_weight_predictor.parameters(), lr=args.lr_v_weight)

        optimizers_dict = {
        'main': self.optimizer_main,
        'concept': self.optimizer_concept,
        'text_weight': self.optimizer_text_weight,
        'audio_weight': self.optimizer_audio_weight,
        'video_weight': self.optimizer_video_weight
    }
    
        initial_lrs = {
            'main': args.lr_main,
            'concept': args.lr_concept,
            'text_weight': args.lr_t_weight,
            'audio_weight': args.lr_a_weight,
            'video_weight': args.lr_v_weight
        }

        self.adaptive_scheduler = AdaptiveLearningRateScheduler(optimizers_dict, initial_lrs, args)
        self.pareto_optimizer = ImprovedParetoOptimizer(num_losses=3)
        self.train_dataloader, self.eval_dataloader, self.test_dataloader = \
            data.mm_dataloader['train'], data.mm_dataloader['dev'], data.mm_dataloader['test']
        
        self.args = args
        self.criterion = nn.CrossEntropyLoss()
        self.metrics = Metrics(args)

        if args.train:
            self.best_eval_score = 0
        else:
            self.model = restore_model(self.model, args.model_output_path)

    def _train(self, args): 

        early_stopping = EarlyStopping(args)
        i = 0
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            i += 1
            self.model.train()
            total_cls_loss = AverageMeter()
            total_weight_loss = AverageMeter()
            total_concept_loss = AverageMeter()
            for step, batch in enumerate(tqdm(self.train_dataloader, desc="Iteration")):

                text_feats = batch['text_feats'].to(self.device)  
                video_feats = batch['video_feats'].to(self.device).to(torch.float32) 
                audio_feats = batch['audio_feats'].to(self.device).to(torch.float32) 
                text_scores = batch['text_concepts'].to(self.device).to(torch.float32)
                audio_scores = batch['audio_concepts'].to(self.device).to(torch.float32)
                video_scores = batch['video_concepts'].to(self.device).to(torch.float32)
                labels = batch['label_ids'].to(self.device) 

                with torch.set_grad_enabled(True):

                    logits, concept_weights_t, concept_weights_a, concept_weights_v, score_t, score_a, score_v = self.model(
                        text_feats, audio_feats, video_feats, labels=None, return_weights=True
                    )
                    
                    cls_loss = self.criterion(logits, labels)
                    
                    target_weights_t = self.model.model.target_text_weights[labels]
                    target_weights_a = self.model.model.target_audio_weights[labels]
                    target_weights_v = self.model.model.target_video_weights[labels]
                                
                    weight_loss = F.mse_loss(concept_weights_t, target_weights_t)
                    weight_loss += F.mse_loss(concept_weights_a, target_weights_a)
                    weight_loss += F.mse_loss(concept_weights_v, target_weights_v)

                    l2_reg = args.l2reg_t * torch.norm(concept_weights_t, p=2) + args.l2reg_a * torch.norm(concept_weights_a, p=2) + args.l2reg_v * torch.norm(concept_weights_v, p=2)
                    
                    weight_loss = weight_loss + l2_reg
                                
                    concept_loss = KLDIV(score_t, text_scores)
                    concept_loss += KLDIV(score_a, audio_scores)
                    concept_loss += KLDIV(score_v, video_scores)
                    losses_tensor = torch.tensor([cls_loss.item(), weight_loss.item(), concept_loss.item()], 
                                       device=self.device)
                    pareto_weights = self.pareto_optimizer.get_weights(losses_tensor)
                    
                    total_cls_loss.update(cls_loss.item(), labels.size(0))
                    total_concept_loss.update(concept_loss.item(), labels.size(0))
                    total_weight_loss.update(weight_loss.item(), labels.size(0))

                    total_loss = (pareto_weights[0] * cls_loss + 
                                pareto_weights[1] * weight_loss + 
                                pareto_weights[2] * concept_loss)

                    self.optimizer_main.zero_grad()
                    self.optimizer_concept.zero_grad()
                    self.optimizer_text_weight.zero_grad()
                    self.optimizer_audio_weight.zero_grad()
                    self.optimizer_video_weight.zero_grad()

                    total_loss.backward()

                    torch.nn.utils.clip_grad_norm_(self.model.model._get_main_params(), max_norm=args.grad_clip)
                    torch.nn.utils.clip_grad_norm_(self.model.model._get_concept_params(), max_norm=args.grad_clip)
                    torch.nn.utils.clip_grad_norm_(self.model.model.text_weight_predictor.parameters(), max_norm=args.grad_clip)
                    torch.nn.utils.clip_grad_norm_(self.model.model.audio_weight_predictor.parameters(), max_norm=args.grad_clip)
                    torch.nn.utils.clip_grad_norm_(self.model.model.video_weight_predictor.parameters(), max_norm=args.grad_clip)

                    self.optimizer_main.step()
                    self.optimizer_concept.step()
                    self.optimizer_text_weight.step()
                    self.optimizer_audio_weight.step()
                    self.optimizer_video_weight.step()

                    if step % (len(self.train_dataloader) // 2 ) == 0:
                        print(f"  Batch {step}: Pareto Weight = [CLS:{pareto_weights[0]:.3f}, Weight:{pareto_weights[1]:.3f}, Concept:{pareto_weights[2]:.3f}]")
                        print(f"    Current Loss = [CLS:{cls_loss.item():.4f}, Weight:{weight_loss.item():.4f}, Concept:{concept_loss.item():.4f}]")

            _, eval_score, weight_analysis = self._get_outputs(args, mode = 'eval')
            self.adaptive_scheduler.step(eval_score, weight_analysis, total_concept_loss.avg)

            eval_results = {
                'train_cls_loss': round(total_cls_loss.avg, 4),
                'train_weight_loss': round(total_weight_loss.avg, 4),
                'train_concept_loss': round(total_concept_loss.avg, 4),
                'best_eval_score': round(early_stopping.best_score, 4),
                'eval_score': round(eval_score, 4),
            }

            self.logger.info("***** Epoch: %s: Eval results *****", str(epoch + 1))
            for key in sorted(eval_results.keys()):
                self.logger.info("  %s = %s", key, str(eval_results[key]))
            if i >= args.num_warmup:
                early_stopping(eval_score, self.model)

            if early_stopping.early_stop:
                self.logger.info(f'EarlyStopping at epoch {epoch + 1}')
                break

        self.best_eval_score = early_stopping.best_score
        self.model = early_stopping.best_model   
        
        if args.save_model:
            self.logger.info('Trained models are saved in %s', args.model_output_path)
            save_model(self.model, args.model_output_path)   

    def _get_outputs(self, args, mode = 'eval', return_sample_results = False, show_results = False):
        
        if mode == 'eval':
            dataloader = self.eval_dataloader
        elif mode == 'test':
            dataloader = self.test_dataloader
        elif mode == 'train':
            dataloader = self.train_dataloader

        if mode == 'eval':
            self.model.eval()
            total_loss = 0.0
            correct = 0
            total = 0

            all_pred_weights_t = []
            all_pred_weights_a = []
            all_pred_weights_v = []
            all_target_weights_t = []
            all_target_weights_a = []
            all_target_weights_v = []
            
            with torch.no_grad():
                for batch in tqdm(dataloader, desc="Iteration"):
                    text = batch['text_feats'].to(self.device)
                    audio = batch['audio_feats'].to(self.device).to(torch.float32)
                    video = batch['video_feats'].to(self.device).to(torch.float32)
                    labels = batch['label_ids'].to(self.device)

                    logits, pred_weights_t, pred_weights_a, pred_weights_v = self.model(
                        text, audio, video, labels=None, return_weights=False
                    )
                    loss = F.cross_entropy(logits, labels)
                    total_loss += loss.item()

                    _, predicted = torch.max(logits.data, dim=1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
 
                    target_weights_t = self.model.model.target_text_weights[labels]
                    target_weights_a = self.model.model.target_audio_weights[labels]
                    target_weights_v = self.model.model.target_video_weights[labels]
                    
                    all_pred_weights_t.append(pred_weights_t.cpu())
                    all_pred_weights_a.append(pred_weights_a.cpu())
                    all_pred_weights_v.append(pred_weights_v.cpu())
                    all_target_weights_t.append(target_weights_t.cpu())
                    all_target_weights_a.append(target_weights_a.cpu())
                    all_target_weights_v.append(target_weights_v.cpu())

            accuracy = correct / total

            all_pred_weights_t = torch.cat(all_pred_weights_t, dim=0)
            all_pred_weights_a = torch.cat(all_pred_weights_a, dim=0)
            all_pred_weights_v = torch.cat(all_pred_weights_v, dim=0)
            all_target_weights_t = torch.cat(all_target_weights_t, dim=0)
            all_target_weights_a = torch.cat(all_target_weights_a, dim=0)
            all_target_weights_v = torch.cat(all_target_weights_v, dim=0)

            weight_diff_t = torch.abs(all_pred_weights_t - all_target_weights_t).mean().item()
            weight_diff_a = torch.abs(all_pred_weights_a - all_target_weights_a).mean().item()
            weight_diff_v = torch.abs(all_pred_weights_v - all_target_weights_v).mean().item()
            weight_diff_avg = (weight_diff_t + weight_diff_a + weight_diff_v) / 3

            def cosine_similarity_batch(pred, target):
                pred_norm = F.normalize(pred, dim=-1)
                target_norm = F.normalize(target, dim=-1)
                cos_sim = torch.sum(pred_norm * target_norm, dim=-1).mean().item()
                return cos_sim
            
            cos_sim_t = cosine_similarity_batch(all_pred_weights_t, all_target_weights_t)
            cos_sim_a = cosine_similarity_batch(all_pred_weights_a, all_target_weights_a)
            cos_sim_v = cosine_similarity_batch(all_pred_weights_v, all_target_weights_v)
            cos_sim_avg = (cos_sim_t + cos_sim_a + cos_sim_v) / 3
            
            weight_analysis = {
                'text_diff': weight_diff_t,
                'audio_diff': weight_diff_a,
                'video_diff': weight_diff_v,
                'avg_diff': weight_diff_avg,
                'text_cosine': cos_sim_t,
                'audio_cosine': cos_sim_a,
                'video_cosine': cos_sim_v,
                'avg_cosine': cos_sim_avg
            }
            
            return total_loss / len(dataloader), accuracy, weight_analysis
        else:
            self.model.eval()

            total_labels = torch.empty(0,dtype=torch.long).to(self.device)
            total_preds = torch.empty(0,dtype=torch.long).to(self.device)
            total_logits = torch.empty((0, args.num_labels)).to(self.device)
            
            loss_record = AverageMeter()

            for batch in tqdm(dataloader, desc="Iteration"):

                text_feats = batch['text_feats'].to(self.device)
                video_feats = batch['video_feats'].to(self.device).to(torch.float32)
                audio_feats = batch['audio_feats'].to(self.device).to(torch.float32)
                label_ids = batch['label_ids'].to(self.device)
                
                with torch.set_grad_enabled(False):
                    
                    logits, _, _, _ = self.model(text_feats, audio_feats, video_feats, return_weights = False)

                    total_logits = torch.cat((total_logits, logits))
                    total_labels = torch.cat((total_labels, label_ids))

                    loss = self.criterion(logits, label_ids)
                    loss_record.update(loss.item(), label_ids.size(0))

            total_probs = F.softmax(total_logits.detach(), dim=1)
            total_maxprobs, total_preds = total_probs.max(dim = 1)

            y_pred = total_preds.cpu().numpy()
            y_true = total_labels.cpu().numpy()

            outputs = self.metrics(y_true, y_pred, show_results=show_results)
            outputs.update({'loss': loss_record.avg})

            if return_sample_results:

                outputs.update(
                    {
                        'y_true': y_true,
                        'y_pred': y_pred
                    }
                )

            return outputs

    
    def _test(self, args):

        test_results = self._get_outputs(args, mode = 'test', return_sample_results = True, show_results = True)
        test_results['best_eval_score'] = round(self.best_eval_score, 4)
    
        return test_results
    