import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models
import os
from tqdm import tqdm
import torch.nn.functional as F
from data.standard_data import StandardData, ValStandardData,LabelTestStandardData,LabelTrainStandardData,MultiTaskDataset
import argparse
import glob
from model.MTANet import MTANNet
import matplotlib.pyplot as plt



class PCGrad:
    def __init__(self, optimizer):
        self._optim = optimizer

    def zero_grad(self):
        self._grads = []
        self._params = []
        for group in self._optim.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    self._params.append(p)
        for p in self._params:
            p.grad = None

    def store_grad(self, loss):
        grads = []
        self._optim.zero_grad()
        loss.backward(retain_graph=True)
        for p in self._params:
            if p.grad is not None:
                grads.append(p.grad.clone())
            else:
                grads.append(None)
        self._grads.append(grads)

    def pc_backward(self):
        final_grads = [torch.zeros_like(p) if p.requires_grad else None for p in self._params]

        for i in range(len(self._grads)):
            g_i = self._grads[i]
            for j in range(i):
                g_j = self._grads[j]
                dot_product = sum((torch.sum(g1 * g2) for g1, g2 in zip(g_i, g_j) if g1 is not None and g2 is not None))
                if dot_product < 0:
                    # 投影
                    g_i = [g1 - (torch.sum(g1 * g2) / (torch.norm(g2) ** 2 + 1e-12)) * g2
                           if g1 is not None and g2 is not None else g1
                           for g1, g2 in zip(g_i, g_j)]

            for k in range(len(final_grads)):
                if final_grads[k] is not None and g_i[k] is not None:
                    final_grads[k] += g_i[k]

        for p, g in zip(self._params, final_grads):
            if g is not None:
                p.grad = g

    def step(self):
        self._optim.step()



BATCH_SIZE = 64
LEARNING_RATE = 5e-3
WEIGHT_DECAY = 1e-4
LABEL_SMOOTHING = 0.1
EPOCHS_PRETRAIN = 600
EPOCHS_FINETUNE = 10
# choose_feature = 'azimuth'

parser = argparse.ArgumentParser(description='ResNet18 Training Multiple CSVs')
parser.add_argument('--train_dir', type=str, required=True, help='Directory containing training CSVs')
parser.add_argument('--choose_feature', type=str, required=True, help='Choose_feature')
parser.add_argument('--val_csv', type=str, required=True, help='Path to validation CSV')
parser.add_argument('--log_file', type=str, required=True, help='Path to save training results')
parser.add_argument('--num_classes', type=int, required=True, help='Number of output classes')
args = parser.parse_args()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


    
val_dataset = MultiTaskDataset(csv_file=args.val_csv, train=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8,drop_last=True)

os.makedirs(os.path.dirname(args.log_file), exist_ok=True)

class PCGrad:
    def __init__(self, optimizer):
        self._optim = optimizer

    def zero_grad(self):
        self._grads = []
        self._params = []
        for group in self._optim.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    self._params.append(p)
        for p in self._params:
            p.grad = None

    def store_grad(self, loss):
        grads = []
        self._optim.zero_grad()
        loss.backward(retain_graph=True)
        for p in self._params:
            if p.grad is not None:
                grads.append(p.grad.clone())
            else:
                grads.append(None)
        self._grads.append(grads)

    def pc_backward(self):
        final_grads = [torch.zeros_like(p) if p.requires_grad else None for p in self._params]

        for i in range(len(self._grads)):
            g_i = self._grads[i]
            for j in range(i):
                g_j = self._grads[j]
                dot_product = sum((torch.sum(g1 * g2) for g1, g2 in zip(g_i, g_j) if g1 is not None and g2 is not None))
                if dot_product < 0:
                    # 投影
                    g_i = [g1 - (torch.sum(g1 * g2) / (torch.norm(g2) ** 2 + 1e-12)) * g2
                           if g1 is not None and g2 is not None else g1
                           for g1, g2 in zip(g_i, g_j)]

            for k in range(len(final_grads)):
                if final_grads[k] is not None and g_i[k] is not None:
                    final_grads[k] += g_i[k]

        for p, g in zip(self._params, final_grads):
            if g is not None:
                p.grad = g

    def step(self):
        self._optim.step()


class MultiTaskResNet(nn.Module):
    def __init__(self, num_classes_list, hidden_dim=512, dropout=0.5):
        super(MultiTaskResNet, self).__init__()
        base_model = models.resnet50(pretrained=False)
        self.backbone = nn.Sequential(*list(base_model.children())[:-1])  # 去掉原FC层
        in_features = base_model.fc.in_features


        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_features, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, num_classes)
            )
            for num_classes in num_classes_list
        ])

    def forward(self, x):
        x = self.backbone(x).squeeze(-1).squeeze(-1)  # [B, 2048, 1, 1] → [B, 2048]
        return [head(x) for head in self.heads]  

    def forward(self, x):
        x = self.backbone(x).squeeze(-1).squeeze(-1)  # [B, 512, 1, 1] → [B, 512]
        return [head(x) for head in self.heads]  #  [B, num_class_i]
def multi_task_accuracy(outputs, targets):
    task_accuracies = []
    for i in range(len(outputs)):
        preds = outputs[i].argmax(1)
        correct = (preds == targets[:, i]).sum().item()
        acc = correct / targets.size(0)
        task_accuracies.append(acc)
    avg_acc = sum(task_accuracies) / len(task_accuracies)
    return avg_acc, task_accuracies  

def loss_fn(outputs, labels):
    epsilon = LABEL_SMOOTHING
    num_classes = outputs.size(1)
    log_preds = F.log_softmax(outputs, dim=1)
    with torch.no_grad():
        true_dist = torch.zeros_like(log_preds)
        true_dist.fill_(epsilon / (num_classes - 1))
        true_dist.scatter_(1, labels.data.unsqueeze(1), 1 - epsilon)
    return torch.mean(torch.sum(-true_dist * log_preds, dim=1))


def train_epoch(model, dataloader, optimizer):
    model.train()
    total_loss = 0
    correct_total = 0
    total_samples = 0

    for images, labels in tqdm(dataloader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = multi_task_loss(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        bacth_acc,_ = multi_task_accuracy(outputs, labels)
        correct_total += bacth_acc * labels.size(0)
        total_samples += labels.size(0)

    return total_loss / total_samples, correct_total / total_samples


def train_epoch_pcgrad(model, dataloader, optimizer, loss_fn):
    model.train()
    total_loss = 0
    correct_total = 0
    total_samples = 0

    pcgrad = PCGrad(optimizer)

    for images, labels in tqdm(dataloader, desc="Training with PCGrad", leave=False):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)

        pcgrad.zero_grad()
        for i in range(len(outputs)):
            loss = loss_fn(outputs[i], labels[:, i])
            pcgrad.store_grad(loss)
        pcgrad.pc_backward()
        pcgrad.step()

        with torch.no_grad():
            loss = multi_task_loss(outputs, labels)  # for logging only
            total_loss += loss.item() * labels.size(0)
            batch_acc, _ = multi_task_accuracy(outputs, labels)
            correct_total += batch_acc * labels.size(0)
            total_samples += labels.size(0)

    return total_loss / total_samples, correct_total / total_samples
def multi_task_loss(outputs, targets, smoothing=LABEL_SMOOTHING):
    total_loss = 0
    for i in range(len(outputs)):
        out = outputs[i]
        tgt = targets[:, i]
        epsilon = smoothing
        num_classes = out.size(1)
        log_preds = F.log_softmax(out, dim=1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_preds)
            true_dist.fill_(epsilon / (num_classes - 1))
            true_dist.scatter_(1, tgt.unsqueeze(1), 1 - epsilon)
        loss = torch.sum(-true_dist * log_preds, dim=1).mean()
        total_loss += loss
    return total_loss / len(outputs)  

def validate_epoch(model, dataloader):
    model.eval()
    total_loss = 0
    avg_acc_list = []
    task_acc_lists = []
    total_samples = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = multi_task_loss(outputs, labels)

            total_loss += loss.item() * labels.size(0)
            avg_acc, task_accs = multi_task_accuracy(outputs, labels)
            avg_acc_list.append(avg_acc * labels.size(0))
            task_acc_lists.append([a * labels.size(0) for a in task_accs])
            total_samples += labels.size(0)

    mean_avg_acc = sum(avg_acc_list) / total_samples
    mean_task_accs = [sum(task[i] for task in task_acc_lists) / total_samples for i in range(len(task_acc_lists[0]))]

    return total_loss / total_samples, mean_avg_acc, mean_task_accs
import re

def natural_key(s):

    return [int(text) if text.isdigit() else text.lower() for text in re.split('(\d+)', s)]

csv_list = sorted(glob.glob(os.path.join(args.train_dir, '*.csv')), key=natural_key)
# csv_list = csv_list[:]  

# csv_list = sorted(glob.glob(os.path.join(args.train_dir, '*.csv')))
SAVE_BEST = True  

with open(args.log_file, 'w') as log_f:
    log_f.write("CSV_Name, Val_Accuracy, Floor_Color_Acc, Wall_Color_Acc, Object_Color_Acc, Object_Size_Acc, Object_Type_Acc, Azimuth_Acc\n")

    for csv_path in csv_list:
        csv_name = os.path.basename(csv_path)
        print(f"\n=== Training on: {csv_name} ===")
        train_dataset = MultiTaskDataset(csv_file=csv_path, train=True)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

        model = MTANNet(num_tasks=6, num_classes_per_task=[10, 10, 10, 8, 4, 15]).to(device)

        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

        best_val_acc = 0.0
        best_task_accs = []
        val_avg_acc_list = []
        val_task_acc_lists = []

        for epoch in range(EPOCHS_PRETRAIN):
            train_loss, train_acc = train_epoch(model, train_loader, optimizer)
            if (epoch + 1) % 50 == 0 and epoch >= 200:
                val_loss, val_avg_acc, val_task_accs = validate_epoch(model, val_loader)
                print(f"[Val @ Epoch {epoch+1}] Loss: {val_loss:.4f} Avg Acc: {val_avg_acc:.4f}, Task Accs: {[f'{a:.4f}' for a in val_task_accs]}")

                val_avg_acc_list.append(val_avg_acc)
                val_task_acc_lists.append(val_task_accs.copy())

                if val_avg_acc > best_val_acc:
                    best_val_acc = val_avg_acc
                    best_task_accs = val_task_accs.copy()

        recent_avg_acc = sum(val_avg_acc_list[-5:]) / min(5, len(val_avg_acc_list))
        recent_task_accs = [sum(task[i] for task in val_task_acc_lists[-5:]) / min(5, len(val_task_acc_lists)) for i in range(6)]

        if SAVE_BEST:
            save_acc = best_val_acc
            save_task_accs = best_task_accs
        else:
            save_acc = recent_avg_acc
            save_task_accs = recent_task_accs

        log_f.write(f"{csv_name}, {save_acc:.4f}, " + ', '.join([f"{a:.4f}" for a in save_task_accs]) + "\n")
        log_f.flush()

        epochs_total = list(range(len(val_avg_acc_list)))
        task_names = ["Floor_Color", "Wall_Color", "Object_Color", "Object_Size", "Object_Type", "Azimuth"]

        plt.figure(figsize=(10, 6))
        plt.plot(epochs_total, val_avg_acc_list, label='Avg_Acc')

        for i in range(6):
            task_acc = [epoch_accs[i] for epoch_accs in val_task_acc_lists]
            plt.plot(epochs_total, task_acc, label=f'{task_names[i]}')

        plt.xlabel('Epoch')
        plt.ylabel('Validation Accuracy')
        plt.title(f'Validation Accuracy Curves: {csv_name}')
        plt.legend()
        plt.grid(True)

        plot_save_path = os.path.join(os.path.dirname(args.log_file), f'{csv_name}_val_acc_curve.png')
        plt.savefig(plot_save_path)
        plt.close()