import os
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

from libs.class_id_map import get_id2class_map
from libs.metric import AverageMeter, BoundaryScoreMeter, ScoreMeter, IOUMeter
from libs.postprocess import PostProcessor
from tqdm import tqdm
from prompt.tools import (segment_video_labels, gen_label, gen_label_split,
                          generate_segment_features,generate_split_features,
                          create_logits, split_feature, split_gt, split_gt_feature, split_mixed_class)
from prompt.text_prompt import text_prompt_for_clip
from libs.loss_fn.curve import CurvatureLoss
from libs.loss_fn.cross import TSConstraintLoss, RobustTSConstraint, KLTSConstraint

from libs.tsne import plot_tsne
import copy
import pandas as pd
import json

curve_generator = CurvatureLoss(q=10)

perturb_eval = False
examplers = None
examplers_batch = 400
curvature_batch = 0


def train(
    train_loader: DataLoader,
    model: nn.Module,
    model_text: nn.Module,
    class_text_list,
    joint_text_list,
    criterion_cls: nn.Module,
    criterion_bound: nn.Module,
    criterion_contrast: nn.Module,
    lambda_bound_loss: float,
    optimizer: optim.Optimizer,
    dataset_name,
    device, output_device,
    criterion_curve = None,
    batch_cnt = 0,
    alpha = 1.0,
    beta = 0.5,
    cross_criterion = TSConstraintLoss(),
    scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
) -> float:
    global examplers
    losses = AverageMeter("Loss", ":.4e")
    cross_losses = AverageMeter("Cross Loss", ":.4e")
    div_loss = AverageMeter("Div Loss", ":.4e")
    # switch training mode
    model.train()

    for sample in tqdm(train_loader):
        loss = 0.0
        x = sample["feature"]
        t = sample["label"]
        b = sample["boundary"]
        mask = sample["mask"]
        


        x = x.to(output_device)
        t = t.to(output_device)
        b = b.to(output_device)
        mask = mask.to(output_device)
        joint_text_list = joint_text_list.to(output_device)

        optimizer.zero_grad()

        batch_size = x.shape[0]
        joint_text_embedding = None
        if hasattr(model, 'text_token') != True or model.text_token == None:
            joint_text_embedding = model_text(joint_text_list).float()
            output_cls, output_bound, output_feature, output_feature_split, logit_scale = model(x, mask, joint_text_embedding)
            
        # compute output and loss
        else:
            output_cls, output_bound, output_feature, output_feature_split, logit_scale, dloss = model(x, mask, joint_text_embedding)
            loss += dloss
        #Action-text pairs
        t_segment = segment_video_labels(t)

        label =  [i[0] for seg in t_segment for i in seg]

        label_g = gen_label(label)
        num_class = len(class_text_list)
        texts = list()
        for single_label in label:
            text_item = class_text_list[single_label].unsqueeze(dim=0)
            texts.append(text_item)

        texts = torch.cat(texts).cuda(output_device)
        text_embedding = model_text(texts).float()
        

        

        action_embeddings = []
        
        if isinstance(output_feature, list):
            for i in range(len(output_feature)):
                
                action_embedding = generate_segment_features(output_feature[i], t_segment, output_device)
                action_embeddings.append(action_embedding)
       
        gt_split, feature_split = split_mixed_class(t_segment,2)

        feature_split_embedding = generate_split_features(output_feature_split, feature_split, output_device)

        text_split = text_prompt_for_clip(gt_split, dataset_name, "simple").cuda(output_device)

        text_split_embedding = model_text(text_split).float()

        label_split_g = gen_label_split(gt_split)

        


        
        # Action segmentation loss
        if isinstance(output_cls, list):
            n = len(output_cls)
            # print(n)
            for out in output_cls:
               
                loss += criterion_cls(out, t, x) / n
        else:
            loss += criterion_cls(output_cls, t, x)

        # boundary regression loss
        if isinstance(output_bound, list):
            n = len(output_bound)
            for out in output_bound:
                loss += lambda_bound_loss * criterion_bound(out, b, mask) / n
        else:
            loss += lambda_bound_loss * criterion_bound(output_bound, b, mask)

        # action-text contrastive loss
        if isinstance(action_embeddings, list):
            for i in range(len(action_embeddings)):
                logits_per_image, logits_per_text = create_logits(action_embeddings[i], text_embedding, logit_scale[0])
                ground_truth = torch.tensor(label_g, dtype=action_embedding.dtype, device=output_device)

                loss_imgs = criterion_contrast(logits_per_image, ground_truth)
                loss_texts = criterion_contrast(logits_per_text, ground_truth)

                loss += 0.8 * ((loss_imgs + loss_texts) / 2)

        # clip-text contrastive loss
        logits_per_image, logits_per_text = create_logits(feature_split_embedding, text_split_embedding,
                                                          logit_scale[1])
        ground_truth = torch.tensor(label_split_g, dtype=feature_split_embedding.dtype, device=output_device)

        loss_imgs = criterion_contrast(logits_per_image, ground_truth)
        loss_texts = criterion_contrast(logits_per_text, ground_truth)

        loss += 0.5 * ((loss_imgs + loss_texts) / 2)

       
        # div loss
        
        if model.div_loss is not None:
            loss += model.div_loss
            div_loss.update(model.div_loss, batch_size)
        

        #CROSS CONSTRAINT LOSS
       
        if batch_cnt >= curvature_batch:
            cross_loss = loss.item()
            if isinstance(output_feature, list):
                n = len(output_feature)
                for i in range(len(output_feature)):
                    _, curve, _ = curve_generator.curvature_estimation(output_feature[i])
                    
                    if isinstance(output_bound, list):
                        m = len(output_bound)
                        for j in range(len(output_bound)):
                            # wo true boundary
                            # loss += cross_criterion(curve.unsqueeze(1), output_bound[j], mask) / (n*m)
                            # w true boundary
                            loss += cross_criterion(curve.unsqueeze(1), output_bound[j], mask, b) / (n*m)
                    else:
                        # wo true boundary
                        # loss += cross_criterion(curve.unsqueeze(1), output_bound, mask) / n
                        # w true boundary
                        loss += cross_criterion(curve.unsqueeze(1), output_bound, mask, b) / n

            else:
                _, curve, _ = curve_generator.curvature_estimation(output_feature)
                # wo true boundary
                # loss += cross_criterion(curve.unsqueeze(1), output_bound, mask)
                # w true boundary
                loss += cross_criterion(curve.unsqueeze(1), output_bound, mask, b)
            cross_loss = loss.item() - cross_loss
            cross_losses.update(cross_loss, batch_size)
        
        losses.update(loss.item(), batch_size)


        loss.backward()
        optimizer.step()
    if scheduler is not None:
        scheduler.step()
        print("Learning rate: ", optimizer.param_groups[0]['lr'])
    print("cross loss: ", cross_losses.avg)
    print("div loss: ", div_loss.avg)
    return losses.avg


def validate(
    val_loader: DataLoader,
    model: nn.Module,
    model_text: nn.Module,
    joint_text_list,
    criterion_cls: nn.Module,
    criterion_bound: nn.Module,
    lambda_bound_loss: float,
    device,output_device,
    dataset: str,
    dataset_dir: str,
    iou_thresholds: Tuple[float],
    boundary_th: float,
    tolerance: int,
    refinement_method: Optional[str] = None
) -> Tuple[float, float, float, float, float, float, float, float, str]:
    losses = AverageMeter("Loss", ":.4e")
    postprocessor = PostProcessor(refinement_method, boundary_th)
    scores_cls = ScoreMeter(
        id2class_map=get_id2class_map(dataset, dataset_dir=dataset_dir),
        iou_thresholds=iou_thresholds,
    )
    
    scores_bound = BoundaryScoreMeter(
        tolerance=tolerance, boundary_threshold=boundary_th
    )

    scores_after_refinement = ScoreMeter(
        id2class_map=get_id2class_map(dataset, dataset_dir=dataset_dir),
        iou_thresholds=iou_thresholds,
    )

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for sample in tqdm(val_loader):
            x = sample["feature"]
            t = sample["label"]
            b = sample["boundary"]
            mask = sample["mask"]

            x = x.to(output_device)
            t = t.to(output_device)
            b = b.to(output_device)
            mask = mask.to(output_device)
            joint_text_list = joint_text_list.to(output_device)

            batch_size = x.shape[0]

            joint_text_embedding = model_text(joint_text_list).float()

            # compute output and loss
            output_cls, output_bound, _, _ = model(x, mask, joint_text_embedding)
            # print(query_cls)

            loss = 0.0
            loss += criterion_cls(output_cls, t, x)
            loss += lambda_bound_loss * criterion_bound(output_bound, b, mask)

            # measure accuracy and record loss
            losses.update(loss.item(), batch_size)

            # calcualte accuracy and f1 score
            output_cls = output_cls.to("cpu").data.numpy()
            output_bound = output_bound.to("cpu").data.numpy()

            t = t.to("cpu").data.numpy()
            b = b.to("cpu").data.numpy()
            mask = mask.to("cpu").data.numpy()

            refined_output_cls = postprocessor(
                output_cls, boundaries=output_bound, masks=mask
            ) 
            # update score
            scores_cls.update(output_cls, t, output_bound, mask) #The result of not utilizing boundary branch
            
            scores_bound.update(output_bound, b, mask)
            scores_after_refinement.update(refined_output_cls, t) #The result of utilizing boundary branch

    before_refinement_cls_acc, before_refinement_edit_score, before_refinement_segment_f1s = scores_cls.get_scores()
    cls_acc, edit_score, segment_f1s = scores_after_refinement.get_scores()
    bound_acc, precision, recall, bound_f1s = scores_bound.get_scores()

    return (
        losses.avg,
        cls_acc,
        edit_score,
        segment_f1s,
        bound_acc,
        precision,
        recall,
        bound_f1s,
    )


from libs.measures import curvature_estimation
import matplotlib.pyplot as plt
import os
import numpy as np

def evaluate(
    val_loader: DataLoader,
    model: nn.Module,
    model_text,
    joint_text_list,
    device: str,
    boundary_th: float,
    dataset: str,
    dataset_dir: str,
    iou_thresholds: Tuple[float],
    tolerance: float,
    result_path: str,
    config: str,
    refinement_method: Optional[str] = None,
    save_curvature: bool = False,
    save_stne: bool = False,
    action_text_list = None,
) -> None:
    plt.switch_backend('Agg')  # Use Agg backend for headless environments
    id2class_map = get_id2class_map(dataset, dataset_dir=dataset_dir)
    postprocessor = PostProcessor(refinement_method, boundary_th)
    scores_before_refinement = ScoreMeter(
        id2class_map=id2class_map,
        iou_thresholds=iou_thresholds,
    )
    scores_bound = BoundaryScoreMeter(
        tolerance=tolerance, boundary_threshold=boundary_th
    )
    scores_after_refinement = ScoreMeter(
        id2class_map=id2class_map,
        iou_thresholds=iou_thresholds,
    )
    IOU_before_refinement = IOUMeter(
        id2class_map=id2class_map,
        max_len=9000,
    )
    
    model.eval()
    
    
    save_tsne_path = os.path.join(result_path, "tsne")
    if not os.path.exists(save_tsne_path):
        os.makedirs(save_tsne_path)
    with torch.no_grad():
        for batch_idx, sample in enumerate(tqdm(val_loader)):

            x = sample["feature"]
            t = sample["label"]
            b = sample["boundary"]
            mask = sample["mask"]
            path = sample['feature_path'][0]
            
            x = x.to(device)
            t = t.to(device)
            b = b.to(device)
            
            
            mask = mask.to(device)
            joint_text_list = joint_text_list.to(device)

            joint_text_embedding = model_text(joint_text_list).float()
            action_text_embedding = None
            if action_text_list is not None:
                action_text_list = action_text_list.to(device)
                action_text_embedding = model_text(action_text_list).float()
            output_cls, output_bound, output_feature, bound_feature = model(x, mask, joint_text_embedding)
           
            output_cls = output_cls.to("cpu").data.numpy()
            output_bound = output_bound.to("cpu").data.numpy()
            x_np = x.to("cpu").data.numpy()
            t_np = t.to("cpu").data.numpy()
            b_np = b.to("cpu").data.numpy()
            mask_np = mask.to("cpu").data.numpy()
            feature_np = output_feature.to("cpu").data.numpy().transpose(0, 2, 1)
            feature_split_np = bound_feature.to("cpu").data.numpy().transpose(0, 2, 1)
            output_bound_np = output_bound.transpose(0, 2, 1)
            curv_np = []

          
            N, C, T = output_feature.shape
            output_feature_circle = output_feature.transpose(1, 2).reshape(N*T, C)
            t_circle = t.reshape(N*T, )
            if save_stne:
                plot_tsne(output_feature_circle, t_circle, os.path.join(save_tsne_path, f"tsne_action_{batch_idx}.png"))
            for sample_idx in range(x.size(0)):
                # seq_len from valid frames in this sample
                seq_len = int(mask_np[sample_idx].sum())
                # cur_feature: per-frame classification feature (n_frames, n_features)
                if True:

                    cur_feature = torch.from_numpy(feature_np[sample_idx][:seq_len])
                    gt_cls = t_np[sample_idx][:seq_len].astype(int)
                    pred_cls = np.argmax(output_cls[sample_idx][:,:seq_len],axis=0)
                
                
            for sample_idx in range(x.size(0)):
                if sample_idx > 0:
                    continue

                seq_len = mask_np[sample_idx].sum().astype(int)
                cur_feature = torch.from_numpy(feature_np[sample_idx][:seq_len])
                cur_feature_split = torch.from_numpy(feature_split_np[sample_idx][:seq_len])
                gt_boundary = b_np[sample_idx][:seq_len]
                gt_boundary = gt_boundary[0]
                
                gt_cls = t_np[sample_idx][:seq_len]
                out_cls = output_cls[sample_idx][:seq_len].transpose(1, 0)
                
                prob_boundary = torch.from_numpy(output_bound_np[sample_idx][:seq_len])
                curv, curv_reciprocal, movavg = curvature_estimation(cur_feature, 10, 0)
                cross_ = TSConstraintLoss()
                curv_reciprocal = cross_._remap(torch.from_numpy(curv_reciprocal)).to("cpu").data.numpy()
                curv_np.append(curv_reciprocal)
                
                

            curv_np = np.array(curv_np)
            # Original processing logic
            refined_output_cls = postprocessor(output_cls, boundaries=output_bound, masks=mask_np)
            scores_before_refinement.update(output_cls, t_np)
            scores_bound.update(output_bound, b_np, mask_np)
            scores_after_refinement.update(refined_output_cls, t_np)
            IOU_before_refinement.update(output_cls, t_np, output_bound, curv_np)
            output_bound = 1 / (1 + np.exp(-output_bound))
            

            
    with open(os.path.join(result_path, "evaluation_config.txt"), "w") as f:
        f.write(str(config))       
    print("Before refinement:", scores_before_refinement.get_scores())
    print("Boundary scores:", scores_bound.get_scores())
    print("After refinement:", scores_after_refinement.get_scores())
    
   
    scores_before_refinement.save_scores(os.path.join(result_path, "test_as_before_refine.csv"))
    scores_before_refinement.save_confusion_matrix(os.path.join(result_path, "test_c_matrix_before_refinement.csv"))
    scores_bound.save_scores(os.path.join(result_path, "test_br.csv"))
    scores_after_refinement.save_scores(os.path.join(result_path, "test_as_after_majority_vote.csv"))
    scores_after_refinement.save_confusion_matrix(os.path.join(result_path, "test_c_matrix_after_majority_vote.csv"))
    
