import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import DataLoader
from tqdm import tqdm 
import numpy as np
import random
import wandb

from SDE_model.dataset import MDGenDataset
import torch.nn.functional as F
from SDE_model.rigid_utils import Rigid, Rotation
from SDE_model.utils import get_offsets
from SDE_model.parsing import parse_train_args
from SDE_model.transport.protein_sde import ProteinSDE
from SDE_model.model.config import ModelConfig
from auxiliary import calc_violation_loss
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from SDE_model.geometry import frames_torsions_to_atom14


def setup():
    dist.init_process_group("nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

def cleanup():
    dist.destroy_process_group()
    

def set_seed(seed=137):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def prep_batch(args, batch):
    rigids = Rigid(
        trans=batch['trans'],
        rots=Rotation(rot_mats=batch['rots'])
    )  # B, T, L
    B, T, L = rigids.shape

    if args.no_offsets:
        offsets = rigids.to_tensor_7()
    else:
        offsets = get_offsets(rigids[:, 0:1], rigids)
        
    #### make sure the quaternions have real part
    offsets[..., :4] *= torch.where(offsets[:, :, :, 0:1] < 0, -1, 1)  ##Get the offset of the quaternions
    frame_loss_mask = batch['mask'].unsqueeze(-1).expand(-1, -1, 7)  # B, L, 7
    
    batch['torsions'] = batch['torsions'][..., :3,:]
    batch['torsion_mask'] = batch['torsion_mask'][..., :3]

    
    torsion_loss_mask = batch['torsion_mask'].unsqueeze(-1).expand(-1, -1, -1, 2).reshape(B, L, 6)

    latents = torch.cat([offsets, batch['torsions'].view(B, T, L, 6)], -1)

    if args.supervise_all_torsions:
        torsion_loss_mask = torch.ones_like(torsion_loss_mask)
    elif args.supervise_no_torsions:
        torsion_loss_mask = torch.zeros_like(torsion_loss_mask)

    loss_mask = torch.cat([frame_loss_mask, torsion_loss_mask], -1)
      
    
    loss_mask = loss_mask.unsqueeze(1).expand(-1, T, -1, -1)
    
    cond_mask = torch.zeros(B, T, L, dtype=int, device=offsets.device)
    cond_mask[:, 0] = 1
    aatype_mask = torch.ones_like(batch['seqres'])
    
    backbone_indices = torch.tensor([0, 1, 2, 3], device= batch['arr'].device)
    batch['arr'] = torch.index_select(batch['arr'], dim=-2, index=backbone_indices)
    aatype_mask = torch.ones_like(batch['seqres'])
    aatype = torch.where(aatype_mask.bool(), batch['seqres'], 20)

    return {
        'arr' : batch['arr'],
        'rigids': rigids,
        'latents': latents,
        'loss_mask': loss_mask,
        'name': batch['name'],
        'seqres': batch['seqres'],
        'model_kwargs': {
            'start_frames': rigids[:, 0],
            'frames': rigids,
            'end_frames': rigids[:, -1],
            'mask': batch['mask'].unsqueeze(1).expand(-1, T, -1),
            'aatype': aatype,
            'x_cond': torch.where(cond_mask.unsqueeze(-1).bool(), latents, 0.0),
            'x_cond_mask': cond_mask, }}
    

def train_ddp(args):
    setup()
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    device = torch.device(f'cuda:{rank}')
    
    try:
        if args.wandb and rank == 0:
            wandb.login(key="xxxxx")
            wandb.init(
                project="TEMPO",
                name=args.run_name,
                config=args,)
        
        trainset = MDGenDataset(args, split=args.train_split)
        
        if args.overfit:
            valset = trainset    
        else:
            valset = MDGenDataset(args, split=args.val_split, repeat=args.val_repeat)

        train_sampler = DistributedSampler(trainset, num_replicas=world_size, rank=rank)
        val_sampler = DistributedSampler(valset, num_replicas=world_size, rank=rank)

        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            num_workers=0,
            sampler=train_sampler
        )
        

        val_loader = torch.utils.data.DataLoader(
            valset,
            batch_size=args.batch_size,
            num_workers=0,
            sampler=val_sampler
        )

        model_config = ModelConfig(args)
        model = ProteinSDE(model_config).to(device)
        if args.load_from_ckpt:
            pretrained_model = torch.load(args.pretrained_ckpt)
            model.load_state_dict(pretrained_model)
        model = DDP(model, device_ids=[rank])
        
        if rank == 0:
            os.makedirs(args.run_name, exist_ok=True) 
            config_filename = os.path.join(args.run_name, 'config.json')
            model_config.save_config(config_filename)
        
        if args.wandb and rank == 0:
            wandb.watch(model, log_graph=True, log="parameters")
        
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
        
        
        best_loss = float('inf')
        best_val_loss = float('inf')
        val_loss = float('inf')

        for epoch in range(args.epochs):
            train_sampler.set_epoch(epoch)
            val_sampler.set_epoch(epoch)
            
            noise_scale = get_noise_scale(epoch, args.epochs, initial_scale=0.1, final_scale=1.0, beta=5.0)
            epoch_loss = train_step(train_loader, model, args, epoch, optimizer, noise_scale, device)
            
            if (epoch + 1) % args.val_freq == 0:
                val_loss = valid_step(val_loader, model, args, epoch, noise_scale, device)
            
            if (epoch + 1) % args.ckpt_freq == 0 and rank == 0:
                print(f"Saving model checkpoint at epoch {epoch + 1}.")
                torch.save(model.module.state_dict(), os.path.join(args.run_name, f"checkpoint_epoch_{epoch + 1}.pth"))
            
            if (epoch_loss / len(train_loader)) <= best_loss and rank == 0:
                best_loss = min(best_loss, epoch_loss / len(train_loader)) 
                torch.save(model.module.state_dict(), os.path.join(args.run_name, f"checkpoint_epoch_best.pth"))
            
            if (val_loss / len(val_loader)) <= best_val_loss and rank == 0:
                best_val_loss = min(best_val_loss, val_loss / len(val_loader))
                torch.save(model.module.state_dict(), os.path.join(args.run_name, f"checkpoint_epoch_best_val.pth"))
            dist.barrier()
            
    except Exception as e:
        print(f"Error in rank {rank}: {str(e)}")
        raise e
    
    finally:
        cleanup()
        if rank == 0 and args.wandb:
            wandb.finish()


def mean_flat(x, mask):
    """
    Take the mean over all non-batch dimensions.
    """
    return torch.sum(x * mask, dim=list(range(1, len(x.size())))) / torch.sum(mask, dim=list(range(1, len(x.size()))))




def get_noise_scale(epoch, total_epochs, initial_scale=0.05, final_scale=0.01, beta=5.0):
    progress = epoch / total_epochs
    
    alpha = 1.0
    beta_param = beta * (1.0 - progress) + 1.0  

    x = np.random.beta(alpha, beta_param)

    # map the sampling range to [0.01, 0.05]
    noise_scale = 0.01 + (0.05 - 0.01) * x 

    return noise_scale



def train_step(data_loader, model, args, epoch, optimizer, noise_scale, device):
    model.train()
    epoch_loss = torch.tensor(0.0, device=device)
    aux_loss = torch.tensor(0.0, device=device)
    threshold = args.threshold
    rank = dist.get_rank()

    if rank == 0:
        iterator = tqdm(data_loader, desc=f'Epoch {epoch + 1}/{args.epochs}', unit='batch')
    else:
        iterator = data_loader

    for batch in iterator:
        if isinstance(batch, list):
            batch = {i: v for i, v in enumerate(batch)}  
        elif not isinstance(batch, dict):
            raise ValueError("Batch must be a list or a dictionary.")

        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} 
        prep = prep_batch(args, batch)  
        
        optimizer.zero_grad()
            
        indices = torch.arange(0, min(args.num_frames - threshold * 2 + 1, args.num_frames + 1), threshold) 
                        
        x = prep['latents'][:, 0:threshold, ...]
       
        mask = prep['model_kwargs']['mask'][:, 0:threshold, ...]
        batch_loss = torch.tensor(0.0, device=device)
        aux_batch_loss = torch.tensor(0.0, device=device)
        start_frames =  prep['model_kwargs']['start_frames']

        for step in indices:
            t = torch.full((x.size(0),), noise_scale)  
            t = t.to(device)
            pred = model.module.sde_forward(
                    x = x,
                    noise_scale = noise_scale,
                    t = t,
                    mask = mask,
                    start_frames = start_frames, 
                    end_frames=None,
                    x_cond=None, 
                    x_cond_mask=None, 
                    aatype=prep['model_kwargs']['aatype'], 
                    mode=args.mode,
            )
            
            
            gt = prep['latents'][:, (step+threshold):(step+2*threshold), ...]
            loss_mask = prep['loss_mask'][:, (step+threshold):(step+2*threshold) , ...]

            step_loss = mean_flat(((pred - gt) ** 2), loss_mask) 
            batch_loss = batch_loss + step_loss 


            if args.auxiliary:
                offsets_pred = pred[..., :7]
                torsions_pred = pred[..., 7:] 
                frames_pred = Rigid.from_tensor_7(offsets_pred, normalize_quats=True)
                B, T, L, _ = pred.shape
                
                torsions_pred = F.pad(torsions_pred, (0, 8), mode='constant', value=0)
                
                atom14_pred = frames_torsions_to_atom14(frames_pred, torsions_pred.view(B, T, L, 7, 2),
                                                prep['model_kwargs']['aatype'][:, None].expand(B, T, L)) ##
      
                atom4_pred = atom14_pred[..., :4,:]
                atom4_gt = prep['arr'][:, (step+threshold):(step+2*threshold), ...]
                mask = prep['model_kwargs']['mask'][:, (step+threshold):(step+2*threshold), ...]
                aux_step_loss = 0.2*calc_violation_loss(atom4_pred, mask)
                aux_batch_loss = aux_batch_loss +  aux_step_loss
                batch_loss  =  batch_loss +  aux_step_loss
            

            use_teacher_forcing = random.random() < args.sample_ratio
            if use_teacher_forcing:
                x = gt
                frames_pred = prep['model_kwargs']['frames'][:, (step+threshold):(step+2*threshold)]
                frames_pred._rots = frames_pred._rots[:,0,...]
                frames_pred._trans = frames_pred._trans[:,0,...]
                start_frames = frames_pred 
            else:
                x = pred
                offsets_pred_current = pred[..., :7]
                frames_pred = Rigid.from_tensor_7(offsets_pred_current, normalize_quats=True)
                frames_pred._rots = frames_pred._rots[:,0,...]
                frames_pred._trans = frames_pred._trans[:,0,...]
                start_frames = frames_pred
            
        batch_loss = batch_loss/len(indices)
        batch_loss = batch_loss.mean()
        
        
        epoch_loss += batch_loss
        batch_loss.backward()       
        optimizer.step()
        
        if args.auxiliary:
            aux_batch_loss = aux_batch_loss/len(indices)
            aux_batch_loss = aux_batch_loss.mean()
            aux_loss += aux_batch_loss
        
    dist.all_reduce(epoch_loss, op=dist.ReduceOp.SUM)
    epoch_loss = epoch_loss / dist.get_world_size()
    
    if args.auxiliary:
        dist.all_reduce(aux_loss, op=dist.ReduceOp.SUM)
        aux_loss = aux_loss / dist.get_world_size()

    if rank == 0:
        print(f"Epoch {epoch + 1}/{args.epochs} completed. Average train Loss: {epoch_loss / len(data_loader):.4f}")
        if args.auxiliary:
            print(f"Epoch {epoch + 1}/{args.epochs} completed. Average auxi Loss: {aux_loss / len(data_loader):.4f}")
        if args.wandb:
            wandb.log({"Training Loss": epoch_loss / len(data_loader), "Epoch": epoch + 1})
            if args.auxiliary:
                wandb.log({"aux": aux_loss / len(data_loader), "Epoch": epoch + 1})
    
    return epoch_loss

def valid_step(data_loader, model, args, epoch, noise_scale, device):    
    model.eval()
    val_epoch_loss = torch.tensor(0.0, device=device)
    val_aux_loss = torch.tensor(0.0, device=device)
    threshold = args.threshold
    rank = dist.get_rank()
    
    with torch.no_grad():
        if rank == 0:
            iterator = tqdm(data_loader, desc=f'Validation Epoch {epoch + 1}', unit='batch')
        else:
            iterator = data_loader

        for batch in iterator:
            if isinstance(batch, list):
                batch = {i: v for i, v in enumerate(batch)}  
            elif not isinstance(batch, dict):
                raise ValueError("Batch must be a list or a dictionary.")

            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} 
            prep = prep_batch(args, batch)  
            indices = torch.arange(0, min(args.num_frames - threshold * 2 + 1, args.num_frames + 1), threshold)  
            batch_loss = torch.tensor(0.0, device=device)
            aux_batch_loss = torch.tensor(0.0, device=device)
            
            x = prep['latents'][:, 0:threshold, ...]
            mask = prep['model_kwargs']['mask'][:, 0:threshold, ...]
            start_frames =  prep['model_kwargs']['start_frames']
         
            
            for step in indices:
                t = torch.full((x.size(0),), noise_scale)  
                t = t.to(device)
                pred = model.module.sde_forward(
                    x = x,
                    noise_scale = noise_scale,
                    t = t,
                    mask = mask,
                    start_frames = start_frames, 
                    end_frames=None,
                    x_cond=None, 
                    x_cond_mask=None, 
                    aatype=prep['model_kwargs']['aatype'], 
                    mode=args.mode,
                )
                
                gt = prep['latents'][:, (step+threshold):(step+2*threshold), ...]
                loss_mask = prep['loss_mask'][:, (step+threshold):(step+2*threshold) , ...]
                step_loss = mean_flat(((pred - gt) ** 2), loss_mask) 
                batch_loss = batch_loss + step_loss
                
                if args.auxiliary:
                    offsets_pred = pred[..., :7]
                    torsions_pred = pred[..., 7:] 
                    frames_pred = Rigid.from_tensor_7(offsets_pred, normalize_quats=True)
                    B, T, L, _ = pred.shape
                    
                    torsions_pred = F.pad(torsions_pred, (0, 8), mode='constant', value=0)
                    
                    atom14_pred = frames_torsions_to_atom14(frames_pred, torsions_pred.view(B, T, L, 7, 2),
                                                    prep['model_kwargs']['aatype'][:, None].expand(B, T, L)) ##
        
                    atom4_pred = atom14_pred[..., :4,:]
                    mask = prep['model_kwargs']['mask'][:, (step+threshold):(step+2*threshold), ...]
                    aux_step_loss = 0.2*calc_violation_loss(atom4_pred, mask)
                    aux_batch_loss = aux_batch_loss +  aux_step_loss
                    batch_loss  =  batch_loss +  aux_step_loss
                
                
                use_teacher_forcing = random.random() < args.sample_ratio
                if use_teacher_forcing:
                    x = gt
                    frames_pred = prep['model_kwargs']['frames'][:, (step+threshold):(step+2*threshold)]
                    frames_pred._rots = frames_pred._rots[:,0,...]
                    frames_pred._trans = frames_pred._trans[:,0,...]
                    start_frames = frames_pred  
                else:
                    x = pred
                    offsets_pred_current = pred[..., :7]
                    frames_pred = Rigid.from_tensor_7(offsets_pred_current, normalize_quats=True)
                    frames_pred._rots = frames_pred._rots[:,0,...]
                    frames_pred._trans = frames_pred._trans[:,0,...]
                    start_frames = frames_pred
                    
                    
            batch_loss = batch_loss/len(indices)
            batch_loss = batch_loss.mean()  
            
            val_epoch_loss += batch_loss
            
            if args.auxiliary:
                aux_batch_loss = aux_batch_loss/len(indices)
                aux_batch_loss = aux_batch_loss.mean()
                val_aux_loss += aux_batch_loss
                

        dist.all_reduce(val_epoch_loss, op=dist.ReduceOp.SUM)
        val_epoch_loss = val_epoch_loss / dist.get_world_size()
        
        if args.auxiliary:
            dist.all_reduce(val_aux_loss, op=dist.ReduceOp.SUM)
            val_aux_loss = val_aux_loss / dist.get_world_size()
            
        
        if rank == 0:
            avg_val_loss = val_epoch_loss / len(data_loader)
            print(f"Epoch {epoch+1}/{args.epochs}, Validation Loss: {avg_val_loss:.4f}")
            if args.auxiliary:
                print(f"Epoch {epoch+1}/{args.epochs}, Validation aux Loss: {val_aux_loss / len(data_loader):.4f}")
                
            if args.wandb:
                wandb.log({"Validation Loss": avg_val_loss, "Epoch": epoch + 1})
                if args.auxiliary:
                    wandb.log({"Validation aux Loss": val_aux_loss / len(data_loader), "Epoch": epoch + 1})
                    
    return val_epoch_loss

                
        
if __name__ == "__main__":
    args = parse_train_args()
    args.local_rank = int(os.environ["LOCAL_RANK"])
    train_ddp(args)