#!/usr/bin/env python
# coding: utf-8
# %%
import argparse
import os

# %%
from datasets import load_dataset
from teneva import sample
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from tqdm import tqdm

from heads_as_modules import *

import neptune

from peft import get_peft_model, LoraConfig, TaskType

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler
from torch.utils.data import IterableDataset, DataLoader
import torch.multiprocessing as mp

import os
import socket

# %%
def batchify(dataset, batch_size):
    batch = []
    for example in dataset:
        batch.append(example)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if batch:
        yield batch


class DistributedIterableDataset(IterableDataset):
    def __init__(self, dataset, rank, world_size):
        self.dataset = dataset
        self.rank = rank
        self.world_size = world_size

    def __iter__(self):
        for i, example in enumerate(self.dataset):
            if i % self.world_size == self.rank:
                yield example

# %%
def shift_batch(batch, dim=None):
    if dim is None:
        return batch, batch.shape[1]
    else:
        fin_shape = batch.shape[1] - dim
        ans = [torch.roll(batch, -i, dims=1)[:, :fin_shape] for i in range(1, dim + 1)]
        return ans, fin_shape


# %%
class CombinedModel(nn.Module):
    # Combine the model and the new head
    def __init__(self, base_model, new_head):
        super(CombinedModel, self).__init__()
        self.base_model = base_model
        self.new_head = new_head

    def forward(self, input_ids, attention_mask=None, targets=None):
        outputs = self.base_model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        logits, loss = self.new_head(last_hidden_state, targets=targets)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        self.block_size = 999999999
        
        for _ in range(max_new_tokens):
            # If the sequence context is growing too long we must crop it:
            if idx.size(1) <= self.block_size:
                idx_cond = idx
            else:
                idx_cond = idx[:, -self.block_size:]
            
            # Forward the model to get the logits for the index in the sequence:
            logits, _ = self(idx_cond)
            
            # Pluck the logits at the final step and scale by temperature:
            logits = logits[:, -1, :] / temperature
            
            # Optionally crop the logits to only the top k options:
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            # Apply softmax to convert logits to (normalized) probabilities:
            probs = F.softmax(logits, dim=-1)
            
            # Sample from the distribution:
            idx_next = torch.multinomial(probs, num_samples=1)
            
            # Append sampled index to the running sequence and continue:
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

    @torch.no_grad()
    def generate_mh(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        self.block_size = 999999999
        
        self.d = self.new_head.d
        
        for _ in range(max_new_tokens // self.d):
            # If the sequence context is growing too long we must crop it:
            if True: # idx.size(1) <= self.block_size-self.d:
                idx_cond = idx
            else:
                idx_cond = idx[:, -self.block_size-self.d:]
            
            idxs_next, _ = self(idx_cond)
            idx = torch.cat((idx, idxs_next), dim=1)

        return idx


DEFAULT_PROMPT = """
def hello_world():
    '''
    This function just prints hello world, nothing more.
    '''
"""

# %%
def show(model, tokenizer, epoch, loss, i, prompt=DEFAULT_PROMPT, is_default=False):
    device = next(model.parameters()).device
    text_inp = tokenizer(prompt, return_tensors='pt')['input_ids'].to(device)
    if is_default:
        output = model.generate(text_inp, 200)
    else:
        output = model.generate_mh(text_inp, 200)
    text_out = tokenizer.decode(output[0], skip_special_tokens=True)
    
    print(f"\n---> Epoch {epoch + 1}, iter {i}, Loss: {loss.item()}")
    print(f'Demo for generate method: {text_out}')

    return text_out


# %%
def init(
        device, 
        dim=2, 
        is_default=False, 
        checkpoint_path=None, 
        rank=None, 
        add_lora=False, 
        lora_r=8, 
        lora_alpha=32, 
        lora_dropout=0.1, 
        lora_target_modules=["q_proj", "v_proj"]
    ):
    tokenizer = AutoTokenizer.from_pretrained("Daoguang/PyCodeGPT")
    model_base = AutoModelForCausalLM.from_pretrained("Daoguang/PyCodeGPT")

    input_dim = model_base.config.hidden_size
    output_dim = tokenizer.vocab_size

    if is_default: 
        head = DefaultHead(input_dim, output_dim)
    else:
        head = CPHead(input_dim, output_dim, n_tokens=dim, r=rank)
    

    model = CombinedModel(model_base, head).to(device)
    
    if checkpoint_path and add_lora:
        warnings.warn("Loading a checkpoint before applying LoRA adapter. This may affect the model's performance.", RuntimeWarning)
    if checkpoint_path:
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"Loaded model from checkpoint: {checkpoint_path}")    

    if add_lora:
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=lora_dropout,
            bias="none",
            task_type=TaskType.CAUSAL_LM
        )
        model_base = get_peft_model(model_base, lora_config)
        model.base_model = model_base
        model_base.print_trainable_parameters()
    else:
        raise ValueError('add_lora must be True!')
    
    return model, tokenizer


# %%
def train(
        rank,
        world_size,
        is_default=False, 
        dim=None, 
        neptune_project=None,
        run_name=None,
        checkpoint_path=None, 
        freeze_base_model=True, 
        rank_cp=None, 
        add_lora=False, 
        lora_r=16, 
        lora_alpha=64, 
        lora_dropout=0.1,
        master_addr=None,
        master_port=None
    ):
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = master_port
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)

    if rank == 0:
        logger = neptune.init_run(
            project=neptune_project,
            git_ref=False,
            name=run_name
        )
    else:
        logger = None

    batch_size = 5
    max_seq_length = 2048
    grad_acc_steps = 10 / 2
    


    if rank == 0 and logger:
        logger['parameters'] = {
            'batch_size': batch_size,
            'grad_acc_steps': grad_acc_steps,
            'dim': dim,
            'is_default': is_default,
            'freeze_base_model': freeze_base_model,
            'rank': rank_cp,
            'learning_rate': 1e-4,
            'max_seq_length': max_seq_length,
            'add_lora': add_lora,
            'lora_r': lora_r,
            'lora_alpha': lora_alpha,
            'lora_dropout': lora_dropout,
        }

    
    lss = []
    if is_default and dim is not None:
        raise ValueError('Only is_default, or dim, not both!')
    if dim is None:
        dim = 2

    ds = load_dataset("codeparrot/github-code", streaming=True, split="train", data_dir='../.gh_code_data', cache_dir="../.hf_cache")
    
    ds = ds.shuffle(buffer_size=1_000, seed=42)
    ds = ds.filter(lambda x: x['language'] == 'Python')
    
    def distribute_dataset(dataset, rank, world_size):
        for i, example in enumerate(dataset):
            if i % world_size == rank:
                yield example

    distributed_ds = DistributedIterableDataset(ds, rank, world_size)

    train_loader = DataLoader(distributed_ds, batch_size=batch_size)


    model, tokenizer = init(
        device, 
        dim, 
        is_default=is_default, 
        checkpoint_path=checkpoint_path, 
        rank=rank_cp, 
        add_lora=add_lora, 
        lora_r=lora_r, 
        lora_alpha=lora_alpha, 
        lora_dropout=lora_dropout
    )

    model = DDP(model, device_ids=[rank])

    if freeze_base_model and not add_lora:
        for param in model.module.base_model.parameters():
            param.requires_grad = False
    elif not add_lora:
        print("Base model is not frozen and will be fine-tuned.")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    if rank == 0:
        print('\n\n>>> Start training ...')



    # Log parameters to Neptune
    if rank == 0:
        logger['parameters'] = {
            'batch_size': batch_size,
            'grad_acc_steps': grad_acc_steps,
            'dim': dim,
            'is_default': is_default,
            'freeze_base_model': freeze_base_model,
            'rank': rank_cp,
            'learning_rate': 1e-4,
            'max_seq_length': max_seq_length,
            'add_lora': add_lora,
            'lora_r': lora_r,
            'lora_alpha': lora_alpha,
            'lora_dropout': lora_dropout,
        }


    i = 0
    for epoch in tqdm(range(10)):
        ds.set_epoch(epoch)
        optimizer.zero_grad()

        lss_list = []
        for batch in train_loader:
            
            inputs = tokenizer(batch['code'], return_tensors='pt', padding=True, truncation=True, max_length=max_seq_length)
            if is_default:
                labels = inputs['input_ids'][:, 1:].to(device)
                seq_len = labels.shape[1]
            else:
                labels, seq_len = shift_batch(inputs['input_ids'], dim=dim)
                labels = [elem.to(device) for elem in labels]
            xs, attn_mask = inputs['input_ids'], inputs['attention_mask']
            xs = xs[:, :seq_len]
            attn_mask = attn_mask[:, :seq_len]

            # Forward pass
            logits, loss = model(xs.to(device), attention_mask=attn_mask.to(device), targets=labels)

            # Backward pass and optimization
            loss.backward()

            if i % grad_acc_steps == grad_acc_steps - 1:
                for param in model.parameters():
                    if param.grad is not None:
                        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                        param.grad.data /= world_size
                optimizer.step()
                optimizer.zero_grad()

            lss_list.append(loss.item())
            i += 1

            if i % 100 == 0 and rank == 0:
                lss_list.append(torch.tensor(lss_list).mean().item())
                lss_list = []
                logger['train/loss'].append(loss.item())
                
            if i % 2_000 == 0 and rank == 0:
                checkpoint_filename = f'logs/PyCodeGPT_{run_name}.pt'
                torch.save(model.module.state_dict(), checkpoint_filename)
                logger['train/checkpoint'].upload(checkpoint_filename)
                #logger['train/checkpoint'].append(checkpoint_filename.split('/')[-1])

    if rank == 0:
        show(model.module, tokenizer, epoch, loss, i, is_default=is_default)
        print("\n\n+++ Training complete!\n\n")

    if rank == 0 and logger:
        logger.stop()

    dist.destroy_process_group()



# %%
def main():
    parser = argparse.ArgumentParser(description="Run training script with specified GPU.")
    parser.add_argument('--gpus', type=str, default='0', help='Comma-separated list of GPU ids to use')
    parser.add_argument('--dim', type=int, default=None, help='dimension of head (default: 2)')
    parser.add_argument('--head', type=str, default='default', help='head type, subject of the later extension') 
    parser.add_argument('--checkpoint', type=str, default=None, help='path to the checkpoint file to load the model from')
    parser.add_argument('--unfreeze-base', action='store_true', help='unfreeze the base model for fine-tuning')
    parser.add_argument('--rank', type=int, default=None, help='rank for CP head (ignored for default head)')
    parser.add_argument('--add_lora', action='store_true', help='add LoRA adapter for fine-tuning')
    parser.add_argument('--lora_r', type=int, default=8, help='rank of LoRA adapter')
    parser.add_argument('--lora_alpha', type=int, default=32, help='alpha parameter for LoRA')
    parser.add_argument('--lora_dropout', type=float, default=0.1, help='dropout probability for LoRA')

    args = parser.parse_args()

    gpus = [int(gpu) for gpu in args.gpus.split(',')]
    world_size = len(gpus)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    is_default = args.head == 'default'

    # Set up the master address and port
    master_addr = socket.gethostbyname(socket.gethostname())
    master_port = '12355'  # You can choose any free port

    # Set these in the main process
    os.environ['MASTER_ADDR'] = master_addr
    os.environ['MASTER_PORT'] = master_port

    # Create a run name based on parameters
    run_name = f"gpus_{args.gpus}_dim_{args.dim}_head_{args.head}_unfreeze_{args.unfreeze_base}_rank_{args.rank}_lora_{args.add_lora}_lorar_{args.lora_r}"

    # Neptune project name
    neptune_project = "a-wernon/llmtelora"

    mp.spawn(
        train,
        args=(
            world_size,
            is_default,
            args.dim,
            neptune_project,
            run_name,
            args.checkpoint,
            not args.unfreeze_base,
            args.rank,
            args.add_lora,
            args.lora_r,
            args.lora_alpha,
            args.lora_dropout,
            master_addr,
            master_port
        ),
        nprocs=world_size,
        join=True
    )

if __name__ == '__main__':
    main()