import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Sampler,SequentialSampler,RandomSampler
from transformers.models.t5.modeling_t5 import T5Block
import re
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl)
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap, size_based_auto_wrap_policy,
    wrap,
)

from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime
import pickle
import transformers
tokenizer = None

from datasets import load_dataset

import torch

from pathlib import Path
import random
import pandas as pd
pd.set_option('display.max_columns', 500)


from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
   apply_activation_checkpointing,
)

import sys
from peft import (
    LoraConfig,
    PeftConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)

from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)









var_to_ind= {'ALP': 1, 'ALT': 2, 'AST': 3, 'Albumin': 4, 'Albumin 25%': 5, 'Albumin 5%': 6,
              'Amiodarone': 7, 'Anion Gap': 8, 'Antibiotics': 9, 'BUN': 10, 'Base Excess': 11,
              'Basophils': 12, 'Bicarbonate': 13, 'Bilirubin (Direct)': 14, 'Bilirubin (Indirect)': 15,
              'Bilirubin (Total)': 16, 'CRR': 17, 'Calcium Free': 18, 'Calcium Gluconate': 19,
              'Calcium Total': 20, 'Cefazolin': 21, 'Chest Tube': 22, 'Chloride': 23, 'Colloid': 24,
              'Creatinine Blood': 25, 'Creatinine Urine': 26, 'D5W': 27, 'DBP': 28, 'Dextrose Other': 29,
              'Dobutamine': 30, 'Dopamine': 31, 'EBL': 32, 'Emesis': 33, 'Eoisinophils': 34,
              'Epinephrine': 35, 'Famotidine': 36, 'Fentanyl': 37, 'FiO2': 38, 'Fiber': 39,
              'Free Water': 40, 'Fresh Frozen Plasma': 41, 'Furosemide': 42, 'GCS_eye': 43,
              'GCS_motor': 44, 'GCS_verbal': 45, 'GT Flush': 46, 'Gastric': 47, 'Gastric Meds': 48,
              'Glucose (Blood)': 49, 'Glucose (Serum)': 50, 'Glucose (Whole Blood)': 51,
              'HR': 52, 'Half Normal Saline': 53, 'Hct': 54, 'Heparin': 55, 'Hgb': 56,
              'Hydralazine': 57, 'Hydromorphone': 58, 'INR': 59, 'Insulin Humalog': 60,
              'Insulin NPH': 61, 'Insulin Regular': 62, 'Insulin largine': 63,
              'Intubated': 64, 'Jackson-Pratt': 65, 'KCl': 66, 'KCl (Bolus)': 67,
              'LDH': 68, 'Lactate': 69, 'Lactated Ringers': 70, 'Levofloxacin': 71,
              'Lorazepam': 72, 'Lymphocytes': 73, 'Lymphocytes (Absolute)': 74,
              'MBP': 75, 'MCH': 76, 'MCHC': 77, 'MCV': 78, 'Magnesium': 79,
              'Magnesium Sulfate (Bolus)': 80,  'Magnesium Sulphate': 81,
              'Mechanically ventilated': 82, 'Metoprolol': 83, 'Midazolam': 84,
              'Milrinone': 85, 'Monocytes': 86, 'Morphine Sulfate': 87,
              'Neosynephrine': 88, 'Neutrophils': 89, 'Nitroglycerine': 90,
              'Nitroprusside': 91, 'Norepinephrine': 92, 'Normal Saline': 93,
              'O2 Saturation': 94, 'OR/PACU Crystalloid': 95, 'PCO2': 96,
              'PO intake': 97, 'PO2': 98, 'PT': 99, 'PTT': 100, 'Packed RBC': 101,
              'Pantoprazole': 102, 'Phosphate': 103, 'Piggyback': 104, 'Piperacillin': 105,
              'Platelet Count': 106, 'Potassium': 107, 'Pre-admission Intake': 108,
              'Pre-admission Output': 109, 'Propofol': 110, 'RBC': 111, 'RDW': 112,
              'RR': 113, 'Residual': 114, 'SBP': 115, 'SG Urine': 116, 'Sodium': 117,
              'Solution': 118, 'Sterile Water': 119, 'Stool': 120, 'TPN': 121,
              'Temperature': 122, 'Total CO2': 123, 'Ultrafiltrate': 124, 'Urine': 125,
              'Vancomycin': 126, 'Vasopressin': 127, 'WBC': 128, 'Weight': 129,
              'pH Blood': 130, 'pH Urine': 131}

stuff_to_use = {"GCS_eye", "GCS_motor", "GCS_verbal", "Bilirubin (Total)", "Platelet Count", "Urine", "Creatinine Blood", "FiO2", "PO2", "Weight", "Dopamine", "Dobutamine", "Epinephrine", "Norepinephrine", "SBP", "DBP"}
import math
import copy
from copy import deepcopy


class MimicDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, stage, target_note_list = ["synthetic_sepsis_label"], include_text_of_same_type=False, include_measurements_in_text=True, max_len=3500):
        with open("mimic_iii_preprocessed_synthetic_michi.pkl", "rb") as infile:
            self.df, oc, self.train_ind, self.val_ind, self.test_ind = pickle.load(infile)
        self.tokenizer = tokenizer
        self.offset = -0
        self.split_target = True # Implement That
        self.use_susinf = False
        if stage=="train":
            self.df=self.df[self.df.ts_ind.isin(self.train_ind)]
            if self.use_susinf:
                self.get_positive_susinfs = self.df[self.df.hour < 24]
                self.susinfs = self.get_positive_susinfs[self.get_positive_susinfs.variable=="suspected-infection"]
                self.susinfs = self.susinfs[self.susinfs.value==1.0]
                self.new_indices = self.susinfs.ts_ind
                self.df = self.df[self.df.ts_ind.isin(self.new_indices)]
            self.df=self.df.reset_index(drop=True)
            self.target_texts = self.df[self.df.variable.isin(target_note_list)]
            self.target_texts = self.target_texts[self.target_texts.hour > abs(self.offset)]
        if stage=="val":
            self.df=self.df[self.df.ts_ind.isin(self.val_ind)]
            if self.use_susinf:
                self.get_positive_susinfs = self.df[self.df.hour < 24]
                self.susinfs = self.get_positive_susinfs[self.get_positive_susinfs.variable=="suspected-infection"]
                self.susinfs = self.susinfs[self.susinfs.value==1.0]
                self.new_indices = self.susinfs.ts_ind
                self.df = self.df[self.df.ts_ind.isin(self.new_indices)]
            self.df=self.df.reset_index(drop=True)
            self.target_texts = self.df[self.df.variable.isin(target_note_list)]
            self.target_texts = self.target_texts[self.target_texts.hour > abs(self.offset)]
        if stage=="test":
            self.df=self.df[self.df.ts_ind.isin(self.test_ind)]
            if self.use_susinf:
                self.get_positive_susinfs = self.df[self.df.hour < 24]
                self.susinfs = self.get_positive_susinfs[self.get_positive_susinfs.variable=="suspected-infection"]
                self.susinfs = self.susinfs[self.susinfs.value==1.0]
                self.new_indices = self.susinfs.ts_ind
                self.df = self.df[self.df.ts_ind.isin(self.new_indices)]
            self.df=self.df.reset_index(drop=True)
            self.target_texts = self.df[self.df.variable.isin(target_note_list)]
            self.target_texts = self.target_texts[self.target_texts.hour > abs(self.offset)]
        self.cache = dict()
        self.caching = True
        self.target_note_list=target_note_list
        self.include_text_of_same_type=include_text_of_same_type
        self.max_len = max_len
        self.include_measurements_in_text = include_measurements_in_text
        #print(self.target_texts)
        self.len_dict = {"train": 15000, "val":3000, "test": 3000}
        self.target_texts = self.target_texts.sample(n=self.len_dict[stage], random_state=1)#.reset_index(drop=True)

    def cache_exists(self, inputdir):
        return os.path.exists(inputdir)
    def save_cache(self, outputdir):
        with open(outputdir, "wb") as outfile:
            pickle.dump(self.cache, outfile)
    def load_cache(self, inputdir):
        with open(inputdir, "rb") as infile:
            self.cache=pickle.load(infile)
    def precompute_cache(self):
        for i in tqdm.tqdm(range(len(self.index))):
            self[i]
    def __getitem__(self, idx):
        if idx in self.cache and self.caching:
            return self.cache[idx]
        datapoint = self.df.loc[self.target_texts.index[idx]]
        if self.offset < -10:
            datapoint_24hours_before = self.df[(self.df.hour < datapoint.hour-24+0.01) & (self.df.hour > datapoint.hour-24-0.01)]
            same_datapoint_24hours_before = datapoint_24hours_before[datapoint_24hours_before.variable == datapoint.variable]
            #print("WOOOP", same_datapoint_24hours_before)
        other_datapoints = self.df[self.df.ts_ind==datapoint.ts_ind]
        in_time_datapoints = other_datapoints[(other_datapoints.hour > datapoint.hour - 24 + self.offset) & (other_datapoints.hour < datapoint.hour + self.offset)]
        age_data = (other_datapoints[other_datapoints.variable=="Age"])
        if len(age_data) > 0:
            age = age_data.iloc[0].value
        else:
            age = 60
        gender_data = (other_datapoints[other_datapoints.variable=="Gender"])
        if len(gender_data)>0:
            gender = gender_data.iloc[0].value
        else:
            gender = 0
        string = "Patient is {} years old and is {}. Given all the information in this text, answer the question at the end.".format(age, "female" if gender else "male")
        textual_information_pre = "\nHere are the clinical texts: "
        measurement_information_pre = "\nHere are the measurements: "
        textual_information = []
        measurement_information = []
        current_time = datapoint.hour
        pre_final_part = """\n Please answer in the following style, where current time values are the correct summary values of the given measurements according to the Sepsis-3 definition. The future values should be your best guesses on how the values develop. The SOFA scores should be calculated according to the Sepsis-3 definition. Only answer like in the given example. Here is the example:
Patient is 63.0 years old and is male. Given all the information in this text, answer the question at the end.
Here are the measurements: DBP at time -23.68: 36.0, SBP at time -23.68: 71.0, DBP at time -23.43: 66.0, SBP at time -23.43: 86.0, DBP at time -23.18: 28.0, SBP at time -23.18: 74.0, DBP at time -22.93: 45.0, SBP at time -22.93: 83.0, DBP at time -22.68: 46.0, SBP at time -22.68: 87.0, DBP at time -22.43: 48.0, SBP at time -22.43: 73.0, DBP at time -22.18: 33.0, SBP at time -22.18: 79.0, DBP at time -21.93: 36.0, SBP at time -21.93: 71.0, DBP at time -21.43: 43.0, GCS_eye at time -21.43: 4.0, GCS_motor at time -21.43: 6.0, GCS_verbal at time -21.43: 5.0, SBP at time -21.43: 67.0, Urine at time -21.43: 150.0, DBP at time -20.68: 34.0, SBP at time -20.68: 69.0, DBP at time -20.1: 36.0, SBP at time -20.1: 86.0, DBP at time -19.93: 54.0, SBP at time -19.93: 80.0, DBP at time -19.43: 41.0, SBP at time -19.43: 84.0, Bilirubin (Total) at time -19.27: 1.8, Creatinine Blood at time -19.27: 1.4, Platelet Count at time -19.27: 235.0, DBP at time -19.18: 54.0, SBP at time -19.18: 84.0, DBP at time -18.93: 50.0, SBP at time -18.93: 84.0, PO2 at time -18.87: 141.0, DBP at time -18.77: 46.0, SBP at time -18.77: 77.0, DBP at time -18.68: 44.0, SBP at time -18.68: 90.0, DBP at time -18.43: 39.0, SBP at time -18.43: 87.0, DBP at time -17.43: 45.0, GCS_eye at time -17.43: 4.0, GCS_motor at time -17.43: 6.0, GCS_verbal at time -17.43: 5.0, SBP at time -17.43: 97.0, Urine at time -17.43: 280.0, DBP at time -16.43: 39.0, SBP at time -16.43: 82.0, Urine at time -16.43: 80.0, DBP at time -15.43: 38.5, SBP at time -15.43: 84.0, Urine at time -15.43: 45.0, DBP at time -15.1: 37.0, SBP at time -15.1: 72.0, DBP at time -14.43: 33.0, SBP at time -14.43: 78.0, Urine at time -14.43: 120.0, DBP at time -13.43: 45.5, SBP at time -13.43: 82.0, Urine at time -13.43: 40.0, DBP at time -12.43: 37.0, SBP at time -12.43: 76.5, Urine at time -12.43: 75.0, Creatinine Blood at time -11.68: 1.4, Platelet Count at time -11.68: 287.0, DBP at time -11.43: 33.5, SBP at time -11.43: 84.5, PO2 at time -10.87: 147.0, DBP at time -10.43: 45.0, SBP at time -10.43: 100.0, Urine at time -10.43: 80.0, DBP at time -9.43: 50.0, GCS_eye at time -9.43: 4.0, GCS_motor at time -9.43: 6.0, GCS_verbal at time -9.43: 5.0, SBP at time -9.43: 95.0, Urine at time -9.43: 100.0, DBP at time -8.43: 47.0, SBP at time -8.43: 96.0, Urine at time -8.43: 100.0, DBP at time -7.43: 44.0, SBP at time -7.43: 97.0, Urine at time -7.43: 30.0, DBP at time -6.43: 44.0, SBP at time -6.43: 97.0, Urine at time -6.43: 85.0, DBP at time -5.43: 41.0, GCS_eye at time -5.43: 4.0, GCS_motor at time -5.43: 6.0, GCS_verbal at time -5.43: 5.0, SBP at time -5.43: 90.5, Urine at time -5.43: 30.0, DBP at time -4.43: 42.0, SBP at time -4.43: 97.0, Urine at time -4.43: 60.0, DBP at time -3.43: 40.0, SBP at time -3.43: 98.0, Urine at time -3.43: 100.0, PO2 at time -2.73: 144.0, DBP at time -2.43: 41.0, SBP at time -2.43: 95.0, Urine at time -2.43: 100.0, DBP at time -1.43: 38.0, SBP at time -1.43: 89.0, Urine at time -1.43: 60.0, DBP at time -0.43: 45.0, GCS_eye at time -0.43: 4.0, GCS_motor at time -0.43: 6.0, GCS_verbal at time -0.43: 5.0, SBP at time -0.43: 97.0, Urine at time -0.43: 50.0
Now answer the following question: 
Doctors suspect an infection, based on this information and the other information in this text, will the patient be classified as septic tomorrow?
First we need to calculate the SOFA scores given the extracted values. The SOFA scores for the current time are the following: 
The minimum value of GCS_eye is 4.0, GCS_motor is 6.0 and GCS_verbal is 5.0, this produces the sum 15.0 and means the CNS SOFA is 0.
Because minimum MAP is 43.333, max Dopamine is 0, max Dobutamine is 0, max Epinephrine is 0 and max Norepinephrine is 0 with a patient weight of 80 kg, the cardiovascular SOFA is 1.
Given that minimum PO2 is 141.0 and minimum FiO2 is 1 the calculated PAO2FIO2 is 141.0, this means the respiratory SOFA is 3.
Because the minimum Platelet count is 235.0 the coagulation SOFA is 0.
The maximum Bilirubin (Total) is 1.8 leading to a liver SOFA of 1.
Because total Urine output is 1585.0 and maximum creatinine in the blood is 1.4 the renal SOFA is 1.
To summarize: the patient has a total SOFA score of 6.
Now we need to calculate the SOFA scores with forecasted values. The SOFA scores in the future based on the forecasted values are the following: 
The minimum value of GCS_eye will be 4.0, GCS_motor will be 6.0 and GCS_verbal will be 5.0, this produces the sum 15.0 and means the CNS SOFA will be 0.
Because future minimum MAP will be 55.667, future max Dopamine will be 0, future max Dobutamine will be 0, future max Epinephrine will be 0 and future max Norepinephrine will be 0 with a patient weight of 80 kg, the cardiovascular SOFA will be 1.
Given that minimum PO2 will be 141.0 and minimum FiO2 will be 1 the forecasted PAO2FIO2 will be 141.0, this means the respiratory SOFA will be 3.
Because the Platelet count will be 295.0 the coagulation SOFA is going to be 0.
The maximum Bilirubin (Total) will be 1.8 leading to a liver SOFA of 1.
Because Urine output will be 1635.0 and maximum creatinine in the blood will be 1.1 the renal SOFA will be 0.
To summarize: the patient will have a future total SOFA score of 5.
The patient will not develop sepsis in the next 24 hours, because total SOFA changed only by -1 and infection is suspected
The example is now finished. Say "The patient will develop sepsis" in the last sentence if the criteria are met (if total SOFA changed by 2 and infection is suspected).
"""
        final_part = "\nNow answer the following question in the given format: \n"
        ts = torch.zeros((24, 262))
        for information in in_time_datapoints[::-1].itertuples():
            old_textual_information = deepcopy(textual_information)
            old_measurement_information = deepcopy(measurement_information)
            if information.TABLE == "TEXT" and information.variable in self.target_note_list and self.include_text_of_same_type:
                textual_information+=["Clinical Note: \n" + information.textvalue]
            else:
                #print("VARIABLE", information)
                if information.variable in stuff_to_use:
                    if self.include_measurements_in_text:
                        measurement_information+=[information.variable +" at time {}: ".format(round(information.hour-current_time+abs(self.offset), 2)) + str(round(information.value * information.std + information.mean, 3))]
                    if True:
                        time = math.floor(information.hour-current_time+abs(self.offset))
                        index = var_to_ind[information.variable]-1
                        ts[time, index] = information.value
                        ts[time, index+131] = 1
            current_text = string + ("\n".join(textual_information[::-1]) if self.include_text_of_same_type else "") + (", ".join(measurement_information[::-1]) if self.include_measurements_in_text else "") + final_part
            if len(self.tokenizer(current_text)["input_ids"]) > self.max_len:
                textual_information = old_textual_information
                measurement_information = old_measurement_information
                #break

        current_text = pre_final_part + string + (textual_information_pre + "\n".join(textual_information[::-1]) if self.include_text_of_same_type else "") + ( measurement_information_pre+", ".join(measurement_information[::-1]) if self.include_measurements_in_text else "") + final_part
        target_text = datapoint.textvalue

        complete_text = current_text + target_text
        item = self.tokenizer(complete_text+self.tokenizer.eos_token, return_length=True)
        if self.split_target:
            target_start = target_text.split("?")[0]+"?"
            item2 = self.tokenizer(current_text+target_start)
        else:

            item2 = self.tokenizer(current_text)
        target_labels = copy.deepcopy(item.input_ids)
        for i in range(len(item2.input_ids)):
            target_labels[i]=-100
        item["labels"]=target_labels
        item["size_diff"]=len(item.input_ids) - len(item2.input_ids)
        item["ts"] = ts

        if self.offset < -10:
            item["previousnote"] = same_datapoint_24hours_before.iloc[0].textvalue
        if self.caching:
            self.cache[idx]=item
        return item

    def __len__(self):
        return len(self.target_texts)







def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()


def setup_model(model_name):
    global tokenizer
    #bnb_config = BitsAndBytesConfig(
    #    load_in_4bit=True,
    #    load_4bit_use_double_quant=True,
    #    bnb_4bit_quant_type="nf4",
    #    bnb_4bit_compute_dtype=torch.float16,
    #)
    #model = T5ForConditionalGeneration.from_pretrained(model_name)
    #tokenizer =  T5Tokenizer.from_pretrained(model_name)
    from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
    from transformers import AutoModel, AutoTokenizer, LlamaForCausalLM
    model = LlamaForCausalLM.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto")#, quantization_config=bnb_config)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"]

    config = LoraConfig(

    r=16, lora_alpha=16, target_modules="all-linear", lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"

        )
    #lora_model = get_peft_model(model, config)
    lora_model = model
    return lora_model, tokenizer


def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    g_gigabyte = 1024**3
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num


def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = 0#int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).cuda()


    if sampler:
        pass#sampler.set_epoch(epoch)
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    counter=0
    optimizer.zero_grad()
    for batch in train_loader:
        #print(batch)
        for key in batch.keys():
            #batch[key] = batch[key].to(local_rank)
            batch[key].cuda()
        #optimizer.zero_grad()
        #output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"] )
        output = model(input_ids=batch["input_ids"].cuda(), labels=batch["labels"].cuda(), attention_mask=batch["attention_mask"].cuda())
        loss = output["loss"] / args.batch_multiplier
        loss.backward()
        counter+=1
        if counter%args.batch_multiplier == 0:

            optimizer.step()
            optimizer.zero_grad()
        fsdp_loss[0] += loss.cpu().item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    train_accuracy = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
            )
    return train_accuracy


def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = 0#int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).cuda()
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    if True:#with FSDP.summon_full_params(model):
        with torch.no_grad():
            for batch in val_loader:
                for key in batch.keys():
                    #batch[key] = batch[key].to(local_rank)
                    batch[key].cuda()
                output = model(input_ids=batch["input_ids"].cuda(),labels=batch["labels"].cuda(), attention_mask=batch["attention_mask"].cuda())
                fsdp_loss[0] += output["loss"].item()  # sum up batch loss
                fsdp_loss[1] += len(batch)
                #model.generate(**input,synced_gpus=True)
                if rank==0:
                    inner_pbar.update(1)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

def test(model, rank, world_size, val_loader, filename):
    model.eval()
    fsdp_loss = torch.zeros(3).cuda()
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Testing Epoch"
        )
    collection = list()
    if True:#with FSDP.summon_full_params(model):
        with torch.no_grad():
            #print("here 1")
            for batch in val_loader:
                #print("here 2")
                for key in batch.keys():
                    #batch[key] = batch[key].to(local_rank)
                    try:
                        batch[key].cuda()
                    except: continue
                #output = model(input_ids=batch["input_ids"])
                #model.generate(**input,synced_gpus=True)
                max_len = max([len(x) for x in batch["labels"]])
                #print("here 3", str(max_len))
                output = model.generate(input_ids=batch["input_ids"].cuda(), attention_mask=batch["attention_mask"].cuda(), max_new_tokens=max_len+1000)
                for inputs, outputs, gold in zip(batch["input_ids"], output, batch["labels"]):
                    
                    print("generated_stuff", outputs[len(inputs):])
                    result = {"input": tokenizer.decode(inputs), "output": tokenizer.decode(outputs[len(inputs):]), "gold": tokenizer.decode(gold) }
                    print(result)
                    collection.append(result)
                if rank==0:
                    inner_pbar.update(1)
    if rank == 0:
        inner_pbar.close()
        #print(f"Validation Loss: {val_loss:.4f}")
    with open(".".join(filename.split(".")[:-1])+".generations", "wb") as outfile:
        pickle.dump(collection, outfile)
    return 0


def force_decoding(model, rank, world_size, val_loader, filename, tokenizer):
    model.eval()
    fsdp_loss = torch.zeros(3).cuda()
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Testing Epoch"
        )
    collection = list()
    if True:
        with torch.no_grad():
            for batch in val_loader:
                for key in batch.keys():
                    try:
                        batch[key].cuda()
                    except: continue
                max_len = max([len(x) for x in batch["labels"]])
                #print(batch["labels"])
                tokens = tokenizer.convert_ids_to_tokens(batch["labels"][0])
                i=0
                continueing = False
                spans = []
                span = []
                while True:
                    print(i)
                    if i >= len(tokens): break
                    if re.match("\d+", tokens[i]) and continueing:
                        pass#continue
                    elif re.match("\.", tokens[i]) and continueing:
                        pass#continue
                    elif continueing:
                        continueing = False
                        span.append(i)
                        spans.append(span)
                        span = []
                    elif re.match("\d+", tokens[i]):
                        span.append(i)
                        continueing = True
                    i+=1
                if span:
                    spans.append(span)
                for idx in [31, 38, 46, 48, 50, 53, 54, -2]:
                    #print(idx, len(spans))
                    #print(spans)
                    #print("blub", spans[idx][1])
                    gold_parts = torch.tensor(batch["labels"][0][:spans[idx][1]], dtype=torch.long).unsqueeze(0)
                    gold_mask = torch.ones((1,spans[idx][1]), dtype=torch.long) 
                    output = model.generate(input_ids=torch.cat([batch["input_ids"], gold_parts], dim=1).cuda(), attention_mask=torch.cat([batch["attention_mask"], gold_mask], dim=1).cuda(), max_new_tokens=(spans[idx+1][1]-spans[idx][1]+5))
                    for inputs, outputs, gold in zip(batch["input_ids"], output, batch["labels"]):
                        result = {"level":idx, "input": tokenizer.decode(torch.cat([inputs,torch.tensor(gold[:spans[idx][1]])], dim=0)), "output": tokenizer.decode(outputs[len(inputs)+spans[idx][1]:]), "gold": tokenizer.decode(gold[spans[idx][1]:spans[idx+1][1]+5]) }
                        print(result)
                        collection.append(result)
                if rank==0:
                    inner_pbar.update(1)
    if rank == 0:
        inner_pbar.close()
        #print(f"Validation Loss: {val_loss:.4f}")
    with open(".".join(filename.split(".")[:-1])+".forced", "wb") as outfile:
        pickle.dump(collection, outfile)
    return 0

def collate_fn_partial(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    
    max_len = max([a-b for x in data for a,b in zip(x["length"], [x["size_diff"]])])
    global tokenizer
    input_ids_batched = torch.ones((len(data), max_len), dtype=torch.long) * tokenizer(tokenizer.eos_token, add_special_tokens=False).input_ids[-1]
    labels_batched = []
    attention_masks_batched = torch.zeros((len(data), max_len), dtype=torch.long)
    for i, datapoint in enumerate(data):
        true_input = datapoint["input_ids"][:len(datapoint["input_ids"])-datapoint["size_diff"]]
        true_output = datapoint["input_ids"][len(true_input):]
        input_ids_batched[i, -len(true_input):] = torch.tensor(true_input)
        attention_masks_batched[i, -len(true_input):] = 1
        labels_batched.append(true_output)
    lengths = torch.stack([torch.tensor(x["length"])-torch.tensor(x["size_diff"]) for x in data])
    return {"input_ids":input_ids_batched, "attention_mask": attention_masks_batched, "labels": labels_batched, "lengths":lengths}

def collate_fn_full(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    
    max_len = max([y for x in data for y in x["length"]])
    global tokenizer
    input_ids_batched = torch.ones((len(data), max_len), dtype=torch.long) * tokenizer(tokenizer.eos_token, add_special_tokens=False).input_ids[-1]
    labels_batched = torch.ones((len(data), max_len), dtype=torch.long) * -100
    attention_masks_batched = torch.zeros((len(data), max_len), dtype=torch.long)
    for i, datapoint in enumerate(data):
        input_ids_batched[i, -len(datapoint["input_ids"]):] = torch.tensor(datapoint["input_ids"])
        attention_masks_batched[i, -len(datapoint["attention_mask"]):] = torch.tensor(datapoint["attention_mask"])
        labels_batched[i, -len(datapoint["labels"]):] = torch.tensor(datapoint["labels"])
    target_length = torch.stack([torch.tensor(x["size_diff"]) for x in data])
    lengths = torch.stack([torch.tensor(x["length"]) for x in data])

    return {"input_ids":input_ids_batched, "attention_mask": attention_masks_batched, "labels": labels_batched, "lengths":lengths }

def fsdp_main(args):

    #model, tokenizer = setup_model("meta-llama/Llama-3.2-3B")
    #model, tokenizer = setup_model("gradientai/Llama-3-8B-Instruct-Gradient-1048k")

    model, tokenizer = setup_model("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
    local_rank = 0#int(os.environ['LOCAL_RANK'])
    rank = 0#int(os.environ['RANK'])
    world_size = 1#int(os.environ['WORLD_SIZE'])


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = MimicDataset(tokenizer, "train") 
    val_dataset = MimicDataset(tokenizer, "val") 
    test_dataset = MimicDataset(tokenizer, "test")
    """
    if not train_dataset.cache_exists("train_exp1.cache"):
        train_dataset.precompute_cache()
        train_dataset.save_cache("train_exp1.cache")
    else:
        train_dataset.load_cache("train_exp1.cache")
    if not val_dataset.cache_exists("val_exp1.cache"):
        val_dataset.precompute_cache()
        val_dataset.save_cache("val_exp1.cache")
    else:
        val_dataset.load_cache("val_exp1.cache")
    if not test_dataset.cache_exists("test_exp1.cache"):
        test_dataset.precompute_cache()
        test_dataset.save_cache("test_exp1.cache")
    else:
        test_dataset.load_cache("test_exp1.cache")
    """
    sampler1 = RandomSampler(train_dataset)
    sampler2 = SequentialSampler(val_dataset)



    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}

    force_kwargs = {'batch_size': 1, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 4,
                    'pin_memory': True,
                    'shuffle': False}
    #train_kwargs.update(cuda_kwargs)
    #test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,collate_fn=collate_fn_full,**train_kwargs)
    val_loader_full = torch.utils.data.DataLoader(val_dataset, collate_fn=collate_fn_full,**test_kwargs)
    val_loader_partial = torch.utils.data.DataLoader(val_dataset, collate_fn=collate_fn_partial,**test_kwargs)

    test_loader_full = torch.utils.data.DataLoader(test_dataset, collate_fn=collate_fn_full,**test_kwargs)
    test_loader_partial = torch.utils.data.DataLoader(test_dataset, collate_fn=collate_fn_partial,**test_kwargs)
    test_loader_force = torch.utils.data.DataLoader(test_dataset, collate_fn=collate_fn_partial,**force_kwargs)
    #t5_auto_wrap_policy = functools.partial(
    #    transformer_auto_wrap_policy,
    #    transformer_layer_cls={
    #        T5Block,
    #    },
    #)
    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=20000
        )

    my_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            LlamaDecoderLayer,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    #torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and dist.is_nccl_available()
    )
    print("bf16 ready", bf16_ready)
    mp_policy = None # defaults to fp32
    fpSixteen = MixedPrecision(
        param_dtype=torch.float16,
        # Gradient communication precision.
         reduce_dtype=torch.float16,
        # Buffer precision.
        buffer_dtype=torch.float16,
    )
    model.cuda()
    #fpSixteen=None 
    # model is on CPU before input to FSDP
    """
    model = FSDP(model,
        auto_wrap_policy=my_auto_wrap_policy,
        mixed_precision=fpSixteen,
        #sharding_strategy=sharding_strategy,
        cpu_offload=CPUOffload(offload_params=True),
        device_id=torch.cuda.current_device(), use_orig_params=True)
    #print(model)

    non_reentrant_wrapper = partial(
        checkpoint_wrapper,
        offload_to_cpu=False,
        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
    )
    """
    check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
    #apply_activation_checkpointing(
    #       model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
    #)
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "pretrain_textualisation-reduced-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    if args.load_path:
        model.load_state_dict(torch.load(args.load_path))
    #validat(model, rank, world_size, val_loader_partial)
    force_decoding(model, rank, world_size, test_loader_force, "from_scratch_1shot_deepseek.pt", tokenizer)
    test(model, rank, world_size, test_loader_partial, "from_scratch_1shot_deepseek.pt")
    sys.exit(0)
    curr_val_loss = validation(model, rank, world_size, val_loader_full)
    cpu_state = model.state_dict()

    epoch = "start" 
    if rank == 0:
        currEpoch = (
             "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
        )
        save_name = file_save_name + "-" + time_of_run + "-" + str(args.lr) + "-" + currEpoch

        torch.save(cpu_state, save_name)
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader_full)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                #save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                save_name = file_save_name + "-" + time_of_run + "-" + str(args.lr) + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")


if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.000002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    parser.add_argument('--batch_multiplier', type=int, default=1,
                        help='batch multiplier (default: 1)')
    parser.add_argument("--load_path", type=str, default="", help="which path to load (default="")")

    args = parser.parse_args()

    torch.manual_seed(args.seed)

    fsdp_main(args)



