import os
import json

import math
import torch
import torch.nn.functional as F
import wandb
from dataclasses import replace
from .config import save_config, TrainingParams
from tqdm import tqdm
from torch.utils.data import DataLoader, RandomSampler
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
import random
import time
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import asdict


def loss_fn(outputs, targets):
    loss = F.cross_entropy(
        outputs.view(-1, outputs.size(-1)),
        targets.to(torch.long).view(-1)
    )
    return loss

def accuracy_fn(outputs, targets):
    preds = outputs.argmax(-1).view(-1)
    truth = targets.to(torch.long).view(-1)
    correct = (preds == truth).sum().item()
    total = truth.size(0)
    return correct, total

class Tester:
    def __init__(self) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.test_batches = []

    def get_test_batches(self, task, config, context_length):

        while len(self.test_batches) < config.evaluation_size / config.batch_size:
            self.test_batches.append(task.sample_batch(
                batch_size=config.batch_size,
                k_shots=config.k_shots,
                max_length=context_length,
                hold_out=True,
                unshuffled=True
            ))
        return self.test_batches

    def evaluate(self, model, task, config):
        correct_t, total = 0, 0
        predictions = []
        with torch.no_grad():
            batches = self.get_test_batches(task, config, model.config.block_size)
            for test_batch in batches:
                test_batch = {k: v.to(self.device)
                        for k, v in test_batch.items() if isinstance(v, torch.Tensor)}

                # Apply padding
                if config.leftpad:
                    context_length = model.config.block_size
                    pad_token_id = task.pad_token_id
                    input_padding = context_length - test_batch["inputs"].size(1)
                    target_padding = context_length - test_batch["targets"].size(1)
                    padded_inputs = F.pad(test_batch["inputs"], (input_padding, 0),
                            value=pad_token_id)
                    padded_targets = F.pad(test_batch["targets"], (target_padding, 0),
                            value=pad_token_id)
                else:
                    padded_inputs = test_batch["inputs"]
                    padded_targets = test_batch["targets"]

                # Forward pass
                outputs = model(padded_inputs)
                if not isinstance(outputs, torch.Tensor):
                    outputs = outputs[0]

                # Compute accuracy only on last positions
                masked_outputs = outputs[:,-1:]
                masked_targets = padded_targets[:,-1:]

                correct, count = accuracy_fn(masked_outputs, masked_targets)
                correct_t += correct
                total += count

                # Accumulate predictions for use in the summary
                predictions.append(outputs.argmax(-1).to('cpu'))

        accuracy = correct_t / total
        summary = task.summarize(batches, predictions, accuracy)
        return correct_t / total, summary

class Trainer:

    def __init__(self) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tester = Tester()

    def train_step(self, model, task, config, optimiser, scheduler):

        context_length = model.config.block_size

        train_batch = task.sample_batch(
            batch_size=config.batch_size,
            k_shots=config.k_shots,
            max_length=context_length,
        )

        # Move input data to the correct device
        train_batch = {k: v.to(self.device)
                for k, v in train_batch.items() if isinstance(v, torch.Tensor)}

        # Pad inputs and targets

        if config.leftpad:
            pad_token_id = task.pad_token_id
            input_padding = context_length - train_batch["inputs"].size(1)
            target_padding = context_length - train_batch["targets"].size(1)
            padded_inputs = F.pad(train_batch["inputs"], (input_padding, 0), value=pad_token_id)
            padded_targets = F.pad(train_batch["targets"], (target_padding, 0), value=pad_token_id)
        else:
            padded_inputs = train_batch["inputs"]
            padded_targets = train_batch["targets"]

        # Forward pass
        outputs = model(padded_inputs)
        if not isinstance(outputs, torch.Tensor):
            outputs = outputs[0]

        if config.final_token_only:
            # Only compute loss on final token
            masked_outputs = outputs[:, -1, :].reshape(-1, outputs.size(-1))
            masked_targets = padded_targets[:, -1].reshape(-1)
        else:
            # Create mask for non-padding tokens
            mask = (padded_targets != task.pad_token_id)

            # Reshape outputs and targets
            masked_outputs = outputs.reshape(-1, outputs.size(-1))[mask.reshape(-1)]
            masked_targets = padded_targets.reshape(-1)[mask.reshape(-1)]

        # Compute loss only on non-padded positions
        loss = loss_fn(masked_outputs, masked_targets)

        # Backward pass
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        scheduler.step()

        # Compute metrics
        return loss.item()

    def fit(self, model, task, metadata, config):

        random.seed(config.seed)
        np.random.seed(config.seed)
        torch.manual_seed(config.seed)     
        torch.cuda.manual_seed(config.seed)   

        model = model.to(self.device)

        # Setup optimiser and scheduler
        optimiser = torch.optim.AdamW(
            params=model.parameters(),
            lr=config.lr
        )
        train_scheduler = CosineAnnealingLR(
            optimizer=optimiser,
            T_max=config.n_steps,
            eta_min=0
        )
        warmup_scheduler = LinearLR(optimiser,
                start_factor=1e-4, end_factor=1, total_iters=config.lr_warmup_steps)
        scheduler = SequentialLR(optimiser,
                [warmup_scheduler, train_scheduler], [config.lr_warmup_steps])

        if config.use_wandb:
            wandb.init(
                project=config.wandb_project,
                entity=config.wandb_entity,
                name=config.wandb_run_name,
                config=asdict(config)
            )

        # Create output directory with wandb run name if available
        output_dir = config.output_dir
        if config.use_wandb:
            output_dir = os.path.join(config.output_dir, wandb.run.name)

        # Create results directory if it doesn't exist
        os.makedirs(f'{output_dir}/progress', exist_ok=True)

        # Write metadata file
        self.write_metadata(output_dir, metadata)

        # Training
        accuracies = []
        losses = []
        eval_accuracy = 0.0
        best_eval_accuracy = 0.0
        pbar = tqdm(range(config.n_steps), desc="Training", unit="it")
        for step in pbar:
            log = {}
            if step % config.evaluation_steps == 0:
                eval_accuracy, summary = self.tester.evaluate(model, task, config)
                if config.task_name == 'navigation':
                    summary, success_acc, collision_acc, awareness_acc = summary
                    pbar.write(f'Step: {step}, Collision Acc: {round(collision_acc * 100, 4)}, Success Acc: {round(success_acc * 100, 4)}; {summary}')
                    log |= { "eval_accuracy": eval_accuracy,
                           "success_accuracy(o)": success_acc,
                           "collision_accuracy(x)": collision_acc,
                           "x vs o accuracy": awareness_acc}
                else:
                    pbar.write(f'Step: {step}, Acc: {round(eval_accuracy * 100, 4)}; {summary}')
                    log |= { "eval_accuracy": eval_accuracy }

                plt.figure(figsize=(3, 3))
                plt.plot(losses)
                plt.grid(True)
                plt.title('Training Loss')
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.savefig(f'{output_dir}/progress/train_loss.png', bbox_inches='tight', dpi=300)
                plt.close()

                os.makedirs(f'{output_dir}/models', exist_ok=True)
                if eval_accuracy > best_eval_accuracy:
                    best_eval_accuracy = eval_accuracy
                    torch.save(model.state_dict(), f'{output_dir}/models/gpt_umm_mod_best.pt')

            if config.checkpoint_steps and (step % config.checkpoint_steps == 0):
                torch.save(model.state_dict(), f'{output_dir}/models/gpt_umm_mod_{step}.pt')

            train_loss = self.train_step(model, task, config, optimiser, scheduler)
            pbar.set_postfix({'Step': step, 'Train Loss': round(train_loss, 4)})
            losses.append(train_loss)
            log |= {
                "train_loss": train_loss,
                "learning_rate": optimiser.param_groups[0]["lr"],
                "step": step
            }
            losses.append(train_loss)

            if config.use_wandb:
                wandb.log(log)

        if config.use_wandb:
            wandb.finish()

        return losses

    def compute_loss(self, outputs, targets):
        loss = F.cross_entropy(
            outputs.view(-1, outputs.size(-1)),
            targets.cuda().to(torch.long).view(-1),
            ignore_index=-1
        )
        return loss

    def write_metadata(self, output_dir, metadata, filename="metadata.json"):
        filepath = os.path.join(output_dir, filename)
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(metadata, f, ensure_ascii=False, indent=2)

