import torch
import torch.nn.functional as F
from torch.linalg import matrix_exp, eigvalsh
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

import pandas as pd
import numpy as np

def data_to_device(esm_if_batch, gt_protein, device):
    # coords, confidence, strs, tokens, padding_mask
    coords, confidence, _, _, padding_mask = esm_if_batch
    coords, confidence, padding_mask = coords.to(device), confidence.to(device), padding_mask.to(device)
    for key in gt_protein:
        gt_protein[key] = gt_protein[key].to(device)
    return coords, confidence, padding_mask, gt_protein


def get_scheduler(args, optimizer):
    if not args.opt_constant_only:    
        scheduler_warmup = LinearLR(optimizer, start_factor=0.1, \
                                    end_factor=1.0, total_iters=args.warmup_epochs)
        
        scheduler_decay = LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=(args.total_epochs-args.warmup_epochs))
        
        scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_decay], \
                                            milestones=[args.warmup_epochs])
    else:
        # set constant scheduler
        scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1., total_iters=args.total_epochs)
    return scheduler

def _matrix_sqrt(x: torch.Tensor) -> torch.Tensor:
    """
    Compute matrix square root using SVD for stability.
    """
    # SVD: x = U @ S @ V.T
    U, S, V = torch.linalg.svd(x)
    return U @ torch.diag(torch.sqrt(S)) @ V

def get_fid(generated_latents, real_latents):
    # [B, seq_len, D] -> [B, D]
    gen = generated_latents.mean(dim=1)        # [B, D]
    real = real_latents.mean(dim=1)            # [B, D]
    mu_gen = torch.mean(gen, dim=0)            # [D]
    mu_real = torch.mean(real, dim=0)          # [D]
    sigma_gen = torch.cov(gen.T)               # [D, D]
    sigma_real = torch.cov(real.T)             # [D, D]
    diff = mu_gen - mu_real
    trace_term = torch.sum(diff ** 2)
    cov_mean = _matrix_sqrt(sigma_gen @ sigma_real)
    fid = trace_term + torch.trace(sigma_gen + sigma_real - 2 * cov_mean)
    if torch.isnan(fid):
        print("FID is nan, returning 0.")
        return 0.0
    return fid.item()

def get_pos_wise_fid(generated_latents: torch.Tensor, real_latents: torch.Tensor) -> float:
    
    _, seq_len, _ = generated_latents.shape
    
    fid_scores = 0
    for position in range(seq_len):
        gen = generated_latents[:, position, :]
        real = real_latents[:, position, :]
        
        mu_gen = torch.mean(gen, dim=0)                             # [D]
        mu_real = torch.mean(real, dim=0)                           # [D]
        sigma_gen = torch.cov(gen.T)                                # [D, D]
        sigma_real = torch.cov(real.T)                              # [D, D]
    
        diff = mu_gen - mu_real
        
        trace_term = torch.sum(diff ** 2)
        cov_mean = _matrix_sqrt(sigma_gen @ sigma_real)
        fid = trace_term + torch.trace(sigma_gen + sigma_real - 2 * cov_mean)

        if torch.isnan(fid):
            print(f"FID is nan at position: {position}, Skipping this position.")
            continue
        fid_scores += fid.item()

    return fid_scores / seq_len

def eval_struct_embeds(args, struct_embeds):
    num_prots, prot_len, _ = struct_embeds.shape

    prot_df = pd.read_csv(args.recon_quality_file)
    prot_df = prot_df[(prot_df['seq_len'] == prot_len) & (prot_df['rmsd'] < 1) & (prot_df['tm'] > 0.9)]
    prots = prot_df['file_name'].values

    real_struct_embeds = []
    for prot in prots:
        prot_chain_dict = np.load(f'{args.latent_dir}/{prot}.npz')
        real_struct_embeds.append(torch.tensor(prot_chain_dict['latent']).mul_(0.1875))
    real_struct_embeds = torch.stack(real_struct_embeds).cuda(device=args.device)

    fid = get_fid(struct_embeds, real_struct_embeds)
    pos_wise_fid = get_pos_wise_fid(struct_embeds, real_struct_embeds)

    del real_struct_embeds
    return fid, pos_wise_fid
