from typing import Optional
from pathlib import Path
from contextlib import nullcontext

import torch
import torch.nn.functional as F
from torch_scatter import scatter_mean

from src.constants import atom_encoder, bond_encoder
from src.model.lightning import DrugFlow, set_default
from src.data.dataset import ProcessedLigandPocketDataset, DPODataset
from src.data.data_utils import AppendVirtualNodesInCoM, Residues, center_data

class DPO(DrugFlow):
    def __init__(self, dpo_mode, ref_checkpoint_p, **kwargs):
        super(DPO, self).__init__(**kwargs)
        self.dpo_mode = dpo_mode
        self.dpo_beta = kwargs['loss_params'].dpo_beta if 'dpo_beta' in kwargs['loss_params'] else 100.0
        self.dpo_beta_schedule = kwargs['loss_params'].dpo_beta_schedule if 'dpo_beta_schedule' in kwargs['loss_params'] else 't'
        self.clamp_dpo = kwargs['loss_params'].clamp_dpo if 'clamp_dpo' in kwargs['loss_params'] else True
        self.dpo_lambda_dpo = kwargs['loss_params'].dpo_lambda_dpo if 'dpo_lambda_dpo' in kwargs['loss_params'] else 1
        self.dpo_lambda_w = kwargs['loss_params'].dpo_lambda_w if 'dpo_lambda_w' in kwargs['loss_params'] else 1
        self.dpo_lambda_l = kwargs['loss_params'].dpo_lambda_l if 'dpo_lambda_l' in kwargs['loss_params'] else 0.2
        self.dpo_lambda_h = kwargs['loss_params'].dpo_lambda_h if 'dpo_lambda_h' in kwargs['loss_params'] else kwargs['loss_params'].lambda_h
        self.dpo_lambda_e = kwargs['loss_params'].dpo_lambda_e if 'dpo_lambda_e' in kwargs['loss_params'] else kwargs['loss_params'].lambda_e
        self.ref_dynamics = self.init_model(kwargs['predictor_params'])
        state_dict = torch.load(ref_checkpoint_p)['state_dict']
        self.ref_dynamics.load_state_dict({k.replace('dynamics.',''): v for k, v in state_dict.items() if k.startswith('dynamics.')})
        print(f'Loaded reference model from {ref_checkpoint_p}')
        # initializing model params with ref model params
        self.dynamics.load_state_dict(self.ref_dynamics.state_dict())

    def get_dataset(self, stage, pocket_transform=None):

        # when sampling we don't append virtual nodes as we might need access to the ground truth size
        if self.virtual_nodes and stage == 'train':
            ligand_transform = AppendVirtualNodesInCoM(
                atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max)
        else:
            ligand_transform = None

        # we want to know if something goes wrong on the validation or test set
        catch_errors = stage == 'train'

        if self.sharded_dataset:
            raise NotImplementedError('Sharded dataset not implemented for DPO')

        if self.sample_from_clusters and stage == 'train':  # val/test should be deterministic
            raise NotImplementedError('Sampling from clusters not implemented for DPO')

        if stage == 'train':
            return DPODataset(
                Path(self.datadir, 'train.pt'),
                ligand_transform=None,
                pocket_transform=pocket_transform,
                catch_errors=True,
            )
        else:
            return ProcessedLigandPocketDataset(
                pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'),
                ligand_transform=ligand_transform,
                pocket_transform=pocket_transform,
                catch_errors=catch_errors,
            )


    def training_step(self, data, *args):
        ligand_w, ligand_l, pocket = data['ligand'], data['ligand_l'], data['pocket']
        loss, info = self.compute_dpo_loss(pocket, ligand_w=ligand_w, ligand_l=ligand_l, return_info=True)

        if torch.isnan(loss):
            print(f'For ligand pair , loss is NaN at epoch {self.current_epoch}. Info: {info}')
        
        log_dict = {k: v for k, v in info.items() if isinstance(v, float) or torch.numel(v) <= 1}
        self.log_metrics({'loss': loss, **log_dict}, 'train', batch_size=len(ligand_w['size']))

        out = {'loss': loss, **info}
        self.training_step_outputs.append(out)
        return out
    
    def validation_step(self, data, *args):
        return super().validation_step(data, *args)

    def compute_dpo_loss(self, pocket, ligand_w, ligand_l, return_info=False):
        t = torch.rand(ligand_w['size'].size(0), device=ligand_w['x'].device).unsqueeze(-1)

        if self.dpo_beta_schedule == 't':
            # from https://arxiv.org/pdf/2407.13981
            beta_t = (self.dpo_beta * t).squeeze()
        elif self.dpo_beta_schedule == 'tcomplement':
            beta_t = self.dpo_beta * (1 - t).squeeze()
        elif self.dpo_beta_schedule == 'const':
            beta_t = self.dpo_beta
        else:
            raise ValueError(f'Unknown DPO beta schedule: {self.dpo_beta_schedule}')

        loss_dict_w = self.compute_loss_single_pair(ligand_w, pocket, t)
        loss_dict_l = self.compute_loss_single_pair(ligand_l, pocket, t)
        info = {
            'loss_x_w': loss_dict_w['theta']['x'].mean().item(),
            'loss_h_w': loss_dict_w['theta']['h'].mean().item(),
            'loss_e_w': loss_dict_w['theta']['e'].mean().item(),
            'loss_x_l': loss_dict_l['theta']['x'].mean().item(),
            'loss_h_l': loss_dict_l['theta']['h'].mean().item(),
            'loss_e_l': loss_dict_l['theta']['e'].mean().item(),
        }
        if self.dpo_mode == 'single_dpo_comp':
            loss_w_theta = (
                loss_dict_w['theta']['x'] +
                self.dpo_lambda_h * loss_dict_w['theta']['h'] +
                self.dpo_lambda_e * loss_dict_w['theta']['e']
            )
            loss_w_ref = (
                loss_dict_w['ref']['x'] +
                self.dpo_lambda_h * loss_dict_w['ref']['h'] +
                self.dpo_lambda_e * loss_dict_w['ref']['e']
            )
            loss_l_theta = (
                loss_dict_l['theta']['x'] +
                self.dpo_lambda_h * loss_dict_l['theta']['h'] +
                self.dpo_lambda_e * loss_dict_l['theta']['e']
            )
            loss_l_ref = (
                loss_dict_l['ref']['x'] +
                self.dpo_lambda_h * loss_dict_l['ref']['h'] +
                self.dpo_lambda_e * loss_dict_l['ref']['e']
            )
            diff_w = loss_w_theta - loss_w_ref
            diff_l = loss_l_theta - loss_l_ref
            info['diff_w'] = diff_w.mean().item()
            info['diff_l'] = diff_l.mean().item()
            # print(diff)
            diff = -1 * beta_t * (diff_w - diff_l)
            loss = -1 * F.logsigmoid(diff)
        elif self.dpo_mode == 'single_dpo_comp_v2': # works with careful tuning of beta_t, dpolambda_w and dpo_lambda_l
            loss_w_theta = (
                loss_dict_w['theta']['x'] +
                self.dpo_lambda_h * loss_dict_w['theta']['h'] +
                self.dpo_lambda_e * loss_dict_w['theta']['e']
            )
            info['loss_w_theta'] = loss_w_theta.mean().item()
            loss_w_ref = (
                loss_dict_w['ref']['x'] +
                self.dpo_lambda_h * loss_dict_w['ref']['h'] +
                self.dpo_lambda_e * loss_dict_w['ref']['e']
            )
            info['loss_w_ref'] = loss_w_ref.mean().item()
            loss_l_theta = (
                loss_dict_l['theta']['x'] +
                self.dpo_lambda_h * loss_dict_l['theta']['h'] +
                self.dpo_lambda_e * loss_dict_l['theta']['e']
            )
            info['loss_l_theta'] = loss_l_theta.mean().item()
            loss_l_ref = (
                loss_dict_l['ref']['x'] +
                self.dpo_lambda_h * loss_dict_l['ref']['h'] +
                self.dpo_lambda_e * loss_dict_l['ref']['e']
            )
            info['loss_l_ref'] = loss_l_ref.mean().item()
            diff_w = loss_w_theta - loss_w_ref
            diff_l = loss_l_theta - loss_l_ref
            info['diff_w'] = diff_w.mean().item()
            info['diff_l'] = diff_l.mean().item()
            diff = -1 * beta_t * (diff_w - diff_l)
            if self.clamp_dpo:
                diff = diff.clamp(-10, 10)
            info['dpo_arg_min'] = diff.min().item()
            info['dpo_arg_max'] = diff.max().item()
            info['dpo_arg_mean'] = diff.mean().item()
            dpo_loss = -1 * self.dpo_lambda_dpo * F.logsigmoid(diff)
            info['dpo_loss'] = dpo_loss.mean().item()
            
            loss_w_theta_reg = (
                loss_dict_w['theta']['x'] +
                self.lambda_h * loss_dict_w['theta']['h'] +
                self.lambda_e * loss_dict_w['theta']['e']
            )
            info['loss_w_theta_reg'] = loss_w_theta_reg.mean().item()
            loss_l_theta_reg = (
                loss_dict_l['theta']['x'] +
                self.lambda_h * loss_dict_l['theta']['h'] +
                self.lambda_e * loss_dict_l['theta']['e']
            )
            info['loss_l_theta_reg'] = loss_l_theta_reg.mean().item()
            dpo_reg = self.dpo_lambda_w * loss_w_theta_reg + \
                      self.dpo_lambda_l * loss_l_theta_reg
            info['dpo_reg'] = dpo_reg.mean().item()
            loss = dpo_loss + dpo_reg
        elif self.dpo_mode == 'single_dpo_comp_v3': # equivaleent to v2 but diff_w and diff_l are logged
            diff_w_x = loss_dict_w['theta']['x'] - loss_dict_w['ref']['x']
            diff_w_h = loss_dict_w['theta']['h'] - loss_dict_w['ref']['h']
            diff_w_e = loss_dict_w['theta']['e'] - loss_dict_w['ref']['e']
            diff_l_x = loss_dict_l['theta']['x'] - loss_dict_l['ref']['x']
            diff_l_h = loss_dict_l['theta']['h'] - loss_dict_l['ref']['h']
            diff_l_e = loss_dict_l['theta']['e'] - loss_dict_l['ref']['e']
            info['diff_w_x'] = diff_w_x.mean().item()
            info['diff_w_h'] = diff_w_h.mean().item()
            info['diff_w_e'] = diff_w_e.mean().item()
            info['diff_l_x'] = diff_l_x.mean().item()
            info['diff_l_h'] = diff_l_h.mean().item()
            info['diff_l_e'] = diff_l_e.mean().item()
            
            # not used, just for logging
            _diff_w = diff_w_x + self.dpo_lambda_h * diff_w_h + self.dpo_lambda_e * diff_w_e
            _diff_l = diff_l_x + self.dpo_lambda_h * diff_l_h + self.dpo_lambda_e * diff_l_e
            info['diff_w'] = _diff_w.mean().item()
            info['diff_l'] = _diff_l.mean().item()

            diff_x = diff_w_x - diff_l_x
            diff_h = diff_w_h - diff_l_h
            diff_e = diff_w_e - diff_l_e
            info['diff_x'] = diff_x.mean().item()
            info['diff_h'] = diff_h.mean().item()
            info['diff_e'] = diff_e.mean().item()

            diff = -1 * beta_t * (diff_x + self.dpo_lambda_h * diff_h + self.dpo_lambda_e * diff_e)
            if self.clamp_dpo:
                diff = diff.clamp(-10, 10)
            info['dpo_arg_min'] = diff.min().item()
            info['dpo_arg_max'] = diff.max().item()
            info['dpo_arg_mean'] = diff.mean().item()
            dpo_loss = -1 * self.dpo_lambda_dpo * F.logsigmoid(diff)
            info['dpo_loss'] = dpo_loss.mean().item()
            
            loss_w_theta_reg = (
                loss_dict_w['theta']['x'] +
                self.lambda_h * loss_dict_w['theta']['h'] +
                self.lambda_e * loss_dict_w['theta']['e']
            )
            info['loss_w_theta_reg'] = loss_w_theta_reg.mean().item()
            loss_l_theta_reg = (
                loss_dict_l['theta']['x'] +
                self.lambda_h * loss_dict_l['theta']['h'] +
                self.lambda_e * loss_dict_l['theta']['e']
            )
            info['loss_l_theta_reg'] = loss_l_theta_reg.mean().item()
            dpo_reg = self.dpo_lambda_w * loss_w_theta_reg + \
                      self.dpo_lambda_l * loss_l_theta_reg
            info['dpo_reg'] = dpo_reg.mean().item()
            loss = dpo_loss + dpo_reg
        elif self.dpo_mode == 'sep_dpo_comp_nocoord':
            loss = 0
            # dpo loss only on discrete components
            for component, weight in zip(['h', 'e'], [self.dpo_lambda_h, self.dpo_lambda_e]):
                diff_wc = loss_dict_w['theta'][component] - loss_dict_w['ref'][component]
                diff_lc = loss_dict_l['theta'][component] - loss_dict_l['ref'][component]
                info[f'diff_{component}_w'] = diff_wc.mean().item()
                info[f'diff_{component}_l'] = diff_lc.mean().item()
                dpo_loss = -1 * beta_t * weight * (diff_wc - diff_lc)
                if self.clamp_dpo:
                    dpo_loss = dpo_loss.clamp(-10, 10)
                info[f'dpo_arg_{component}_min'] = dpo_loss.min().item()
                info[f'dpo_arg_{component}_max'] = dpo_loss.max().item()
                info[f'dpo_arg_{component}_mean'] = dpo_loss.mean().item()
                info[f'dpo_arg_{component}_std'] = dpo_loss.std().item()
                dpo_loss = -1 * F.logsigmoid(dpo_loss)
                info[f'dpo_loss_{component}'] = dpo_loss.mean().item()
                loss += dpo_loss
            dpo_reg_coord = self.dpo_lambda_w * loss_dict_w['theta']['x'] + \
                            self.dpo_lambda_l * loss_dict_l['theta']['x']
            info[f'dpo_reg_coord'] = dpo_reg_coord.mean().item()
            loss += dpo_reg_coord
        else:
            raise ValueError(f'Unknown DPO mode: {self.dpo_mode}')

        if self.timestep_weights is not None:
            w_t = self.timestep_weights(t).squeeze()
            loss = w_t * loss

        loss = loss.mean(0)
        
        print(f'Loss is {loss}, info is {info}')

        return (loss, info) if return_info else loss

    def compute_loss_single_pair(self, ligand, pocket, t):
        # TODO: move somewhere else (like collate_fn)
        pocket = Residues(**pocket)

        # Center sample
        ligand, pocket = center_data(ligand, pocket)
        pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)

        # Noise
        z0_x = self.module_x.sample_z0(pocket_com, ligand['mask'])
        z0_h = self.module_h.sample_z0(ligand['mask'])
        z0_e = self.module_e.sample_z0(ligand['bond_mask'])
        zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask'])
        zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask'])
        zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask'])

        # Predict denoising
        sc_transform = self.get_sc_transform_fn(None, zt_x, t, None, ligand['mask'], pocket)

        pred_ligand, _ = self.dynamics(
            zt_x, zt_h, ligand['mask'], pocket, t,
            bonds_ligand=(ligand['bonds'], zt_e),
            sc_transform=sc_transform
        )

        # Reference model
        with torch.no_grad():
            ref_pred_ligand, _ = self.ref_dynamics(
                zt_x, zt_h, ligand['mask'], pocket, t,
                bonds_ligand=(ligand['bonds'], zt_e),
                sc_transform=sc_transform
            )

        # Compute L2 loss
        loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
        ref_loss_x = self.module_x.compute_loss(ref_pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)

        t_next = torch.clamp(t + self.train_step_size, max=1.0)

        loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
        ref_loss_h = self.module_h.compute_loss(ref_pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
        loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
        ref_loss_e = self.module_e.compute_loss(ref_pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)

        return {
            'theta': {
                'x': loss_x,
                'h': loss_h,
                'e': loss_e,
            },
            'ref': {
                'x': ref_loss_x,
                'h': ref_loss_h,
                'e': ref_loss_e,
            }
        }
