import torch
from tqdm import tqdm
import os
import sys
from models.model_minimind import MiniMindConfig
from models.model_minimind_final import MiniMindFinal
sys.path.append(os.path.abspath(os.path.dirname('__file__')))
from utils.dataloader_Public import create_folder_if_not_exists
import torch.nn.functional as F
import numpy as np
import math
import pandas as pd
import random
import matplotlib.pyplot as plt
from utils.earth_computation import rad_to_deg, deg_to_rad, deg_to_vec, haversine_distance
from utils.metrics import frechet_distance, curvature_calculation
from utils.dataloader_Public import PublicDataset
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import numpy as np
from math import sqrt

def set_global_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
set_global_seed(42)

def evaluate(data_path, model_sft_path, model_train_path, faiss_index_path, seq_len = 128, pred_len = 16):
    print('model initialization...')
    lm_config = MiniMindConfig(flash_attn=False)
    
    model = MiniMindFinal(lm_config, model_train_path, model_sft_path, faiss_index_path).to(device)
    model.eval()

    dataset = PublicDataset(data_path=data_path, seq_len=seq_len, pred_len=pred_len, stride=7)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=False)
    avg_loss = 0.0
    afd = 0.0  # Frechet distance
    acvt = 0.0  # Total curvature difference
    cnt = 0  # Sample count
    for batch in tqdm(dataloader):
        input_ids = batch["input_ids"].to(device)  # [1, seq_len, 4] - already in radians
        output_ids = batch["output_ids"].to(device)  # [1, seq_len, 4]
        
        cnt += input_ids.size(0)
        
        true_out = output_ids[:, seq_len-1:seq_len-1+pred_len, :2]  # [pred_len, 2]
                
        # Build four-channel input
        Y_in = input_ids[:, :seq_len, :]  # [1, seq_len, 4]
        
        # Cascaded inference
        predictions = []
        current_input = Y_in.clone()

        for i in range(pred_len):
        
            input_seq = current_input.clone()  # [1, seq_len, 4]
        
            with torch.no_grad():
                # Public data doesn't have target port information, get model output
                new_position, pred_vol_vec, displacement, pred_coord, pred_tgt = model(input_seq)
                
                last_point = new_position[:, -1, :].clone()  # [4]
                
                last_point[:, :2] = (displacement[:, -1, :] + current_input[:, -1, :2]).squeeze().clone()
            
            predictions.append(rad_to_deg(last_point[:, :2]).cpu().numpy())
            
            last_point = last_point[:, :4].clone()
            # print(current_input.shape, last_point.unsqueeze(1).shape)
            current_input = torch.cat((current_input, last_point.unsqueeze(1)), dim=1)  # [bs, seq_len+1, 4]
            
            
        predictions = np.array(predictions)  # [pred_len+1, 2]
        
        # Calculate loss
        # critria = torch.nn.MSELoss()
        # Ensure current_input and true_out shapes match
        pred_input = current_input[:,-min(pred_len, current_input.shape[1]):, :2].squeeze()
        # print(pred_input.shape)
        true_target = true_out[:, -min(pred_len, current_input.shape[1]):, :2]
        assert pred_input.shape == true_target.shape
        # avg_loss += critria(pred_input, true_target).item()

        # print(f"Position Loss: {avg_loss/(cnt+1)}") # positional loss
        msep_per_traj = ((pred_input - true_target) ** 2).sum(dim=2).mean(dim=1)

        # # Accumulate "MSE per trajectory"
        avg_loss += msep_per_traj.sum().item()
        
        afd += frechet_distance(rad_to_deg(pred_input), rad_to_deg(true_target)).sum().item()
        # print(f"Frechet Distance: {afd/(cnt+1)}") # frechet distance
        acvt += torch.mean((curvature_calculation(pred_input) - curvature_calculation(true_target)) ** 2).item()
        
        true_target = rad_to_deg(true_target)
        # print(true_target[0])
    

    print(f"Total {cnt} trajectories")
    print(f"Average Loss: {avg_loss/cnt}")   
    print(f"Average Frechet Distance: {afd/cnt}")
    print(f"Average Curvature: {acvt/cnt}")
    return {
        "count": cnt,
        "avg_loss": avg_loss/cnt if cnt else float('nan'),
        "avg_frechet": afd/cnt if cnt else float('nan'),
        "avg_curvature": acvt/cnt if cnt else float('nan'),
    }

if __name__ == '__main__':
    evaluate(
            ['data/210238000.csv',
            'data/210279000.csv',
            'data/356285000.csv',
            'data/414062000.csv',
            'data/414066000.csv',
            'data/636015239.csv'], 
             "weights_pretrain/830_statedict_0.16575811230219328.pth",
             "weights_sft/28_statedict_light_96.86.pth",
             "enrolled_trajectory.npy",
             seq_len=288, pred_len=144)
