import torch
import torch.nn as nn
import torch.nn.init as init 
import torch.nn.functional as F
import json
import os 
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.multiprocessing as mp
import transformers
from transformers import AutoTokenizer, AutoModelForMaskedLM
import random
from modelscope import AutoTokenizer
import gc
import ast
from torch.cuda.amp import autocast, GradScaler
import torch.optim as optim
from datasets import Dataset as finetune_dataset
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer

device0 = torch.device("cuda:4")
device1 = torch.device("cuda:5")
device2 = torch.device("cuda:6")
device3 = torch.device("cuda:7")
model_id = "/hub/models--meta-llama--Llama-3.2-3B/snapshots/13afe5124825b4f3751f836b40dafda64c1ed062"
pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.float16}, device_map="cpu", return_full_text=False)
class CrossAttentionWithFFN(nn.Module):
    def __init__(self, d_model=768, num_heads=16, d_ff=2048, dropout=0.1):
        super().__init__()
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True  # 
        )
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                init.zeros_(m.bias)

    def forward(self, x, encoder_output, src_mask=None):
        # 交叉注意力
        attn_output, _ = self.cross_attention(
            query=x,
            key=encoder_output,
            value=encoder_output,
            key_padding_mask=src_mask  # 
        )
        attn_output = self.dropout(attn_output)
        x = x + attn_output  # 
        x = self.norm1(x)    # 

     
        ffn_output = self.ffn(x)
        ffn_output = self.dropout(ffn_output)
        x = x + ffn_output   #
        x = self.norm2(x)    
        return x
    
class Model1(nn.Module):
    def __init__(self, num_blocks=4, input_shape=197*768):
        super().__init__()
        self.mlp1 = nn.ModuleList(
            [CrossAttentionWithFFN() for _ in range(num_blocks)]
        )
        self.linear = nn.Linear(input_shape, 1)
        
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                init.zeros_(m.bias)
        
    def forward(self, fv1, fv2=None, inference=False, train_w_text=False):
        if train_w_text is False:
            def compute_boundary(fv):
                B = fv.shape[0]
                candidate_scores = [] 
                for i in range(1, 23):
                    aggregated = torch.mean(fv[:, :i, :, :], dim=1).squeeze(1)  # [B, 1, 768]
                    current = fv[:, i, :, :].squeeze(1)  # [B, 1, 768]
                    x = current
                    for layer in self.mlp1:
                        x = layer(x, aggregated)
                    mlp1_out = x
                    mlp1_out_flat = mlp1_out.view(B, -1)          # [B, 1*768]
                    score = self.linear(mlp1_out_flat)            # [B, 1]
                    candidate_scores.append(score)
                candidate_scores = torch.cat(candidate_scores, dim=1)
                k = 1
                candidate_probs = torch.sigmoid(k * (candidate_scores - 0.5))
                if not inference:
                    norm_probs = candidate_probs / (candidate_probs.sum(dim=1, keepdim=True) + 1e-6)
                    indices = torch.arange(1, 23, device=fv.device, dtype=candidate_probs.dtype).unsqueeze(0)
                    boundary_cont = torch.sum(norm_probs * indices, dim=1, keepdim=True)  # [B, 1]
                    return boundary_cont, candidate_probs
                else:
                    norm_probs = candidate_probs / (candidate_probs.sum(dim=1, keepdim=True) + 1e-6)
                    max_probs, max_idx = torch.max(norm_probs, dim=1, keepdim=True)  # [B, 1]
                    boundary_idx = torch.where(max_probs > 0.5, max_idx.float() + 1.0, torch.full_like(max_idx.float(), 23.0))
                    return boundary_idx, candidate_probs
            first_fv = fv1[:, :23, :, :]  #
            second_fv = fv1[:, 23:, :, :]  
            pred_end0, cand_prob0 = compute_boundary(first_fv)
            pred_end1, cand_prob1 = compute_boundary(second_fv)
            return pred_end0, pred_end1

        else:
            
            vision_feature = fv1
            text_feature = fv2
            if len(text_feature) == 4:
                text_feature = text_feature.mean(dim=3)
            x = text_feature
            for layer in self.mlp1:
                x = layer(x, vision_feature)
            mlp1_output = x
            return mlp1_output
    
class Mlp2(nn.Module):
    def __init__(self, input_shape=197*768):
        super().__init__()
        self.S_attention = self.mlp1 = nn.ModuleList(
            [CrossAttentionWithFFN() for _ in range(4)]
        )
        
        self.middle_layer = CrossAttentionWithFFN()
        
        self.mlp1 = nn.ModuleList(
            [CrossAttentionWithFFN() for _ in range(4)]
        )
        self.linear = nn.Linear(input_shape, 1)
        self.sigmoid = nn.Sigmoid()
        self._initialize_weights()


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                init.zeros_(m.bias)
                
class Mlp2(nn.Module):
    def __init__(self, input_shape=1*768):
        super().__init__()
        self.S_attention = self.mlp1 = nn.ModuleList(
            [CrossAttentionWithFFN() for _ in range(4)]
        )
        
        self.middle_layer = CrossAttentionWithFFN()
        
        self.mlp1 = nn.ModuleList(
            [CrossAttentionWithFFN() for _ in range(4)]
        )
        self.linear = nn.Linear(input_shape, 1)
        self.sigmoid = nn.Sigmoid()
        self._initialize_weights()


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                init.zeros_(m.bias)
                
    def forward(self, fv_concat, train=True):
        if train:
            B = fv_concat.shape[0]
            
            fv_1 = fv_concat[:, 0, :, :] # [1, 197, 768]
            fv_2 = fv_concat[:, 1, :, :]
            fv_3 = fv_concat[:, 2, :, :]
            fv_4 = fv_concat[:, 3, :, :]
            fv_5 = fv_concat[:, 4, :, :]
            
            fv_bar = fv_concat.mean(dim=1)[:, :, :]
            
            def self_attention(x):
                for layer in self.S_attention:
                    x = layer(x, x)
                return x
            
            new_fv_1 = self_attention(fv_1)
            new_fv_2 = self_attention(fv_2)
            new_fv_3 = self_attention(fv_3)
            new_fv_4 = self_attention(fv_4)
            new_fv_5 = self_attention(fv_5)

            final_fv_1 = self.middle_layer(fv_bar, new_fv_1)
            final_fv_2 = self.middle_layer(fv_bar, new_fv_2)
            final_fv_3 = self.middle_layer(fv_bar, new_fv_3)
            final_fv_4 = self.middle_layer(fv_bar, new_fv_4)
            final_fv_5 = self.middle_layer(fv_bar, new_fv_5)
            
            
            def judge_cls(video_0, video_1):
                for layer in self.mlp1:
                    video_1 = layer(video_1, video_0)
                video_1 = video_1.view(B, -1)
                cls = self.sigmoid(self.linear(video_1))
                return cls
            
            cls_0 = judge_cls(final_fv_1, final_fv_2)
            cls_1 = judge_cls(final_fv_2, final_fv_3)
            cls_2 = judge_cls(final_fv_3, final_fv_4)
            cls_3 = judge_cls(final_fv_4, final_fv_5)
            
            return cls_0, cls_1, cls_2, cls_3
        else:
            
            fv_bar = fv_concat.mean(dim=0).unsqueeze(0)
            fv_0 = fv_concat[0, :].unsqueeze(0)
            fv_1 = fv_concat[1, :].unsqueeze(0)
            def self_attention(x):
                for layer in self.S_attention:
                    x = layer(x, x)
                return x
            
            new_fv_0 = self_attention(fv_0)
            new_fv_1 = self_attention(fv_1)

            final_video_0 = self.middle_layer(fv_bar, new_fv_0)
            final_video_1 = self.middle_layer(fv_bar, new_fv_1)

            def judge_cls(video_0, video_1):
                for layer in self.mlp1:
                    video_1 = layer(video_1, video_0)
                video_1 = video_1.view(1, -1)
                cls = self.sigmoid(self.linear(video_1))
                return cls
            
            cls = judge_cls(final_video_0, final_video_1)
            return cls
            
class Stage1_dataset(Dataset):
    def __init__(self, json_path):
        super().__init__()
        self.data = json.load(open(json_path))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        
        vision_tensor = torch.load(sample['pooling_tensor'], map_location='cpu')
        text_tensor = torch.load(sample['med_tensor'], map_location='cpu')
        
        return {"vision_tensor":vision_tensor, "text_tensor":text_tensor}
    
class Stage2_dataset(Dataset):
    def __init__(self, json_path):
        super().__init__()
        self.data = json.load(open(json_path))
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]

        tensor = torch.load(sample['tensor_path']).detach()
        
        label_0 = sample['label'][0]
        label_1 = sample['label'][1]
        label_2 = sample['label'][2]
        label_3 = sample['label'][3]
        return {"tensor":tensor, "label_0":label_0, 'label_1':label_1, 'label_2':label_2, 'label_3':label_3}
        
def compute_cosine_loss(tensor1, tensor2):
      
        cos_sim = F.cosine_similarity(tensor1, tensor2, dim=-1)

        loss = 1 - cos_sim.mean()
        return loss

def compute_bce_loss(cls_output, target):
    return F.binary_cross_entropy(cls_output, target)



from tqdm import tqdm
def generate_stage2_result(model, pipeline):
    model.eval()
    pred_result = {}
    train_data_root = "/tensor_save_path"
    train_data = json.load(open("/"))
    print(f'generate stage2 result step1--------')
    
    for name in tqdm(train_data):
        train_data_path = os.path.join(train_data_root, name)
        if not os.path.exists(train_data_path):
            print(name)
            continue
        pred_result[name] = {}
        pred_result[name]['number'] = []
        length = len(os.listdir(train_data_path))
        for i in range(length):                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
            if i == length - 1:
                continue
            tensor_i_path = os.path.join(train_data_path, "tensor_" + str(i) + ".pt")
            tensor_i = torch.load(tensor_i_path).to(device1).mean(dim=0)
            tensor_i_1_path = os.path.join(train_data_path, "tensor_" + str(i+1) + ".pt")
            tensor_i_1 = torch.load(tensor_i_1_path).to(device1).mean(dim=0)
            
            fv_concat = torch.concat([tensor_i, tensor_i_1], dim=0).float()
            cls = model(fv_concat, train=False)
            cls_item = cls.item()
            if cls_item > 0.5:
                pred_result[name]['number'].append(i)

            del tensor_i, tensor_i_1, fv_concat, cls
        gc.collect()

    number_list = pred_result

    print(f'generate stage2 results step 2------')
    for name in tqdm(pred_result):
        tensor_dir = os.path.join(train_data_root, name)
        if len(pred_result[name]['number']) == 0:
            tensor_stack = []
            length_2 = len(os.listdir(tensor_dir))
            for i in range(length_2):
                tensor_path = os.path.join(tensor_dir, "tensor_" + str(i) + ".pt")
                tensor = torch.load(tensor_path).to("cpu")
                tensor_stack.append(tensor)
            
            stacked_tensor = torch.concat(tensor_stack, dim=0)
            stacked_tensor = stacked_tensor.to("cpu")
            stacked_tensor = stacked_tensor.mean(dim=0)
            pred_result[name]['tensor_list'] = [stacked_tensor]
            del tensor_stack, stacked_tensor
            gc.collect()
        else:
            start = 0
            pred_result[name]['tensor_list'] = []
            for i in range(len(pred_result[name]['number'])):
                end = pred_result[name]['number'][i]
                tensor_stack = []
                for index in range(start, end + 1):
                    tensor_path = os.path.join(tensor_dir, "tensor_" + str(i) + ".pt")
                    tensor = torch.load(tensor_path).to("cpu")
                    tensor_stack.append(tensor)
                    start = end
                stacked_tensor = torch.concat(tensor_stack, dim=0)
                stacked_tensor = stacked_tensor.to("cpu")
                stacked_tensor = stacked_tensor.mean(dim=0)
                pred_result[name]['tensor_list'].append(stacked_tensor)
                del tensor_stack, stacked_tensor
                gc.collect()
    
    print('generate stage2 results step 3-----')
    pipeline.model = pipeline.model.to(device2)
    pipeline.device = device2
    norm_list = []
    for name in tqdm(pred_result):
        len_tensor_list = len(pred_result[name]['tensor_list'])
        trans_path = os.path.join("/home/ubuntu/PROJECTS/", name + ".json")
        if not os.path.exists(trans_path):
            print(f"path not find {name}")
            continue
        norm_list.append(name)
        trans_data = json.load(open(trans_path))['segments']
        text = "The title of this video is ' " + name + " '.Please summarize the following text into exactly " + str(len_tensor_list) + " sentences and you must use # to segment each sentence:"
        for i in range(len(trans_data)):
            text += trans_data[i]['text']
        
        outputs = pipeline(text, max_new_tokens=64)
        generated_text = outputs[0]['generated_text']
     
        sentences = generated_text.split('#')
        if len(sentences) > len_tensor_list:
            new_sentences = sentences[:len_tensor_list + 1]
        else:
            new_sentences = sentences + [sentences[-1]] * (len_tensor_list - len(sentences))
        pred_result[name]['pred_sentences'] = new_sentences

    torch.cuda.empty_cache()
    
    med_tokenizer = AutoTokenizer.from_pretrained("marcobombieri/surgicberta")
    med_model = AutoModelForMaskedLM.from_pretrained("marcobombieri/surgicberta", output_hidden_states=True).to(device3)
    print('generate stage2 results step 4 -----')
    for name in tqdm(pred_result):
        if name not in norm_list:
            continue
        pred_sentence_tensor = []
        for sentence in pred_result[name]['pred_sentences']:
            inputs = med_tokenizer(sentence, return_tensors="pt", max_length=197, truncation=True, padding="max_length").to(device3)
            with torch.no_grad():
                outputs = med_model(**inputs)
            hidden_states = outputs.hidden_states
            last_hidden_state = hidden_states[-1].to("cpu").detach()
            pred_sentence_tensor.append(last_hidden_state)
        pred_result[name]['sentence_tensor'] = pred_sentence_tensor
    del med_model
    del med_tokenizer
    gc.collect()
    
    final_results = []
    all_count = 0
    print('transform the data')
    for name in tqdm(pred_result):
        if name not in norm_list:
            continue
        for i in range(len(pred_result[name]['tensor_list'])):
            final_results.append({'number':all_count, "tensor":pred_result[name]['tensor_list'][i], "sentence_tensor":pred_result[name]['sentence_tensor'][i]})
            all_count += 1
    
    del pred_result
    gc.collect()
    return final_results, number_list


class Stage3_dataset(Dataset):
    def __init__(self, json_data):
        super().__init__()
        self.data = json_data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        
        return {
            'vision_tensor':sample['tensor'],
            'text_tensor':sample['sentence_tensor']
        }
        
class Stage4_dataset(Dataset):
    def __init__(self, json_path):
        super().__init__()
        self.data = json.load(open(json_path))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        label = sample['label']
        tensor = torch.load(sample['tensor_path']).to("cpu")
        i_label = sample.get('i_label', [0.0, 0.0])
        i_1_label = sample.get('i_1_label', [0.0, 0.0])
        return {'tensor':tensor, 'label':label, 'i_label': i_label,'i_1_label': i_1_label}                   

def compute_loss(pred_end0, pred_end1, true_end0, true_end1, target):

    pred_end0 = pred_end0   #
    te0 = true_end0
    
    intersection0 = torch.clamp(torch.minimum(pred_end0, te0) - 1, min=0)

    union0 = torch.clamp(torch.maximum(pred_end0, te0) - 1, min=1e-6)

    iou0 = intersection0 / union0
    iou_loss0 = (1.0 - iou0).mean()

    
    pred_end1 = pred_end1   
    te1 = true_end1
    
    intersection1 = torch.clamp(torch.minimum(pred_end1, te1) - 1, min=0)

    union1 = torch.clamp(torch.maximum(pred_end1, te1) - 1, min=1e-6)


    iou1 = intersection1 / union1
    iou_loss1 = (1.0 - iou1).mean()
    
    reg_loss = iou_loss0 + iou_loss1
    

    return reg_loss

def comput_GIoU(pred_start, pred_end, true_start, true_end):
    intersection = torch.clamp(
    torch.min(pred_end, true_end) - torch.max(pred_start, pred_end), min=0.0)

    # Union
    pred_length = (pred_end - pred_start)
    gt_length = (true_end - true_start)
    union = pred_length + gt_length - intersection

    # To avoid division by zero, clamp union a bit
    union = torch.clamp(union, min=1e-6)

    # IoU
    iou = intersection / union

    # Enclosing segment length
    enclosure = torch.max(pred_end, true_end) - torch.min(pred_start, true_start)
    enclosure = torch.clamp(enclosure, min=1e-6)  # avoid tiny enclosure

    # GIoU
    giou = iou - (enclosure - union) / enclosure

    # GIoU loss (typical form is 1 - GIoU)
    giou_loss = 1.0 - giou.abs()
    return giou_loss

def judge_trans(trans_data, start, end):
    max_overlap = 0
    overlap_id = -1
    for index, data in enumerate(trans_data):
        semantic_start = round(data['start'])
        semantic_end = round(data['end'])
        overlap_start = max(semantic_start, start)
        overlap_end = min(end, semantic_end)
        
        overlap = max(0, overlap_end - overlap_start)
        if overlap > max_overlap:
            overlap_id = index
            max_overlap = overlap
            
    return overlap_id

def rename(index):
    if index < 100:
        name = "0" + str(index) + ".jpg"
    else:
        name = str(index) + ".jpg"
    return name 

def rewrite_json(json_data):

    conversations = []
    org_data = json_data
    all_count = 0
    for data in org_data:
        frame_list = data['image_path']
        description = data['text']
        title = data['title']
        conversations.append(
            {
                "id": f"identity_{all_count}",
                "title":title,
                "conversations": [
                    {
                        "from": "user",
                            "value": f": <|vision_start|>{frame_list}<|vision_end|>"
                    },
                    {
                        "from": "assistant", 
                            "value": description
                    }
                ]
            }
        )
        all_count += 1
            
    with open("/home/ubuntu/qwen_finetune.json", "w") as fp:
        json.dump(conversations, fp, indent=4)

from tqdm import tqdm
def get_pred(stage2_model, stage1_model):
    device = torch.device("cuda")
    stage2_model = stage2_model.to(device)
    stage1_model = stage1_model.to(device)
    stage1_model.eval()
    stage2_model.eval()
    
    train_dir = "/home/ubuntu//tensor_save_path"
    pred_result = []
    train_data = json.load(open("/home/ubuntu"))
    for name in tqdm(train_data):
        train_data_path = os.path.join(train_dir, name)
        start = 0
        end = len(os.listdir(train_data_path)) - 1
        for i in range(len(os.listdir(train_data_path))):
            if i == len(os.listdir(train_data_path)) - 1:
                continue
            tensor_i_path = os.path.join(train_data_path, "tensor_" + str(i) + ".pt")
            tensor_i_1_path = os.path.join(train_data_path, "tensor_" + str(i+1) + ".pt")
            
            tensor_i = torch.load(tensor_i_path).to(device).unsqueeze(0)
            tensor_i_1 = torch.load(tensor_i_1_path).to(device).unsqueeze(0)
            
            tensor_0 = torch.load(tensor_i_path).to(device).mean(dim=0)
            tensor_1 = torch.load(tensor_i_1_path).to(device).mean(dim=0)
            
            fv_concat_stage2 = torch.concat([tensor_0, tensor_1], dim=0).float()
            fv_concat = torch.concat([tensor_i, tensor_i_1], dim=1).float()
            cls = stage2_model(fv_concat_stage2, train=False)
            pred_end_0, pred_end_1 = stage1_model(fv_concat)
            
            cls_item = cls.item()
            pred_end_0_item = round(pred_end_0.item())
            pred_end_1_item = round(pred_end_1.item())
            
            if cls_item < 0.5:
                end = i
                pred_result.append({'name':name, 'start':start, 'end':end, 'end_i_1':pred_end_1_item})
                start = i
                end = len(os.listdir(train_data_path)) - 1
        
        pred_result.append({'name':name, 'start':start, 'end':end, 'end_i_1':0})
                
    train_keystep = json.load(open("/home/ubuntu/PROJECTS/surgical_workflow/iccv_train.json"))
    for data in pred_result:
        name = data['name']
        start = data['start'] * 23
        end = data['end'] * 23
        this_train_keystep = train_keystep[name]
        
        max_overlap_id = judge_trans(this_train_keystep, start, end)
        keystep_name = this_train_keystep[max_overlap_id]['keystep']
        data['keystep'] = keystep_name
    
    with open("/home/llama_finetune_pred.json", "w") as fp:
        json.dump(pred_result, fp, indent=4)
    
    return pred_result
        
def log_gradients(model, log_file, epoch):
    """
    """
    for name, param in model.named_parameters():
        if param.grad is None:
            log_file.write(f'epoch {epoch}, layer {name}, grad is None\n')
        else:
            log_file.write(f'epoch {epoch}, layer {name}, grad {param.grad.data.mean().item()}')
        log_file.flush() 


def train_llama(llama_model, llama_optim, llama_loss, data_json):
    device = torch.device("cuda")
    llama_model = llama_model.to(device)
    llama_model.train()
    train_root = "/home/ubuntu/PROJECTSe/tensor_save_path"
    for data in data_json:
        llama_optim.zero_grad()
        start_id = data['start']
        end_id = data['end']
        name = data['name']
        pred_end_1 = data['end_i_1']
        keystep_name = data['keystep']
        
        train_dir = os.path.join(train_root, name)
        tensor_stack = []
        for i in range(start_id, end_id + 1):
            tensor_name = "tensor_" + str(i) + ".pt"
            tensor_path = os.path.join(train_dir, tensor_name)
            tensor = torch.load(tensor_path)
            tensor_stack.append(tensor)
        if end_id != len(os.listdir(train_dir)) - 1:
            tensor_i_1_name = "tensor_" + str(end_id + 1) + ".pt"
            tensor_i_1_path = os.path.join(train_dir, tensor_i_1_name)
            tensor_i_1 = torch.load(tensor_i_1_path)
            tensor_new = tensor_i_1[:pred_end_1, :, :, :]
            tensor_stack.append(tensor_new)
            
        image_embedding = torch.concat(tensor_stack, dim=0).mean(dim=0).to(device)
        input_text = "You are an experienced surgeon with over 10 years of clinical training. Given a surgical video segment and its corresponding" +  name + " , provide a concise summary of the surgical step depicted in the segment. The summary should be no more than 10 words. "      
        output_text = keystep_name
        loss, _ = llama_model(image_embedding, input_text, output_text)
        loss.backward()
        llama_loss.append(loss.item())
        llama_optim.step()
        print(f'llama loss {loss.item()}')
        plt.figure()
        plt.plot(llama_loss, 'r-', linewidth=2)
        plt.xlabel('Iterations', fontsize=12)
        plt.ylabel('Loss', fontsize=12)
        plt.title(f'Batch Loss Curve (Epoch {epoch+1})', fontsize=14)
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.savefig(f'llama_loss.png', dpi=150, bbox_inches='tight')
        plt.close()
        
        


if __name__ == "__main__":
    import torch.optim as optim

    stage1_model = Model1().to(torch.float32).to(device0)
    stage2_model = Mlp2().to(torch.float32).to(device1)
    stage1_optim = optim.Adam(stage1_model.parameters(), lr=1e-5, weight_decay=0.01)
    stage2_optim = optim.Adam(stage2_model.parameters(), lr=1e-5, weight_decay=0.01)

    stage1_dataset = Stage1_dataset("/home/ubuntu/.json")
    stage1_dataloader = DataLoader(stage1_dataset, batch_size=32, shuffle=True, num_workers=16)
    stage2_dataset = Stage2_dataset("/home/ubuntu/PROJECTS/.json")
    stage2_dataloader = DataLoader(stage2_dataset, batch_size=16, shuffle=True, num_workers=16)
    stage4_dataset = Stage4_dataset("/home/ubuntu/PROJECTS/.json")
    stage4_dataloader = DataLoader(stage4_dataset, shuffle=True, num_workers=16, batch_size=8)


    stage1_epoch = 14
    stage2_epoch = 150
    stage3_epoch = 14
    stage4_epoch = 14
    stage5_epoch = 14
    sum_epoch = 1
    stage1_loss_item = []
    stage2_loss_item = []
    stage3_loss_item = []
    stage4_loss_item = []
    stage4_giou_loss_item = []
    stage1_epoch_loss = []
    stage2_epoch_loss = []
    stage3_epoch_loss = []
    stage4_epoch_loss = []
    stage4_epoch_giou_loss = []
    llama_loss = []
    for sub_sum_epoch in range(sum_epoch):
        # high
        stage1_model.train()
        for epoch in range(stage1_epoch):
            running_loss = 0.0
            batch_count = 0
            for batch in stage1_dataloader:
                stage1_optim.zero_grad()
                vision_tensor = batch['vision_tensor'].to(device0).float()
                text_tensor = batch['text_tensor'].to(device0).float()
                output = stage1_model(vision_tensor, text_tensor, inference=False, train_w_text=True)
                cosine_loss = compute_cosine_loss(output, text_tensor)
                cosine_loss.backward()
                stage1_optim.step()
            
                running_loss += cosine_loss.item()
                batch_count += 1
                print(f'all epoch {epoch} stage 1 loss {cosine_loss.item()}')
                
                
        stage2_model.train()
        for epoch in range(stage2_epoch):
            running_loss = 0.0
            batch_count = 0
            for batch in stage2_dataloader:
                stage2_optim.zero_grad()
                tensor = batch['tensor'].float().to(device1)
                target_0 = torch.tensor(batch['label_0'], dtype=torch.float32, device=device1).unsqueeze(1)
                target_1 = torch.tensor(batch['label_1'], dtype=torch.float32, device=device1).unsqueeze(1)
                target_2 = torch.tensor(batch['label_2'], dtype=torch.float32, device=device1).unsqueeze(1)
                target_3 = torch.tensor(batch['label_3'], dtype=torch.float32, device=device1).unsqueeze(1)
                cls_0, cls_1, cls_2, cls_3 = stage2_model(tensor)
                bce_loss_0 = compute_bce_loss(cls_0, target_0)
                bce_loss_1 = compute_bce_loss(cls_1, target_1)
                bce_loss_2 = compute_bce_loss(cls_2, target_2)
                bce_loss_3 = compute_bce_loss(cls_3, target_3)
                bce_loss = bce_loss_0 + bce_loss_1 + bce_loss_2 + bce_loss_3
                bce_loss.backward()
                
                running_loss += bce_loss.item()
                batch_count += 1
                stage2_optim.step()
                print(f'all epoch {epoch} stage 2 loss {bce_loss.item()}')
        torch.save(stage1_model.state_dict(), "/home/ubuntu/PROJECTS/")
        torch.save(stage2_model.state_dict(), "/home/ubuntu/PROJECTS/")
        #mid->low
        stage3_data, number_list = generate_stage2_result(stage2_model, pipeline)
        stage3_dataset = Stage3_dataset(stage3_data)
        stage3_dataloader = DataLoader(stage3_dataset, shuffle=True, num_workers=12, batch_size=8)
        stage1_model.train()
        with open("/home/ubuntu/PROJECTS/", "w") as fp:
            for epoch in range(stage3_epoch):
                running_loss = 0.0
                batch_count = 0
                for batch in stage3_dataloader:
                    stage1_optim.zero_grad()
                    vision_tensor = batch['vision_tensor'].float().to(device0)
                    text_tensor = batch['text_tensor'].float().to(device0)
                    output = stage1_model(vision_tensor, text_tensor, inference=False, train_w_text=True)
                    cosine_loss = compute_cosine_loss(output, text_tensor)
                    cosine_loss.backward()
                    log_gradients(stage1_model, fp, epoch)
                    stage1_optim.step()
                    
                    running_loss += cosine_loss.item()
                    batch_count += 1
                    print(f'all epoch {epoch} stage 3 loss {cosine_loss.item()}')
        del stage3_data
        del stage3_dataloader
        del stage3_dataset
        gc.collect()

        with open("/home/ubuntu/PROJECTS/", "w") as fp:
            for epoch in range(stage5_epoch):
                running_loss = 0.0
                batch_count = 0
                ground_truth_data = json.load(open("/home/ubuntu/PROJECTS/"))
                for name in number_list:
                    if len(number_list[name]['number']) == 0:
                        continue
                    gd_data = ground_truth_data[name]
                    start = gd_data[0]['start']
                    pred_start = torch.tensor([[1]]).float().to(device0)
                    for i in range(min(len(gd_data), len(number_list[name]['number']))):
                        stage1_optim.zero_grad()
                        gd_start = gd_data[i]['start'] - start + 1
                        gd_end = gd_data[i]['end'] - start + 1
                        gd_start_tensor = torch.tensor([[gd_start]]).to(device0)
                        gd_end_tensor = torch.tensor([[gd_end]]).to(device0)
                        number_seg = number_list[name]['number'][i]
                        tensor_dir = os.path.join("/home/ubuntu/PROJECTS/Datasets/ICLR/all_file/tensor_save_path", name)
                        tensor_i_path = os.path.join(tensor_dir, "tensor_" + str(number_seg) + ".pt")
                        tensor_i_1_path = os.path.join(tensor_dir, "tensor_" + str(number_seg + 1) + ".pt")
                        tensor_i = torch.load(tensor_i_path).float().to(device0).unsqueeze(0)
                        tensor_i_1 = torch.load(tensor_i_1_path).float().to(device0).unsqueeze(0)
                        tensor = torch.concat([tensor_i, tensor_i_1], dim=1).float()
                        pred_end0, pred_end1 = stage1_model(tensor)
                        giou_loss = comput_GIoU(pred_end=pred_end0, pred_start=pred_start, true_end=gd_end_tensor, true_start=gd_start_tensor)
                        giou_loss.backward()
                        log_gradients(stage1_model, fp, epoch)
                        stage1_optim.step()
                        running_loss += giou_loss.item()
                        batch_count += 1

                        print(f'all epoch {epoch} stage 4 giou loss {giou_loss.item()}')
        #low
        for epoch in range(stage4_epoch):
            running_loss = 0.0
            batch_count = 0
            for batch in stage4_dataloader:
                stage1_optim.zero_grad()
                fv_concat = batch['tensor'].float().to(device0)
                target = torch.tensor(batch['label'], dtype=torch.float32, device=device0).unsqueeze(1)
                true_end0 = batch['i_label'][1].to(device0)
                true_end1 = batch['i_1_label'][1].to(device0)

                pred_end0, pred_end1 = stage1_model(fv_concat)
                iou_loss = compute_loss(pred_end0, pred_end1, true_end0, true_end1, target)
                
                running_loss += iou_loss.item()
                batch_count += 1
                iou_loss.backward()
                stage1_optim.step()
                print(f'all epoch {epoch} stage 4 loss {iou_loss.item()}')
                

            
        
    torch.save(stage1_model.state_dict(), "/home/ubuntu/PROJECTS/1.pt")
    torch.save(stage2_model.state_dict(), "/home/ubuntu/PROJECTS/surgical_workflow/2.pt")
