import pytorch_lightning as pl
import torch
import time
import logging
from torch import optim
from timm.utils import ModelEmaV2

from flow.transport.transport import create_transport, Sampler
from context_general_bci.config import (
    Metric,
    Output,
    DataKey,
)

logger = logging.getLogger(__name__)

class SiT_falcon(pl.LightningModule):
    def __init__(self, dict_cfg):
        super(SiT_falcon, self).__init__()
        self.cfg = dict_cfg
        self.cond_model = dict_cfg['cond_model'] if 'cond_model' in dict_cfg else AssertionError("Conditional model not provided in dict_cfg")
        self.ndt_cfg = self.cond_model.cfg
        self.phase = dict_cfg['phase'] if 'phase' in dict_cfg else 'pre_train'

        self.target_ndt = dict_cfg['target_ndt'] if 'target_ndt' in dict_cfg else None
        self.sit_model = dict_cfg['sit'] if 'sit' in dict_cfg else AssertionError("SiT model not provided in dict_cfg")
        self.latent_dim = dict_cfg['latent_dim'] if 'latent_dim' in dict_cfg else 128

    @torch.inference_mode()
    def predict(self):
        pass

    def forward(self):
        pass

    # ==================== Optimization ====================
    def common_log(self, metrics, prefix='', **kwargs):
        for m in metrics:
            if not isinstance(m, Metric) and not isinstance(m, Output) and 'update' not in m: # log misc, mostly task losses
                self.log(f'{prefix}_{m}', metrics[m], **kwargs)
        for m in self.ndt_cfg.task.metrics:
            if m == Metric.kinematic_r2 or m == Metric.kinematic_r2_thresh:
                self.log(f'{prefix}_{m.value}', metrics[m].mean(), **kwargs)
            else:
                self.log(f'{prefix}_{m.value}', metrics[m], **kwargs)
    
    def get_transport_and_sampler(self):
        # create transport and sampler
        transport = create_transport(
            path_type='Linear',
            prediction='velocity',
            loss_weight=None,
            train_eps=None,
            sample_eps=1e-2,
        )
        transport_sampler = Sampler(transport)
        sample_fn = transport_sampler.sample_ode(num_steps=2, sampling_method='euler')
        return transport, sample_fn

    def training_step(self, batch, batch_idx):
        batch_out = self.cond_model.get_batch_factors(batch, eval_mode=False)
        features = batch_out['factors']
        in_batch = batch_out['in_batch']
        with torch.no_grad():
            tgt_features = self.target_ndt(in_batch)
            exp_z_manifold = torch.reshape(tgt_features, (-1, self.latent_dim))
            exp_z_manifold = torch.unsqueeze(exp_z_manifold, dim=1)

        transport, sample_fn = self.get_transport_and_sampler()
        model_fn = self.sit_model.forward

        model_kwargs = dict(y=features, is_cond=False)
        loss_dict = transport.training_losses(self.sit_model, exp_z_manifold, model_kwargs=model_kwargs)
        loss = loss_dict['loss'].mean()

        # trainig sampling
        # noisy latent features
        train_num = int(features.shape[0] * features.shape[1])
        z_0 = torch.randn(train_num, self.latent_dim, device=features.device)
        z_0 = torch.unsqueeze(z_0, dim=1)

        sample_model_kwargs = dict(y=features, is_cond=False)
        samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]
        samples = torch.reshape(samples, (features.size(0), -1, self.latent_dim))

        metrics = self.cond_model.get_metrics_from_features(in_batch, samples, eval_mode=False, phase='val')
        metrics['loss'] = loss

        self.common_log(metrics, prefix='train', sync_dist=True, add_dataloader_idx=False)

        return metrics['loss'] 

    def validation_step(self, batch, batch_idx, dataloader_idx=0):

        batch_out = self.cond_model.get_batch_factors(batch, eval_mode=False)
        features = batch_out['factors']
        in_batch = batch_out['in_batch']

        _, sample_fn = self.get_transport_and_sampler()
        model_fn = self.sit_model.forward
        with torch.no_grad():
            # noisy latent features
            val_num = int(features.shape[0] * features.shape[1])
            z_0 = torch.randn(val_num, self.latent_dim, device=features.device)
            z_0 = torch.unsqueeze(z_0, dim=1)

            sample_model_kwargs = dict(y=features, is_cond=False)
            samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]
            samples = torch.reshape(samples, (features.size(0), -1, self.latent_dim))

            metrics = self.cond_model.get_metrics_from_features(in_batch, samples, eval_mode=False, phase='val')
        self.common_log(metrics, prefix='val' if dataloader_idx == 0 else 'eval', sync_dist=True, add_dataloader_idx=False)

        return metrics[Metric.kinematic_r2].mean()
    
    @torch.inference_mode()
    def test_step(self, batch, batch_idx):
        batch_out = self.cond_model.get_batch_factors(batch, eval_mode=False)
        features = batch_out['factors']
        in_batch = batch_out['in_batch']

        _, sample_fn = self.get_transport_and_sampler()
        model_fn = self.sit_model.forward
        with torch.no_grad():
            # noisy latent features
            val_num = int(features.shape[0] * features.shape[1])
            z_0 = torch.randn(val_num, self.latent_dim, device=features.device)
            z_0 = torch.unsqueeze(z_0, dim=1)

            sample_model_kwargs = dict(y=features, is_cond=False)
            samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]
            samples = torch.reshape(samples, (features.size(0), -1, self.latent_dim))

            metrics = self.cond_model.get_metrics_from_features(in_batch, samples, eval_mode=False, phase='eval')
        self.common_log(metrics, prefix='test')

        return metrics[Metric.kinematic_r2].mean()
    
    @torch.inference_mode()
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        batch_out = self.cond_model.get_batch_factors(batch, eval_mode=False)
        features = batch_out['factors']
        in_batch = batch_out['in_batch']
        start_time = time.time()

        _, sample_fn = self.get_transport_and_sampler()
        model_fn = self.sit_model.forward
        with torch.no_grad():
            # noisy latent features
            val_num = int(features.shape[0] * features.shape[1])
            z_0 = torch.randn(val_num, self.latent_dim, device=features.device)
            z_0 = torch.unsqueeze(z_0, dim=1)

            sample_model_kwargs = dict(y=features, is_cond=False)
            samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]
            samples = torch.reshape(samples, (features.size(0), -1, self.latent_dim))

            metrics = self.cond_model.get_metrics_from_features(in_batch, samples, eval_mode=False, phase='eval')
            metrics['pred_feature'] = samples

        end_time = time.time()
        print(f"Predict step time: {end_time - start_time:.4f} seconds")
        return metrics

    def freeze_sit_and_readout(self):
        logger.info("Freezing sit and readout.")
        for m in [self.sit_model]:
            for p in m.parameters():
                p.requires_grad = False

    def configure_optimizers(self):
        scheduler = None
        grouped_params = filter(lambda p: p.requires_grad, self.parameters())
        try:
            # from apex.optimizers import FusedAdam
            # optimizer_cls = FusedAdam # In JY's experience, about 5% speedup on 3090 in PT 1.13
            # However, literally spontaneous bug emerged where this doesn't train at all. What...?
            # And this was after successfully training and not touching anything else...?
            # The only plausible candidate is that env deactivating and reactivating lost some apex-critical state?
            # IDK.
            optimizer_cls = optim.AdamW
        except ImportError:
            logger.info("Didn't find Apex optimizer, defaulting to Pytorch AdamW")
            optimizer_cls = optim.AdamW
        optimizer = optimizer_cls(
            grouped_params,
            lr=self.ndt_cfg.lr_init,
            weight_decay=self.ndt_cfg.weight_decay
        )
        if self.ndt_cfg.lr_schedule == 'linear_warmup':
            scheduler = optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=self.ndt_cfg.lr_ramp_init_factor,
                total_iters=self.ndt_cfg.lr_ramp_steps
            )
        elif self.ndt_cfg.lr_schedule == 'cosine_warmup':
            scheduler = optim.lr_scheduler.ChainedScheduler([
                optim.lr_scheduler.LinearLR(
                    optimizer,
                    start_factor=self.ndt_cfg.lr_ramp_init_factor,
                    total_iters=self.ndt_cfg.lr_ramp_steps
                ),
                optim.lr_scheduler.CosineAnnealingLR(
                    optimizer,
                    T_max=self.ndt_cfg.lr_decay_steps,
                    eta_min=self.ndt_cfg.lr_min
                ),
            ])
        elif self.ndt_cfg.lr_schedule == 'cosine_timm':
            from timm.scheduler import CosineLRScheduler
            scheduler = CosineLRScheduler(
                optimizer,
                t_initial=self.ndt_cfg.lr_decay_steps, # 1 cycle
                lr_min=self.ndt_cfg.lr_min,
                warmup_lr_init=self.ndt_cfg.lr_ramp_init_factor * self.ndt_cfg.lr_init,
                warmup_t=self.ndt_cfg.lr_ramp_ratio * self.ndt_cfg.lr_decay_steps if self.ndt_cfg.lr_ramp_ratio > 0 else self.ndt_cfg.lr_ramp_steps,
                cycle_limit=1,
                t_in_epochs=True, # WTF why was this false... what even IS this arg
            )
        else:
            assert self.ndt_cfg.lr_schedule == 'fixed', f"Unknown lr_schedule {self.ndt_cfg.lr_schedule}"
        out = {
            'optimizer': optimizer,
            'monitor': 'val_loss'
        }
        if scheduler is not None:
            out['lr_scheduler'] = scheduler
        return out

def filter_state_dict(ckpt_path):
    ckpt = torch.load(ckpt_path)
    state_dict = ckpt['state_dict']

    # filter out keys
    filtered_state_dict = {
        k: v for k, v in state_dict.items() 
        if k.startswith("sit_model.") or k.startswith("cond_model.backbone")
    }

    return filtered_state_dict