import torch
import torch.nn as nn

import wandb
from tqdm import tqdm
from collections import defaultdict

from lib.sde_rjepa import RJEPA
from lib.losses import MSE_, GNLL_

from lib.utils import count_jepa_parameters
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

class RJEPAACSSM():

    def __init__(self, args):
        
        super(RJEPAACSSM, self).__init__()
        
        self.task = args.task
        self.device = args.device

        self.tau = args.tau
        self.eta = args.eta
        self.lamb = args.lamb
        
        self.eps_sigma = args.eps_sigma
        self.max_grad_norm = args.max_grad_norm
        
        self.joint_clipping = args.joint_clipping
        self.normalize_loss = args.normalize_loss

        self.jepa = RJEPA(args)

        if torch.cuda.is_available() and len(args.gpus) > 1:
            self.jepa = nn.DataParallel(self.jepa, device_ids=args.gpus)
        self.jepa = self.jepa.to(self.device)

        params = list(self.jepa.parameters())
        if self.tau > 0 :
            if isinstance(self.jepa, nn.DataParallel):
                target_encoder_params = set(p for p in self.jepa.module.target_encoder.parameters())
            else:
                target_encoder_params = set(p for p in self.jepa.target_encoder.parameters())
            params = [p for p in self.jepa.parameters() if p not in target_encoder_params]
        
        self.optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)
        self.scheduler = self.get_scheduler(args)

        self.encoder_params, self.predictor_params = count_jepa_parameters(self)

    def get_scheduler(self, args):

        if args.use_scheduler:
            
            assert args.num_steps % args.num_epochs == 0, "number of steps must be evenly divisible by epochs"
            
            num_warmup_epochs = args.get("num_warmup_epochs", 0)
            use_warmup = args.get("use_warmup", False)

            steps_per_epoch = args.num_steps // args.num_epochs
            total_steps = args.num_epochs * steps_per_epoch
            warmup_steps = num_warmup_epochs * steps_per_epoch
            warmup_scheduler = None

            if use_warmup:
                assert args.start_lr < args.lr, "starting lr should be less than lr for a proper warmup phase"
                assert warmup_steps < total_steps, "warmup steps should be less than total steps"
                warmup_scheduler = LinearLR(
                    self.optimizer,
                    start_factor=args.start_lr / args.lr,
                    end_factor=1.0,
                    total_iters=warmup_steps,
                )

            cosine_scheduler = CosineAnnealingLR(
                self.optimizer, 
                T_max=total_steps - warmup_steps if use_warmup else total_steps,
                eta_min=args.min_lr,
            )

            if use_warmup:
                return SequentialLR(
                    self.optimizer,
                    schedulers=[warmup_scheduler, cosine_scheduler],
                    milestones=[warmup_steps]
                )
            else:
                return cosine_scheduler
        else:
            return None

    def update_target_encoder(self):
        jepa_model = self.jepa.module if isinstance(self.jepa, nn.DataParallel) else self.jepa
        jepa_model.update_target_encoder()

    def train(self, train_loader, epoch=0):
        
        epoch_dict = defaultdict(int)
        num_data = 0

        self.jepa.train()

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for batch_idx, data in enumerate(progress_bar):
            
            batch_metrics = {}

            data = [x.to(self.device) if torch.is_tensor(x) else x for x in data]
            sample_id, full_obs, full_times, temp_mask, feat_mask = data

            self.optimizer.zero_grad()

            Z, Z_pred, KLs, *rest = self.jepa(full_obs, full_times, temp_mask, feat_mask)
                        
            Z_pred_mean, Z_pred_var = Z_pred
            KLs = KLs.mean(0).sum()

            batch_len = full_obs.size(0)
            num_data += batch_len

            target_mask = ~temp_mask
            
            loss = KLs * self.lamb
            if self.tau > 0 :
                target_mask_Z = target_mask.unsqueeze(-1).expand(-1, -1, Z.shape[-1])

                train_full_nll_jepa = GNLL_(Z, Z_pred_mean, Z_pred_var**2, eps=self.eps_sigma, normalize_dim=self.normalize_loss) * batch_len
                train_mask_nll_jepa = GNLL_(Z, Z_pred_mean, Z_pred_var**2, mask=target_mask_Z, eps=self.eps_sigma, normalize_dim=self.normalize_loss) * batch_len

                train_full_mse_jepa = MSE_(Z.flatten(start_dim=2), Z_pred_mean.flatten(start_dim=2)) * batch_len                
                train_mask_mse_jepa = MSE_(Z.flatten(start_dim=2), Z_pred_mean.flatten(start_dim=2), mask=target_mask_Z) * batch_len            

                if self.task == "full" :
                    loss = loss + train_full_nll_jepa * self.tau
                elif self.task == "interpolation" :
                    loss = loss + train_mask_nll_jepa * self.tau

                batch_metrics.update({
                    "train/jepa_full_mse" : train_full_mse_jepa.detach().item() / batch_len,
                    "train/jepa_mask_mse" : train_mask_mse_jepa.detach().item() / batch_len,
                    "train/jepa_full_nll" : train_full_nll_jepa.detach().item() / batch_len,
                    "train/jepa_mask_nll" : train_mask_nll_jepa.detach().item() / batch_len,
                })

                epoch_dict["train_full_mse_jepa"] += train_full_mse_jepa.detach().item()
                epoch_dict["train_mask_mse_jepa"] += train_mask_mse_jepa.detach().item()
                epoch_dict["train_full_nll_jepa"] += train_full_nll_jepa.detach().item()
                epoch_dict["train_mask_nll_jepa"] += train_mask_nll_jepa.detach().item()

            if self.eta > 0 :

                (O_mean, O_var) = rest[0]
                O_mean, O_var = O_mean.transpose(0, 1), O_var.transpose(0, 1)

                target_mask_O = target_mask.unsqueeze(-1).expand(-1, -1, O_mean.shape[-1])

                train_full_nll_recon = GNLL_(full_obs.flatten(start_dim=2), O_mean.flatten(start_dim=3), O_var.flatten(start_dim=3), normalize_dim=self.normalize_loss) * batch_len
                train_mask_nll_recon = GNLL_(full_obs.flatten(start_dim=2), O_mean.flatten(start_dim=3), O_var.flatten(start_dim=3), mask=target_mask_O.flatten(start_dim=2), normalize_dim=self.normalize_loss) * batch_len
                train_full_mse_recon = MSE_(full_obs.flatten(start_dim=2), O_mean.flatten(start_dim=3)) * batch_len                
                train_mask_mse_recon = MSE_(full_obs.flatten(start_dim=2), O_mean.flatten(start_dim=3), mask=target_mask_O.flatten(start_dim=2)) * batch_len
                
                if self.task == "full" :
                    loss = loss + train_full_nll_recon * self.eta
                elif self.task == "interpolation" :
                    loss = loss + train_mask_nll_recon * self.eta

                batch_metrics.update({
                    "train/recon_full_mse" : train_full_mse_recon.detach().item() / batch_len,
                    "train/recon_mask_mse" : train_mask_mse_recon.detach().item() / batch_len,
                    "train/recon_full_nll" : train_full_nll_recon.detach().item() / batch_len,
                    "train/recon_mask_nll" : train_mask_nll_recon.detach().item() / batch_len,
                })

                epoch_dict["train_full_mse_recon"] += train_full_mse_recon.detach().item()
                epoch_dict["train_mask_mse_recon"] += train_mask_mse_recon.detach().item()
                epoch_dict["train_full_nll_recon"] += train_full_nll_recon.detach().item()
                epoch_dict["train_mask_nll_recon"] += train_mask_nll_recon.detach().item()

            if loss.ndim > 0 :
                raise ValueError(f"Expected loss to be 0-D for DataParallel, but got shape {loss.shape}")

            loss.backward()
            jepa_model = self.jepa.module if isinstance(self.jepa, nn.DataParallel) else self.jepa
            
            # option 1 : joint clipping
            if self.joint_clipping :
                _joint_norm = torch.nn.utils.clip_grad_norm_(jepa_model.parameters(), self.max_grad_norm)
                batch_metrics['train/joint_grad_norm'] = _joint_norm.item()
            # option 2 : separate clipping    
            else :
                _enc_norm = torch.nn.utils.clip_grad_norm_(jepa_model.context_encoder.parameters(), self.max_grad_norm)
                _pred_norm = torch.nn.utils.clip_grad_norm_(jepa_model.predictor.parameters(), self.max_grad_norm)
                batch_metrics['train/context_encoder_grad_norm'] = _enc_norm.item()
                batch_metrics['train/predictor_grad_norm'] = _pred_norm.item()
                if self.eta > 0 :
                    _dec_norm = torch.nn.utils.clip_grad_norm_(jepa_model.decoder.parameters(), self.max_grad_norm)
                    batch_metrics['train/decoder_grad_norm'] = _dec_norm.item()

            self.optimizer.step()
            if self.tau > 0 :
                self.update_target_encoder()
            if self.scheduler :
                self.scheduler.step()

            batch_metrics.update({
                'train/step': epoch * len(train_loader) + batch_idx,
                'train/KL': KLs.detach().item(),
                'train/Z_mean_scale': torch.norm(Z_pred_mean, dim=-1).mean().detach().item(),
                'train/Z_std_scale': torch.norm(Z_pred_var, dim=-1).mean().detach().item(),
            })

            if self.scheduler :
                batch_metrics['train/lr'] = self.scheduler.get_last_lr()[0] # [0] should be provided

            wandb.log(batch_metrics)
            progress_bar.set_postfix(batch_metrics)

        epoch_dict.update({
            "train/num_samples" : num_data
        })
        return epoch_dict

    def evaluate(self, test_loader, epoch=0):
        
        epoch_dict = defaultdict(int)
        num_data = 0

        self.jepa.eval()

        with torch.no_grad():
            progress_bar = tqdm(test_loader, desc=f"Evaluation epoch {epoch+1}")
            for batch_idx, data in enumerate(progress_bar):

                batch_metrics = {}

                data = [x.to(self.device) if torch.is_tensor(x) else x for x in data]
                sample_id, full_obs, full_times, temp_mask, feat_mask = data

                Z, Z_pred, KLs, *rest = self.jepa(full_obs, full_times, temp_mask, feat_mask)

                Z_pred_mean, Z_pred_var = Z_pred
                KLs = KLs.mean(0).sum()

                batch_len = full_obs.size(0)
                num_data += batch_len
                
                target_mask = ~temp_mask

                if self.tau > 0 :
                    target_mask_Z = target_mask.unsqueeze(-1).expand(-1, -1, Z.shape[-1])

                    test_nll_jepa = GNLL_(Z, Z_pred_mean, Z_pred_var**2, mask=target_mask_Z, eps=self.eps_sigma, normalize_dim=self.normalize_loss) * batch_len
                    test_mse_jepa = MSE_(Z.flatten(start_dim=2), Z_pred_mean.flatten(start_dim=2), mask=target_mask_Z) * batch_len

                    batch_metrics.update({
                        "test/jepa_mask_mse" : test_mse_jepa.detach().item() / batch_len,
                        "test/jepa_mask_nll" : test_nll_jepa.detach().item() / batch_len,
                    })
                    epoch_dict["test_mse_jepa"] += test_mse_jepa.detach().item()
                    epoch_dict["test_nll_jepa"] += test_nll_jepa.detach().item()

                if self.eta > 0 :
                    (O_mean, O_var) = rest[0]
                    O_mean, O_var = O_mean.transpose(0, 1), O_var.transpose(0, 1)
                    
                    target_mask_O = target_mask.unsqueeze(-1).expand(-1, -1, O_mean.shape[-1])

                    test_full_nll_recon = GNLL_(full_obs.flatten(start_dim=2), O_mean.flatten(start_dim=3), O_var.flatten(start_dim=3), normalize_dim=self.normalize_loss) * batch_len
                    test_mask_nll_recon = GNLL_(full_obs.flatten(start_dim=2), O_mean.flatten(start_dim=3), O_var.flatten(start_dim=3), mask=target_mask_O.flatten(start_dim=2), normalize_dim=self.normalize_loss) * batch_len

                    test_full_mse_recon = MSE_(full_obs.flatten(start_dim=2), O_mean.flatten(start_dim=3)) * batch_len                
                    test_mask_mse_recon = MSE_(full_obs.flatten(start_dim=2), O_mean.flatten(start_dim=3), mask=target_mask_O.flatten(start_dim=2)) * batch_len

                    batch_metrics.update({
                        "test/recon_full_mse" : test_full_mse_recon.detach().item() / batch_len,
                        "test/recon_mask_mse" : test_mask_mse_recon.detach().item() / batch_len,
                        "test/recon_full_nll" : test_full_nll_recon.detach().item() / batch_len,
                        "test/recon_mask_nll" : test_mask_nll_recon.detach().item() / batch_len,
                    })

                    epoch_dict["test_full_mse_recon"] += test_full_mse_recon.detach().item()
                    epoch_dict["test_mask_mse_recon"] += test_mask_mse_recon.detach().item()
                    epoch_dict["test_full_nll_recon"] += test_full_nll_recon.detach().item()
                    epoch_dict["test_mask_nll_recon"] += test_mask_nll_recon.detach().item()
 
                # Log batch metrics with detach
                batch_metrics.update({
                    'test/KL': KLs.detach().item(),
                })
                progress_bar.set_postfix(batch_metrics)

        epoch_dict.update({
            "test/num_samples" : num_data
        })
        return epoch_dict
